In [None]:
import torchvision
import torch

import matplotlib.pyplot as plt
import numpy as np

import torchvision.transforms.functional as TF

seed = 0
np.random.seed(seed)
torch.manual_seed(seed)

root = './data/MNIST'
ds0train = torchvision.datasets.MNIST(root=root, train=True)
ds0test = torchvision.datasets.MNIST(root=root, train=False)
dtype=torch.float32


class Peersupervision(torch.utils.data.Dataset):
    def __init__(self, data, targets, J, ifstack=True):
        self.data = data
        self.targets = targets.detach().numpy()
        self.J = J
        self.ifstack = ifstack

    def __len__(self):
        return len(self.data)
    def __getitem__(self,idx):
        c = self.targets[idx]
        if c.ndim==0:
            idc = np.where(self.targets==c)[0]
            pair_ids = idc[np.random.choice(len(idc), self.J+1, replace=False).reshape(1,-1)]
            pair_ids[0,0] = idx
        else:
            pair_ids = np.zeros((len(idx), self.J))
            for i,c_ in enumerate(c):
                idc = np.where(self.targets==c_)[0]
                pair_ids[i] = idc[np.random.choice(len(idc), self.J, replace=False)]
            pair_ids[:,0] = idx

        out = [self.data[pair_ids[:,j]] for j in range(self.J)]
        return (torch.stack(out,axis=1), idx) if self.ifstack else (out, idx) 

N,J = len(ds0train), 2

ds_train = Peersupervision(data=ds0train.data/256., targets=ds0train.targets, J=J, ifstack=False)
train_data = ds_train[np.arange(N)]
train_data = [[train_data[0][0].unsqueeze(1), train_data[0][1].unsqueeze(1)], train_data[1]]
train_labels = ds0train.targets

ds_test = Peersupervision(data=ds0test.data/256., targets=ds0test.targets, J=J, ifstack=False)
test_data = ds_test[np.arange(len(ds0test))]
test_data = [[test_data[0][0].unsqueeze(1), test_data[0][1].unsqueeze(1)], test_data[1]]
test_labels = ds0test.targets


In [None]:
K = 10

#fracs_k = [5/600, 5/600, 5/600, 5/600, 5/600, 5/600, 10/600, 10/600, 25/600, 25/600]
fracs_k = [10/600, 10/600, 10/600, 10/600, 10/600, 10/600, 10/600, 10/600, 10/600, 10/600]
idx_ks = []
for k in range(K):
    idx_k = np.where(train_labels == k)[0]
    idx_k = idx_k[:np.int32(np.round(len(idx_k)*fracs_k[k]))]
    idx_ks.append(idx_k)
idx_re = np.concatenate(idx_ks)

N = len(idx_re)

train_labels = train_labels[idx_re]
train_data = [[train_data[0][0][idx_re], train_data[0][1][idx_re]], np.arange(N)]

N

In [None]:
plt.figure(figsize=(4,6))
for n in range(np.minimum(N, 5)):
    for j in range(J):
        plt.subplot(np.minimum(N, 5), J, J*n + j + 1)
        data_show = train_data[0][n,j].detach().numpy() if ds_train.ifstack else train_data[0][j][n].detach().numpy()
        plt.imshow(data_show[0], cmap='gray')
        plt.axis('off')
plt.suptitle('First ' + str(np.minimum(N, 5)) + ' out of N= ' +str(N) + ' peer tuples of size J =' +str(J))
plt.show()

In [None]:
from rpm import RPMEmpiricalMarginals, EmpiricalDistribution
from discreteRPM import discreteRPM, Prior_discrete, discreteRPM_softmaxForm
from implicitRPM import ObservedMarginal, IndependentMarginal
from discreteRPM import RecognitionFactor_discrete, RecognitionFunction_discrete, RecognitionFunction_discrete_norm

K = len(np.unique(train_labels.detach().numpy()))
dim_T = K # dimension of sufficient statistics

class Net(torch.nn.Module):
    # Convolutional Neural Network shared across independent factors
    def __init__(self, C_in, n_out, C_hidden, n_hidden, activation_out=torch.nn.Identity()):
        super(Net, self).__init__()
        self.activation_out = activation_out
        self.conv1 = torch.nn.Conv2d(C_in, C_hidden, kernel_size=5)
        self.conv2 =torch.nn.Conv2d(C_hidden, 2*C_hidden, kernel_size=5)
        self.conv2_drop = torch.nn.Dropout2d()
        self.fc1 = torch.nn.Linear(4*4*2*C_hidden, n_hidden)
        self.fc2 = torch.nn.Linear(n_hidden, n_out)

    def forward(self, x):
        x = torch.nn.functional.relu(torch.nn.functional.max_pool2d(self.conv1(x), 2))
        x = torch.nn.functional.relu(torch.nn.functional.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 4*4*20)
        x = torch.nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return self.activation_out(x)

natparam_models = [Net(C_in=1, 
                       n_out=K+1, # K+1 for the extra term hj(xj) !
                       C_hidden=10, 
                       n_hidden=50, 
                       activation_out=torch.nn.Identity()) for j in range(J)]
#rec_models = [RecognitionFunction_discrete(model=natparam_models[j]) for j in range(J)]
rec_models = [RecognitionFunction_discrete_norm(model=natparam_models[j]) for j in range(J)]


prior =  Prior_discrete(param=torch.zeros(size=(K,)))

xjs = [train_data[0][j] for j in range(J)]
pxjs = [EmpiricalDistribution(xjs[j]) for j in range(J)]
px_allj = RPMEmpiricalMarginals(xjs)

# constsruct implicit RPM
drpm = discreteRPM_softmaxForm(rec_models, latent_prior=prior, pxjs=pxjs)


In [None]:
import sklearn
from sklearn import metrics as skmetrics

xjs = [train_data[0][j].unsqueeze(1) for j in range(J)]
idx_n = torch.tensor(train_data[1]).unsqueeze(1)

log_w, posterior = drpm.eval(idx_n)
posts = posterior.detach().numpy() # posteriors over global (!) latent

M = skmetrics.confusion_matrix(y_true=np.argmax(posts,axis=1), y_pred=train_labels)
plt.imshow(M)
plt.colorbar()
plt.show()

xjis = [drpm.pxjs[j].x[:10000] for j in range(drpm.J)]                                # N - D
gji = [m.affine_all_z(xj) for m,xj in zip(drpm.rec_models, xjis)]             # N - K  x J
norms_init = torch.exp(gji[0]).sum(axis=1).detach().numpy()
plt.plot(np.sort(norms_init))


In [None]:
import matplotlib.pyplot as plt
optimizer = torch.optim.Adam(drpm.parameters(), lr=1e-3)

epochs = 100
batch_size = 16

#ds_load = torch.utils.data.TensorDataset(*test_data[0], torch.tensor(test_data[1]))
ds_load = torch.utils.data.TensorDataset(*train_data[0], torch.tensor(train_data[1]))
dl = torch.utils.data.DataLoader(dataset=ds_load, batch_size=batch_size, shuffle=True, drop_last=True)

ls,t = np.zeros(epochs*(N//batch_size)),0
for i in range(epochs):
    for batch in dl:
        optimizer.zero_grad()
        idx_n = batch[2]
        loss = drpm.training_step(idx_n, batch_idx=t)
        loss.backward()
        optimizer.step()
        ls[t] = loss.detach().numpy()
        t+=1
    print('epoch #' + str(i+1) + '/' + str(epochs) + ', loss : ' + str(ls[t-1]))
plt.plot(ls)
plt.show()


In [None]:
import sklearn
from sklearn import metrics as skmetrics

xjs = [train_data[0][j].unsqueeze(1) for j in range(J)]
idx_n = torch.tensor(train_data[1]).unsqueeze(1)

log_w, posterior = drpm.eval(idx_n)
posts = posterior.detach().numpy() # posteriors over global (!) latent

M = skmetrics.confusion_matrix(y_true=np.argmax(posts,axis=1), y_pred=train_labels)
plt.imshow(M)
plt.colorbar()
plt.show()

xjis = [drpm.pxjs[j].x[:10000] for j in range(drpm.J)]                                # N - D
gji = [m.affine_all_z(xj) for m,xj in zip(drpm.rec_models, xjis)]             # N - K  x J
norms_early = torch.exp(gji[0]).sum(axis=1).detach().numpy()
plt.plot(np.sort(norms_early))


In [None]:
idx_js = []
for k in range(K):
    idx_js.append(np.where(np.argmax(posts,axis=1) == k)[0])
idx_re = np.concatenate(idx_js)
train_images_re = [train_data[0][0][idx_re], train_data[0][1][idx_re]]

idx_n = idx_re #np.arange(N)
    
gji = [m.affine_all_z(xj[idx_n]) for m,xj in zip(drpm.rec_models, xjis)]       # N - K  x J
log_Zj = [torch.logsumexp(gji[j],axis=0).unsqueeze(0) for j in range(drpm.J)]  # 1 - K  x J  
log_aji = [gji[j] - log_Zj[j] for j in range(drpm.J) ]                         # b - K  x J
log_aji = [log_aji[j] - torch.log(torch.exp(log_aji[j]).sum(axis=-1)).unsqueeze(-1) for j in range(drpm.J)]

log_joint = drpm.latent_prior.log_probs() + (log_aji[0].unsqueeze(0)+log_aji[1].unsqueeze(1))
marginal = torch.exp(log_joint).sum(axis=-1)


plt.figure(figsize=(12,6))
for j in range(J):
    plt.subplot(1,2,j+1)
    if j == 0:
        plt.ylabel('kj')
    plt.imshow((torch.exp(log_aji[j])/torch.exp(log_aji[j]).sum(axis=0).unsqueeze(0)).detach().numpy(), 
                aspect='auto', interpolation='none')
    plt.xlabel('Z')
    plt.colorbar()
plt.suptitle('p(kj | Z) for j = 1, 2')
plt.show()

plt.imshow(marginal.detach().numpy())
plt.title('marginal P(K) on standard grid after permuting indices')
plt.show()
#log_normalizer = torch.exp(log_joint).sum(axis=-1)

plt.figure(figsize=(12,6))
plt.subplot(1,2,1)
plt.plot(np.diag(marginal.detach().numpy()), label='p(K=[n,n,..])')
plt.plot(np.arange(K)*N/K+0.5*N/K, torch.exp(drpm.latent_prior.log_probs()).detach().numpy(), 'o-', label='prior p(Z=k)')
plt.legend()
plt.title('diagonal profile of P(K)')
plt.subplot(1,2,2)
plt.plot(torch.exp(drpm.latent_prior.log_probs()).detach().numpy(), np.array([len(idx) for idx in idx_js]), 'x')
plt.title('prior p(Z) vs cluster size')
plt.xlabel('p(Z)')
plt.ylabel('#n which argmax_kj p(kj|Z) = n')
plt.show()



# skewed label distributions 
- 25% of data are 8's
- 25% of data are 9's
- remaining 8 classes each constitute 5%

In [None]:
# skewed label distributions - 25% of data are 8's, 25% of data are 9's, remaining 8 classes each constitute 5%

In [None]:
plt.plot(np.sort(norms_init),  label='init')
plt.plot(np.sort(norms_early), label='early')
plt.plot(np.sort(norms_late),  label='late')
plt.plot(np.sort(norms_final),  label='final')
plt.plot([0, len(norms_init)], [1, 1], 'k:')
plt.legend()
plt.show()