In [1]:
import torch
import math
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.distributions.bernoulli import Bernoulli
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from torchvision import datasets, transforms

In [4]:
class PCPPCA(nn.Module):
    # D is the dimension of the image (i.e. width/height)
    # K is the dimension of the laten variables
    # Y is the number of classes
    def __init__(self, W, K, Y, comp=1):
        super().__init__()
        self.D = math.ceil(W/comp) ** 2
        self.K = K
        self.pool = nn.MaxPool2d(comp,ceil_mode=True)
        self.F = nn.Parameter(torch.randn(self.D,self.K),requires_grad=True)
        self.sig = nn.Parameter(torch.tensor(1.), requires_grad = True)
        self.mu = nn.Parameter(torch.randn(self.D),requires_grad=True)
        #self.linear = nn.Linear(K, Y)

    def forward(self, x):
        M = x.shape[0]
        self.compressed = self.pool(x).mean(axis=1).view(M,self.D)
        self.Phi = torch.eye(self.D) * torch.exp(self.sig.clamp(min=-5))
        self.Phi_inv = torch.eye(self.D) * torch.exp(-self.sig.clamp(min=-5))
        self.gen = MultivariateNormal(loc=self.mu, covariance_matrix=self.F.mm(self.F.t()) + self.Phi)
        #sigz = (torch.eye(self.K) + self.F.t().mm(self.Phi_inv).mm(self.F)).inverse()
        #muz = sigz.mm(self.F.t()).mm(self.Phi_inv).mm((self.compressed - self.mu.unsqueeze(0).repeat(self.compressed.shape[0],1)).t()).t()
        return torch.Tensor(0)#self.linear(muz)
    
    def get_preds(self,x):
        M = x.shape[0]
        self.compressed = self.pool(x).mean(axis=1).view(M,self.D)
        self.Phi = torch.eye(self.D) * torch.exp(self.sig.clamp(min=-5))
        self.Phi_inv = torch.eye(self.D) * torch.exp(-self.sig.clamp(min=-5))
        self.gen = MultivariateNormal(loc=self.mu, covariance_matrix=self.F.mm(self.F.t()) + self.Phi)
        sigz = (torch.eye(self.K) + self.F.t().mm(self.Phi_inv).mm(self.F)).inverse()
        muz = sigz.mm(self.F.t()).mm(self.Phi_inv).mm((self.compressed - self.mu.unsqueeze(0).repeat(self.compressed.shape[0],1)).t()).t()
        return self.linear(muz)


In [5]:
data_path = 'images2em'
dataset = datasets.ImageFolder(
    root=data_path,
    transform= transforms.ToTensor()
)

In [6]:
dataset.classes

['happiness', 'neutral']

In [7]:
M=100
dataloader = torch.utils.data.DataLoader(dataset, batch_size=M, shuffle=True)

In [8]:
model = PCPPCA(350,30,2,comp=5)
optimizer = optim.SGD(model.parameters(), lr=0.005)
l = torch.nn.CrossEntropyLoss()
lmbda = 2000

In [9]:
for i in range(100):
    print(i)
    images,y = next(iter(dataloader))
    optimizer.zero_grad()
    _ = model.forward(images)
    loss = -torch.mean(model.gen.log_prob(model.compressed)) #+ lmbda * l(preds,y)
    #print("Predictive loss: ", lmbda* l(preds,y))
    print("Generative loss: ", loss) #- l(preds,y))
    print()
    loss.backward()
    optimizer.step()

0
Generative loss:  tensor(8205.6855, grad_fn=<NegBackward>)

1
Generative loss:  tensor(448395.6875, grad_fn=<NegBackward>)

2
Generative loss:  tensor(33058.5781, grad_fn=<NegBackward>)

3
Generative loss:  tensor(7781.5020, grad_fn=<NegBackward>)

4
Generative loss:  tensor(4730.9546, grad_fn=<NegBackward>)

5
Generative loss:  tensor(4394.0142, grad_fn=<NegBackward>)



KeyboardInterrupt: 

In [15]:
loadalldata = torch.utils.data.DataLoader(dataset, batch_size=dataset.__len__())
allimages,ally = next(iter(loadalldata))

In [20]:
preds = model.get_preds(allimages)

In [21]:
torch.mean((torch.argmax(preds,axis=1) - ally) ** 2 * 1.)

tensor(0.4415)