In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torchvision
import torch

import matplotlib.pyplot as plt
import numpy as np

import torchvision.transforms.functional as TF

seed = 0
np.random.seed(seed)
torch.manual_seed(seed)

root = './data/MNIST'
ds0train = torchvision.datasets.MNIST(root=root, train=True)
ds0test = torchvision.datasets.MNIST(root=root, train=False)

dtype=torch.float32

def rotate(images):
    batch_size = len(images)
    angles= [0, 120, 240]
    idx = np.random.randint(0, len(angles), batch_size)
    with torch.no_grad(): 
        out = torch.stack([TF.rotate(images[i].unsqueeze(0), angles[idx[i]])[0] for i in range(batch_size)],dim=0)
    return out

def colorize(images):
    batch_size = len(images)
    colors = np.array(
        [[166,206,227],
         [31,120,180],
         [178,223,138],
         [51,160,44],
         [251,154,153]],
    )/256.
    colors = np.asarray(colors,dtype=np.float32)
    idx = np.random.randint(0, len(colors), batch_size)
    with torch.no_grad():
        out=torch.stack([(images[i].unsqueeze(0).repeat(3,1,1)*colors[idx[i]].reshape(-1,1,1)) for i in range(batch_size)],dim=0) 

    return out

transforms_j = rotate, colorize

class Peersupervision(torch.utils.data.Dataset):
    def __init__(self, data, targets, J, transforms, ifstack=True):
        self.data = data
        self.targets = targets.detach().numpy()
        self.J = J
        self.ifstack = ifstack
        assert len(transforms) == J
        self.transforms = transforms

    def __len__(self):
        return len(self.data)
    def __getitem__(self,idx):
        c = self.targets[idx]
        if c.ndim==0:
            idc = np.where(self.targets==c)[0]
            pair_ids = idc[np.random.choice(len(idc), self.J+1, replace=False).reshape(1,-1)]
            pair_ids[0,0] = idx
        else:
            pair_ids = np.zeros((len(idx), self.J))
            for i,c_ in enumerate(c):
                idc = np.where(self.targets==c_)[0]
                pair_ids[i] = idc[np.random.choice(len(idc), self.J, replace=False)]
            pair_ids[:,0] = idx

        out = [self.transforms[j](self.data[pair_ids[:,j]]) for j in range(self.J)]
        return (torch.stack(out,axis=1), idx) if self.ifstack else (out, idx) 

N,J = len(ds0train), 2

ds_train = Peersupervision(data=ds0train.data/256., targets=ds0train.targets, J=J, transforms=transforms_j,ifstack=False)
train_data = ds_train[np.arange(N)]
train_data = [[train_data[0][0].unsqueeze(1), train_data[0][1]], train_data[1]]
train_labels = ds0train.targets

ds_test = Peersupervision(data=ds0test.data/256., targets=ds0test.targets, J=J, transforms=transforms_j, ifstack=False)
test_data = ds_test[np.arange(len(ds0test))]
test_data = [[test_data[0][0].unsqueeze(1), test_data[0][1]], test_data[1]]
test_labels = ds0test.targets


In [None]:
plt.figure(figsize=(4,6))
for n in range(np.minimum(N, 5)):
    for j in range(J):
        plt.subplot(np.minimum(N, 5), J, J*n + j + 1)
        data_show = train_data[0][n,j].detach().numpy() if ds_train.ifstack else train_data[0][j][n].detach().numpy()
        if j == 0:
            plt.imshow(data_show[0], cmap='gray')
        else:
            plt.imshow(data_show.transpose(1,2,0))
        plt.axis('off')
plt.suptitle('First ' + str(np.minimum(N, 5)) + ' out of N= ' +str(N) + ' peer tuples of size J =' +str(J))
plt.show()

In [None]:
from rpm import RPMEmpiricalMarginals, EmpiricalDistribution, RPM
from expFam import LogPartition_gauss_diagonal, ExpFam, ConditionalExpFam


K = len(np.unique(train_labels.detach().numpy()))
#dim_T = K # dimension of sufficient statistics
dim_T = 4 # dimension of sufficient statistics
dim_Z = 2 # dimension of sufficient statistics

class Net(torch.nn.Module):
    # Convolutional Neural Network shared across independent factors
    def __init__(self, C_in, n_out, C_hidden, n_hidden, activation_out=torch.nn.Identity()):
        super(Net, self).__init__()
        self.activation_out = activation_out
        self.conv1 = torch.nn.Conv2d(C_in, C_hidden, kernel_size=5)
        self.conv2 =torch.nn.Conv2d(C_hidden, 2*C_hidden, kernel_size=5)
        self.conv2_drop = torch.nn.Dropout2d()
        self.fc1 = torch.nn.Linear(4*4*2*C_hidden, n_hidden)
        self.fc2 = torch.nn.Linear(n_hidden, n_out)

    def forward(self, x):
        x = torch.nn.functional.relu(torch.nn.functional.max_pool2d(self.conv1(x), 2))
        x = torch.nn.functional.relu(torch.nn.functional.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 4*4*20)
        x = torch.nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return self.activation_out(x)

# define Gaussian prior in natural parametrization
def activation_out(x,d=1): # NN returns natural parameters; in Gaussian case, that is m/sig2, -1/(2*sig2)
    return torch.cat([x[...,:d], -torch.nn.Softplus()(x[...,d:])],axis=-1)
log_partition = LogPartition_gauss_diagonal(d=dim_Z)

latent_prior = ExpFam(natparam=torch.normal(mean=0.0, std=torch.ones(dim_T).reshape(1,-1)),
                                    log_partition=log_partition, activation_out=activation_out)

C_ins_j = [1,3] # number of input channels (per j) for CNNs
natparam_models = [Net(C_in=C_ins_j[j], 
                       n_out=dim_T, 
                       C_hidden=10, 
                       n_hidden=50, 
                       activation_out=activation_out) for j in range(J)]
rec_factors = [ConditionalExpFam(model=natparam_models[j], log_partition=log_partition) for j in range(J)]

ivi_natparam_models = [Net(C_in=sum(C_ins_j), n_out=dim_T, C_hidden=10, n_hidden=50, activation_out=activation_out) for j in range(J)]
ivi_rec_models = [ConditionalExpFam(model=m, log_partition=log_partition) for m in ivi_natparam_models]

nu = ivi_rec_models

xjs = [train_data[0][j] for j in range(J)]
pxjs = RPMEmpiricalMarginals(xjs)

rpm = RPM(rec_factors, latent_prior=latent_prior, px=pxjs, 
          q='use_theta',  
          nu=nu, iviNatParametrization='delta',
          stack_xjs_new_axis=False, full_N_for_Fj=False)


In [None]:
import matplotlib.pyplot as plt
optimizer = torch.optim.Adam(rpm.parameters(), lr=1e-3)

epochs = 1
batch_size = 16

ds_load = torch.utils.data.TensorDataset(*train_data[0], torch.tensor(train_data[1]))
dl = torch.utils.data.DataLoader(dataset=ds_load, batch_size=batch_size, shuffle=True, drop_last=True)

ls,t = np.zeros(epochs*(N//batch_size)),0
for i in range(epochs):
    for batch in dl:
        optimizer.zero_grad()
        batch = batch[0] if ds_train.ifstack else ((batch[0], batch[1]), batch[2])
        loss = rpm.training_step(*batch, batch_idx=t)
        loss.backward()
        optimizer.step()
        ls[t] = loss.detach().numpy()
        t+=1
        #print(t)
    print('epoch #' + str(i+1) + '/' + str(epochs) + ', loss : ' + str(ls[t-1]))
plt.plot(ls)
plt.show()


In [None]:
eta_q = rpm.comp_eta_q(xjs)[0]
mu_q = log_partition.nat2meanparam(eta_q)
for k in range(10):    
    plt.subplot(2,5,k+1)
    plt.imshow(np.histogram2d(mu_q[train_labels==k, 0].detach().numpy(), 
                              mu_q[train_labels==k, 1].detach().numpy())[0])
    #plt.xlim(mu_q[:,0].detach().min(), mu_q[:,0].detach().max())
    #plt.ylim(mu_q[:,1].detach().min(), mu_q[:,1].detach().max())
plt.show()

In [None]:
for i in range(dim_Z):
    for k in range(10):    
        plt.subplot(10,2,i+2*k+1)
        plt.hist(mu_q[train_labels==k, i].detach().numpy())
        plt.xlim(mu_q[:,i].detach().min(), mu_q[:,i].detach().max())
        #plt.ylim(mu_q[:,1].detach().min(), mu_q[:,1].detach().max())
plt.show()

In [None]:
epochs = 10
ls,t = np.zeros(epochs*(N//batch_size)),0
for i in range(epochs):
    for batch in dl:
        optimizer.zero_grad()
        batch = batch[0] if ds_train.ifstack else ((batch[0], batch[1]), batch[2])
        loss = rpm.training_step(*batch, batch_idx=t)
        loss.backward()
        optimizer.step()
        ls[t] = loss.detach().numpy()
        t+=1
        #print(t)
    print('epoch #' + str(i+1) + '/' + str(epochs) + ', loss : ' + str(ls[t-1]))
plt.plot(ls)
plt.show()


In [None]:
eta_q = rpm.comp_eta_q(xjs)[0]
mu_q = log_partition.nat2meanparam(eta_q)

for k in range(10):
    plt.subplot(10,1,k+1)
    plt.hist(mu_q[train_labels==k, 0].detach().numpy())
    plt.xlim(mu_q[:,0].detach().min(), mu_q[:,0].detach().max())
plt.show()

In [None]:
import sklearn
from sklearn import metrics as skmetrics

xjs = [train_data[0][j] for j in range(J)]
idx_n = train_data[1].reshape(-1,1)

log_pzj_xs, log_pzg_x, log_px = drpm.eval(xjs, idx_n)
posts = log_pzg_x.detach().numpy() # posteriors over global (!) latent

M = skmetrics.confusion_matrix(y_true=np.argmax(posts,axis=1), y_pred=train_labels)
plt.imshow(M)
plt.colorbar()
plt.show()

In [None]:
def mapij(a):
    b = np.zeros_like(a)
    for i in range(len(a)):
        if a[i] == 0:
            b[i] =   2
        if a[i] == 1:
            b[i] =   6
        if a[i] == 2:
            b[i] =   1
        if a[i] == 3:
            b[i] =   8
        if a[i] == 4:
            b[i] =   4
        if a[i] == 5:
            b[i] =   3
        if a[i] == 6:
            b[i] =   5
        if a[i] == 7:
            b[i] =   0
        if a[i] == 8:
            b[i] =   9
        if a[i] == 9:
            b[i] =   7
    return b

Mperm = skmetrics.confusion_matrix(y_true=mapij(np.argmax(posts,axis=1)), y_pred=train_labels)
plt.imshow(Mperm)
plt.colorbar()
plt.show()
np.diag(Mperm).sum() / Mperm.sum()

In [None]:
xjs = [test_data[0][j] for j in range(J)]
idx_n = test_data[1].reshape(-1,1)
log_pzj_xs, log_pzg_x, log_px = drpm.eval(xjs, idx_n)
posts = log_pzg_x.detach().numpy() # posteriors over global (!) latent

M = skmetrics.confusion_matrix(y_true=np.argmax(posts,axis=1), y_pred=test_labels)
plt.imshow(M)
plt.colorbar()
plt.show()

Mperm = skmetrics.confusion_matrix(y_true=mapij(np.argmax(posts,axis=1)), y_pred=test_labels)
plt.imshow(Mperm)
plt.colorbar()
plt.show()

np.diag(Mperm).sum() / Mperm.sum()