# RPM without conditional independence assumption
- dependency structure is overrated !

In [None]:
import torchvision
import torch

import matplotlib.pyplot as plt
import numpy as np

root = './data/MNIST'
ds0train = torchvision.datasets.MNIST(root=root, train=True)
ds0test = torchvision.datasets.MNIST(root=root, train=False)

class Peersupervision(torch.utils.data.Dataset):
    def __init__(self, data, targets, J, ifstack=True):
        self.data = data
        self.targets = targets.detach().numpy()
        self.J = J
        self.ifstack = ifstack
    def __len__(self):
        return len(self.data)
    def __getitem__(self,idx):
        c = self.targets[idx]
        if c.ndim==0:
            idc = np.where(self.targets==c)[0]
            pair_ids = idc[np.random.choice(len(idc), self.J+1, replace=False).reshape(1,-1)]
            pair_ids[0,0] = idx
        else:
            pair_ids = np.zeros((len(idx), self.J))
            for i,c_ in enumerate(c):
                idc = np.where(self.targets==c_)[0]
                pair_ids[i] = idc[np.random.choice(len(idc), self.J, replace=False)]
            pair_ids[:,0] = idx

        out = [self.data[pair_ids[:,j]] for j in range(self.J)]
        return torch.stack(out,axis=1) if self.ifstack else out 

N,J = len(ds0train), 2

ds_train = Peersupervision(data=ds0train.data/256., targets=ds0train.targets, J=J, ifstack=True)
train_data = ds_train[np.arange(N)]
train_labels = ds0train.targets

ds_test = Peersupervision(data=ds0test.data/256., targets=ds0test.targets, J=J, ifstack=True)
test_data = ds_test[np.arange(len(ds0test))]
test_labels = ds0test.targets


In [None]:
for n in range(np.minimum(N, 5)):
    for i in range(J):
        plt.subplot(np.minimum(N, 5), J, J*n + i + 1)
        data_show = train_data[n,i].detach().numpy() if ds_train.ifstack else train_data[i][n].detach().numpy()
        plt.imshow(data_show)
        plt.axis('off')
plt.suptitle('First ' + str(np.minimum(N, 5)) + ' out of N= ' +str(N) + ' peer tuples of size J =' +str(J))
plt.show()


In [None]:
from rpm import RPMEmpiricalMarginals, EmpiricalDistribution
from discreteRPM import discreteRPM, discretenonCondIndRPM, Prior_discrete, RecognitionFactor_discrete
from implicitRPM import ObservedMarginal, IndependentMarginal, GaussianCopula_ExponentialMarginals

from discreteRPM import discreteRPM, discretenonCondIndRPM, Prior_discrete, RecognitionFactor_scaled_discrete

K = len(np.unique(train_labels.detach().numpy()))
dim_T = K # dimension of sufficient statistics


class Net(torch.nn.Module):
    # Convolutional Neural Network shared across independent factors
    def __init__(self, C_in, n_out, C_hidden, n_hidden, activation_out=torch.nn.Identity()):
        super(Net, self).__init__()
        self.activation_out = activation_out
        self.conv1 = torch.nn.Conv2d(C_in, C_hidden, kernel_size=5)
        self.conv2 =torch.nn.Conv2d(C_hidden, 2*C_hidden, kernel_size=5)
        self.conv2_drop = torch.nn.Dropout2d()
        self.fc1 = torch.nn.Linear(4*4*2*C_hidden, n_hidden)
        self.fc2 = torch.nn.Linear(n_hidden, n_out)

    def forward(self, x):
        x = torch.nn.functional.relu(torch.nn.functional.max_pool2d(self.conv1(x), 2))
        x = torch.nn.functional.relu(torch.nn.functional.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 4*4*20)
        x = torch.nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return self.activation_out(x)


#natparam_model = Net(C_in=J, n_out=K, C_hidden=10, n_hidden=50, activation_out=torch.nn.Identity())
#rec_model = RecognitionFactor_discrete(model=natparam_model) 

natparam_model = Net(C_in=J, n_out=K+1, C_hidden=10, n_hidden=50, activation_out=torch.nn.Identity())
rec_model = RecognitionFactor_scaled_discrete(model=natparam_model) 


prior =  Prior_discrete(param=torch.zeros(size=(K,)))


xjs = [train_data[:,j] for j in range(J)] if ds_train.ifstack else [train_data[j] for j in range(J)]
pxj = RPMEmpiricalMarginals(xjs)

# constsruct implicit RPM
full_F = False
drpm = discretenonCondIndRPM(rec_model, 
                             latent_prior=prior, 
                             pxjs=pxj, 
                             full_F=full_F)

"""
natparam_models = [Net(C_in=1, n_out=K, C_hidden=10, n_hidden=50, activation_out=torch.nn.Identity()) for j in range(J)]
rec_models = [RecognitionFactor_discrete(model=m) for m in natparam_models] 

# constsruct implicit RPM
drpm = discreteRPM(rec_models, 
                      latent_prior=Prior_discrete(param=torch.randn(size=(K,))), 
                      pxjs=xjs)
"""

In [None]:
prior.param_.requires_grad = False

In [None]:
prior.param_

In [None]:
import matplotlib.pyplot as plt
optimizer = torch.optim.Adam(drpm.parameters(), lr=1e-4)

epochs = 50
batch_size = 16

ds_load = torch.utils.data.TensorDataset(train_data) if ds_train.ifstack else torch.utils.data.TensorDataset(*train_data)
dl = torch.utils.data.DataLoader(dataset=ds_load, batch_size=batch_size, shuffle=True, drop_last=True)

ls,t = np.zeros(epochs*(N//batch_size)),0
for i in range(epochs):
    for batch in dl:
        optimizer.zero_grad()
        batch = batch[0] if ds_train.ifstack else [x.unsqueeze(1) for x in batch]
        loss = drpm.training_step(batch, batch_idx=t)
        loss.backward()
        optimizer.step()
        ls[t] = loss.detach().numpy()
        t+=1
    print('epoch #' + str(i+1) + '/' + str(epochs) + ', loss : ' + str(ls[t-1]))
plt.plot(ls)
plt.show()


In [None]:
epochs = 50
batch_size = 16

ls,t = np.zeros(epochs*(N//batch_size)),0
for i in range(epochs):
    for batch in dl:
        optimizer.zero_grad()
        batch = batch[0] if ds_train.ifstack else [x.unsqueeze(1) for x in batch]
        loss = drpm.training_step(batch, batch_idx=t)
        loss.backward()
        optimizer.step()
        ls[t] = loss.detach().numpy()
        t+=1
    print('epoch #' + str(i+1) + '/' + str(epochs) + ', loss : ' + str(ls[t-1]))
plt.plot(ls)
plt.show()


In [None]:
N_est_F = 200
idx_sort=np.argsort(train_labels[:N_est_F].detach().numpy())
all_xjs = [pxj.x[:N_est_F][idx_sort] for pxj in drpm.pxjs.pxjs]
N_ = all_xjs[0].shape[0]
assert all([N_ == xj.shape[0] for xj in all_xjs])
shuffle_ids = torch.cartesian_prod(*[torch.arange(N_,dtype=torch.int) for j in range(J)])
xshuffled = torch.stack([all_xjs[j][shuffle_ids[:,j]] for j in range(J)], axis=1) # N^J-J-K !!!!
m = drpm.rec_model
log_fxs = m.log_probs(xshuffled)                                         # N^J - K 
log_denom = torch.logsumexp(log_fxs,dim=0).reshape(1,-1) - np.log(N_**J) #  1  - K
pOverF = (torch.exp(drpm.latent_prior.log_probs()).reshape(1,-1)/torch.exp(log_denom).reshape(1,-1)).detach().numpy()

plt.plot(torch.exp(log_denom[0]).detach().numpy(), label='F(Z) (est. from subsample)')
plt.plot(torch.exp(drpm.latent_prior.log_probs()).detach().numpy(), label='P(Z))')
plt.legend()
plt.show()

posts = torch.exp(drpm.rec_model(train_data)).detach().numpy()
for c in range(10):
    idx = np.where(train_labels==c)
    plt.subplot(5,2,c+1)
    plt.plot((posts[idx]*pOverF).mean(axis=0))
    plt.ylabel('c='+str(c+1))
plt.suptitle('avg. posteriors per class')
plt.show()

In [None]:
plt.imshow(torch.exp(log_fxs).sum(axis=-1).reshape(N_est_F, N_est_F).detach().numpy())
plt.colorbar()

In [None]:
j = 0
plt.figure(figsize=(16,16))
for k in range(K):
    plt.subplot(4,3,k+1)
    plt.imshow(torch.exp(log_fxs[:,k]).reshape(N_est_F, N_est_F).detach().numpy())
    plt.axis('off')
    plt.colorbar()
plt.show()

In [None]:
posts = (torch.exp(drpm.rec_model(train_data)) ).detach().numpy() * pOverF
for m in range(10):
    for n in range(5):
        for i in range(J):
            plt.subplot(5, J+1, (J+1)*n + i + 1)
            plt.imshow(train_data[m*5+n,i].detach().numpy())            
            plt.axis('off')
        plt.subplot(5, J+1, (J+1)*n + 3)
        plt.plot(np.arange(K)+1, posts[m*5+n])
                 #torch.exp(drpm.rec_model(train_data[m*5+n].reshape(1,*train_data[m*5+n].shape)))[0].detach().numpy())
    plt.show()


In [None]:
from sklearn import metrics as skmetrics

posts = (torch.exp(drpm.rec_model(train_data)) ).detach().numpy() * pOverF

M = skmetrics.confusion_matrix(y_true=np.argmax(posts,axis=1), y_pred=train_labels)
plt.imshow(M)
plt.colorbar()
plt.show()

#(M[3,0] + M[1,1] + M[2,2] + M[5,3] + M[4,4] + M[6,5] + M[7,6] + M[9,7] + M[8,8] + M[9,9]) / M.sum()

In [None]:
def mapij(a):
    b = np.zeros_like(a)
    for i in range(len(a)):
        if a[i] == 0:
            b[i] =   5
        if a[i] == 1:
            b[i] =   2
        if a[i] == 2:
            b[i] =   7
        if a[i] == 3:
            b[i] =   8
        if a[i] == 4:
            b[i] =   9
        if a[i] == 5:
            b[i] =   4
        if a[i] == 6:
            b[i] =   6
        if a[i] == 7:
            b[i] =   0
        if a[i] == 8:
            b[i] =   1
        if a[i] == 9:
            b[i] =   3
    return b

Mperm = skmetrics.confusion_matrix(y_true=mapij(np.argmax(posts,axis=1)), y_pred=train_labels)
plt.imshow(Mperm)
plt.colorbar()
plt.show()
np.diag(Mperm).sum() / Mperm.sum()

In [None]:
posts = (torch.exp(drpm.rec_model(test_data)) ).detach().numpy() * pOverF

M = skmetrics.confusion_matrix(y_true=np.argmax(posts,axis=1), y_pred=test_labels)
plt.imshow(M)
plt.colorbar()
plt.show()

Mperm = skmetrics.confusion_matrix(y_true=mapij(np.argmax(posts,axis=1)), y_pred=test_labels)
plt.imshow(Mperm)
plt.colorbar()
plt.show()

np.diag(Mperm).sum() / Mperm.sum()