In [1]:
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
import pyro
from pyro.contrib.gp import Parameterized
import pyro.contrib.gp as gp
from pyro import distributions as dist

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset_path = os.path.join("/mnt/dl/datasets/alphabet_dataset/dataset/")

In [3]:
class AlphabetDataset(Dataset):
    letters = {chr(ord('A') + i): i for i in range(26)}
    
    def __init__(self, path, split, n=512, size=64):
        self.path = path
        self.n = n
        self.split = split
        self.size = size
        self.data = self.load_files()
        self.size = size
    
    def __getitem__(self, index):
        fname, letter = self.data[index]
        label = self.letters[letter]
        
        img = Image.open(fname).resize((self.size, self.size))
        img = np.array(img).astype(np.float32)
        img /= 255.0
        img = img[:, :, None]
        img = np.transpose(img, (2, 0, 1))
        return (img, label)        
    
    def load_files(self):
        data = []
        for c in self.letters:
            path = os.path.join(self.path, self.split, c)
            files = os.listdir(path)
            files = list( map(lambda x: os.path.join(path, x), sorted(files)))
            data.extend(list(zip(files[:self.n], [c] * self.n)))
        return data
    
    def __len__(self):
        return self.n * len(self.letters)
    
class TrainAlphabetDataset(AlphabetDataset):
    
    def __init__(self, path, split="train", n=512):
        super().__init__(path, split, n)

    
class TestAlphabetDataset(AlphabetDataset):
    
    def __init__(self, path, split="test", n=16):
        super().__init__(path, split, n)

In [4]:
batch_size = 512
latent_shape = 50
num_classes = 26
size = 64
train_ds = TrainAlphabetDataset(dataset_path)
test_ds = TestAlphabetDataset(dataset_path)

train_loader = DataLoader(train_ds, batch_size=batch_size, 
                          num_workers=4, shuffle=True)

test_loader = DataLoader(test_ds, batch_size=16*num_classes,  shuffle=True)


In [5]:
class CNN(nn.Module):
    
    def __init__(self, ):
        super().__init__()
        self.lin1 = nn.Linear(64*64, 1096)
        self.lin2 = nn.Linear(1096, latent_shape)
    
    def forward(self, x):
        # print("Cnn ", x.size())
        bz = x.size(0)
        x = x.view(bz, -1)
        x = self.lin1(x)
        x = F.softplus(x)
        x = self.lin2(x)
        x = F.leaky_relu(x)
        # print("Cnn output ", x.size())
        
        return x       

class Classification(Parameterized):
    
    def __init__(self, ):
        super().__init__()
        self.lin1 = nn.Linear(latent_shape, num_classes)
        
    def forward(self, f_loc, f_var, y=None):
        # print(f_loc.size(), f_var.size())
        f = dist.Normal(f_loc, f_var.sqrt()).to_event(1)()
        f = f.permute((1, 0))
        # print(f.size())
        f = self.lin1(f)
        obs = pyro.sample("y", dist.Categorical(logits=f).to_event(1), obs=y)
        return obs      

def test(model, loader):
    with torch.no_grad():
        for img, label in loader:
            img = img.cuda()
            label = label.cuda()
            x = model(img)
            yhat = model.likelihood(*x)
            acc = yhat == label
            acc = acc.sum() * 100 / yhat.size(0)
    return acc.item()
# Classification()(torch.randn((2, 50), requires_grad=True), torch.randn((2, 50),  requires_grad=True).abs()) 
# Classification()(torch.randn((2, 50)), torch.randn((2, 50)).abs()) 

In [6]:
X, y = next(iter(train_loader))

In [7]:
X.shape, y.shape, X.dtype, y.dtype

(torch.Size([512, 1, 64, 64]), torch.Size([512]), torch.float32, torch.int64)

In [8]:
cnn = CNN()

In [9]:
kernel = gp.kernels.Warping(gp.kernels.RBF(input_dim=latent_shape, 
                                        #    lengthscale=torch.ones(latent_shape),
                                        #    variance=torch.ones(latent_shape),
                                           ),
                            iwarping_fn=cnn
                            )

In [10]:
likelihood =  Classification()#gp.likelihoods.MultiClass(num_classes=num_classes)


In [11]:
# gp_model = gp.models.VariationalGP(X, y, kernel=kernel, likelihood=likelihood, 
#                                    latent_shape=torch.Size([latent_shape]),
#                                    jitter=1e-3)


gp_model = gp.models.VariationalSparseGP(
        X=X,
        y=None,
        kernel=kernel,
        Xu=X,
        likelihood=likelihood,
        latent_shape=torch.Size([latent_shape]),
        num_data=batch_size * num_classes,
        whiten=True,
        jitter=2e-3,
    )

In [12]:
for n, p in gp_model.named_parameters():
    print(n, p.size())

Xu torch.Size([512, 1, 64, 64])
u_loc torch.Size([50, 512])
u_scale_tril_unconstrained torch.Size([50, 512, 512])
kernel.kern.variance_unconstrained torch.Size([])
kernel.kern.lengthscale_unconstrained torch.Size([])
kernel.iwarping_fn.lin1.weight torch.Size([1096, 4096])
kernel.iwarping_fn.lin1.bias torch.Size([1096])
kernel.iwarping_fn.lin2.weight torch.Size([50, 1096])
kernel.iwarping_fn.lin2.bias torch.Size([50])
likelihood.lin1.weight torch.Size([26, 50])
likelihood.lin1.bias torch.Size([26])


In [13]:
optimizer = torch.optim.Adam(gp_model.parameters(), lr=1e-2)

In [14]:
gp_model.cuda()

VariationalSparseGP(
  (kernel): Warping(
    (kern): RBF()
    (iwarping_fn): CNN(
      (lin1): Linear(in_features=4096, out_features=1096, bias=True)
      (lin2): Linear(in_features=1096, out_features=50, bias=True)
    )
  )
  (likelihood): Classification(
    (lin1): Linear(in_features=50, out_features=26, bias=True)
  )
)

In [15]:
loss_fn = pyro.infer.TraceMeanField_ELBO().differentiable_loss

In [16]:
losses = []
eps = 15
val_losses = []
for i in range(eps):
    for j, (img, label) in enumerate(train_loader):
        img = img.cuda()
        label = label.cuda()
        gp_model.set_data(img, label)      
        loss = loss_fn(gp_model.model, gp_model.guide) #/ (batch_size * size * size)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.cpu().item())
        step = i * eps + j
        if step % 10 == 0:
            print(f"Episode {i}/{eps}, Step: {j}, Loss: {np.mean(losses)}")
    acc = test(gp_model, test_loader)
    val_losses.append(acc)
    print(f"Episode {i}/{eps}, Loss: {np.mean(losses)} Acc: {acc}")

Episode 0/15, Step: 0, Loss: 45174.1484375
Episode 0/15, Step: 10, Loss: 44027.91122159091
Episode 0/15, Step: 20, Loss: 40957.80896577381
Episode 0/15, Loss: 38337.902268629805 Acc: 27.163461685180664
Episode 1/15, Step: 5, Loss: 34928.965576171875
Episode 1/15, Step: 15, Loss: 30227.71507626488
Episode 1/15, Step: 25, Loss: 27026.764610877402
Episode 1/15, Loss: 27026.764610877402 Acc: 64.42308044433594
Episode 2/15, Step: 0, Loss: 26727.841612617925
Episode 2/15, Step: 10, Loss: 24390.228097098214
Episode 2/15, Step: 20, Loss: 22733.964455800513
Episode 2/15, Loss: 22063.681415264422 Acc: 72.59615325927734
Episode 3/15, Step: 5, Loss: 21236.463344029016
Episode 3/15, Step: 15, Loss: 20050.605821974736
Episode 3/15, Step: 25, Loss: 19074.922072190504
Episode 3/15, Loss: 19074.922072190504 Acc: 77.40384674072266
Episode 4/15, Step: 0, Loss: 18981.044800967262
Episode 4/15, Step: 10, Loss: 18090.767238451088
Episode 4/15, Step: 20, Loss: 17320.46883203125
Episode 4/15, Loss: 16981.0761