In [8]:
import torch
import math
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.distributions.bernoulli import Bernoulli
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from torchvision import datasets, transforms

In [9]:
class PCPPCA(nn.Module):
    # D is the dimension of the image (i.e. width/height)
    # K is the dimension of the laten variables
    # Y is the number of classes
    def __init__(self, W, K, Y, comp=1):
        super().__init__()
        self.D = math.ceil(W/comp) ** 2
        self.K = K
        self.pool = nn.MaxPool2d(comp,ceil_mode=True)
        self.F = nn.Parameter(torch.randn(self.D,self.K),requires_grad=True)
        self.sig = nn.Parameter(torch.tensor(1.), requires_grad = True)
        self.mu = nn.Parameter(torch.randn(self.D),requires_grad=True)
        self.linear = nn.Linear(K, Y)

    def forward(self, x):
        M = x.shape[0]
        self.compressed = self.pool(x).mean(axis=1).view(M,self.D)
        self.Phi = torch.eye(self.D) * torch.exp(self.sig.clamp(min=-5))
        self.Phi_inv = torch.eye(self.D) * torch.exp(-self.sig.clamp(min=-5))
        self.gen = MultivariateNormal(loc=self.mu, covariance_matrix=self.F.mm(self.F.t()) + self.Phi)
        sigz = (torch.eye(self.K) + self.F.t().mm(self.Phi_inv).mm(self.F)).inverse()
        muz = sigz.mm(self.F.t()).mm(self.Phi_inv).mm((self.compressed - self.mu.unsqueeze(0).repeat(self.compressed.shape[0],1)).t()).t()
        return self.linear(muz)

In [10]:
data_path = 'images2em'
dataset = datasets.ImageFolder(
    root=data_path,
    transform= transforms.ToTensor()
)

In [11]:
dataset.classes

['happiness', 'neutral']

In [12]:
M=100
dataloader = torch.utils.data.DataLoader(dataset, batch_size=M, shuffle=True)

In [18]:
model = PCPPCA(350,30,2,comp=5)
optimizer = optim.SGD(model.parameters(), lr=0.005)
l = torch.nn.CrossEntropyLoss()
lmbda = 2000

In [19]:
for i in range(100):
    print(i)
    images,y = next(iter(dataloader))
    optimizer.zero_grad()
    preds = model.forward(images)
    loss = -torch.mean(model.gen.log_prob(model.compressed)) + lmbda * l(preds,y)
    print("Predictive loss: ", lmbda* l(preds,y))
    print("Generative loss: ", loss - l(preds,y))
    print()
    loss.backward()
    optimizer.step()

0
Predictive loss:  tensor(1382.6555, grad_fn=<MulBackward0>)
Generative loss:  tensor(9615.6250, grad_fn=<SubBackward0>)

1
Predictive loss:  tensor(1401.6411, grad_fn=<MulBackward0>)
Generative loss:  tensor(463193.0625, grad_fn=<SubBackward0>)

2
Predictive loss:  tensor(1659.2236, grad_fn=<MulBackward0>)
Generative loss:  tensor(36672.2344, grad_fn=<SubBackward0>)

3
Predictive loss:  tensor(3338.5605, grad_fn=<MulBackward0>)
Generative loss:  tensor(9733.4473, grad_fn=<SubBackward0>)

4
Predictive loss:  tensor(4487.3872, grad_fn=<MulBackward0>)
Generative loss:  tensor(9409.4805, grad_fn=<SubBackward0>)

5
Predictive loss:  tensor(6492.6021, grad_fn=<MulBackward0>)
Generative loss:  tensor(10667.0879, grad_fn=<SubBackward0>)

6
Predictive loss:  tensor(2337.8381, grad_fn=<MulBackward0>)
Generative loss:  tensor(8641.6025, grad_fn=<SubBackward0>)

7
Predictive loss:  tensor(5359.5444, grad_fn=<MulBackward0>)
Generative loss:  tensor(9839.1504, grad_fn=<SubBackward0>)

8
Predictive

66
Predictive loss:  tensor(7587.9062, grad_fn=<MulBackward0>)
Generative loss:  tensor(12909.9971, grad_fn=<SubBackward0>)

67
Predictive loss:  tensor(2548.7971, grad_fn=<MulBackward0>)
Generative loss:  tensor(7099.8652, grad_fn=<SubBackward0>)

68
Predictive loss:  tensor(5967.1948, grad_fn=<MulBackward0>)
Generative loss:  tensor(10209.6826, grad_fn=<SubBackward0>)

69
Predictive loss:  tensor(2536.0688, grad_fn=<MulBackward0>)
Generative loss:  tensor(7298.5903, grad_fn=<SubBackward0>)

70
Predictive loss:  tensor(5812.3340, grad_fn=<MulBackward0>)
Generative loss:  tensor(10340.4316, grad_fn=<SubBackward0>)

71
Predictive loss:  tensor(2211.4424, grad_fn=<MulBackward0>)
Generative loss:  tensor(7555.0132, grad_fn=<SubBackward0>)

72
Predictive loss:  tensor(7060.8579, grad_fn=<MulBackward0>)
Generative loss:  tensor(12340.2393, grad_fn=<SubBackward0>)

73
Predictive loss:  tensor(3652.2627, grad_fn=<MulBackward0>)
Generative loss:  tensor(8592.1113, grad_fn=<SubBackward0>)

74
P

In [15]:
loadalldata = torch.utils.data.DataLoader(dataset, batch_size=dataset.__len__())
allimages,ally = next(iter(loadalldata))

In [20]:
preds = model.forward(allimages)

In [21]:
torch.mean((torch.argmax(preds,axis=1) - ally) ** 2 * 1.)

tensor(0.4415)