## Notes
- Thinking of making batch size large and just use the batch as the set of images to cluster
- If want to go for variable dataset sizes, can just have an outer loop that will change the batch size but train the same model

### Import Packages

In [1]:
import numpy as np
import time

import pyro
from pyro.distributions import *
from pyro.infer import Predictive, SVI, Trace_ELBO
from pyro.optim import Adam, AdamW
from pyro.nn import PyroModule

import torch
import torch.nn as nn
from torchvision import transforms, datasets

import matplotlib.pyplot as plt
%matplotlib inline

from tabulate import tabulate

### Load Data

In [2]:
# from https://medium.com/ai-society/gans-from-scratch-1-a-deep-introduction-with-code-in-pytorch-and-tensorflow-cb03cdcdba0f

def mnist_data():
    compose = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((.5), (.5))
        ])
    out_dir = './data/mnist'
    return datasets.MNIST(root=out_dir, train=True, transform=compose, download=True)

def images_to_vectors(images):
    return images.view(images.size(0), 784)

def vectors_to_images(vectors):
    return vectors.view(vectors.size(0), 1, 28, 28)

In [3]:
# Load data
data = mnist_data()
# Create loader with data, so that we can iterate over it
data_loader = torch.utils.data.DataLoader(data, batch_size=500, shuffle=True)
# Num batches
num_batches = len(data_loader)

### Neural Net

In [4]:
class PermutationLayer(PyroModule):
    '''
    PointNet inspired nn architecture for permutation invariance
    GitHub repo: https://github.com/fxia22/pointnet.pytorch/blob/f0c2430b0b1529e3f76fb5d6cd6ca14be763d975/pointnet/model.py#L11
    Paper: https://arxiv.org/pdf/1612.00593.pdf
    '''
    def __init__(self, num_pixels=784):
        super().__init__()
        
        self.num_pixels = num_pixels
        
        self.inside_network = nn.Sequential(
            nn.Conv2d(2*num_pixels, 2048, 1),
            nn.SELU(),
            nn.BatchNorm2d(2048),
            
            nn.Conv2d(2048, 2048, 1),
            nn.SELU(),
            nn.BatchNorm2d(2048),
            
            nn.Conv2d(2048, 2048, 1),
            nn.SELU(),
            nn.BatchNorm2d(2048),            
            
            nn.Conv2d(2048, 2048, 1),
            nn.SELU(),
            nn.BatchNorm2d(2048)
        )
            

    def forward(self, x):
        # do some reshaping to construct pairs
        num_objs = x.shape[1]
        
        rs = x.permute(0,2,1).unsqueeze(3)
        z1 = torch.tile(rs, (1,1,1,num_objs))
        z2 = z1.permute((0,1,3,2))

        # shape: [1, 1568, 500, 500] => [BATCH_SIZE, 2*NUM_PIXELS, NUM_IMGS, NUM_IMGS]
        # the Z tensor contains the interactions between every possible pair of images in the dataset
        Z = torch.cat([z1,z2], axis=1)
        
        Y = self.inside_network(Z)
        Y = torch.max(Y, axis=3)
        return Y

### Pyro

In [5]:
def model(data, step):
    # global variables
    alpha = torch.ones(T)
    weights = pyro.sample('weights', Dirichlet(alpha))
    
    with pyro.plate('components', T):
        locs = pyro.sample('locs', MultivariateNormal(torch.zeros(M), torch.eye(M)))

    # local variables
    with pyro.plate('data', N):
        assignment = pyro.sample('assignments', Categorical(weights))
        pyro.sample('obs', MultivariateNormal(locs[assignment], torch.eye(M)), obs=data)
        
def guide(data, step):
    # amortize using MLP
    
    if use_gpu: 
        data = data.cuda()
        
    perm = perm_mlp(data.float())
    print('nn output', perm.shape)

    # sample mixture components mu
    with pyro.plate('components', T):
        locs = pyro.sample('locs', MultivariateNormal(perm, torch.eye(M)))
    
    # sample cluster assignments
    alpha = pyro.param('alpha', torch.ones(T)) # returns a vector of length T
    weights = pyro.sample('weights', Dirichlet(alpha))  # vector of length T
    with pyro.plate('data', N):
        assignments = pyro.sample('assignments', Categorical(weights))
    
    # logging
    if step % 1000 == 0:
        
        print('='*10, 'Iteration {}'.format(step), '='*10)
        weight_data = [weights[i] for i in range(len(weights))]
        weight_data.insert(0, 'props')

        mu1_data = [locs[i,0] for i in range(locs.shape[0])]
        mu1_data.insert(0, 'mu1')

        mu2_data = [locs[i,1] for i in range(locs.shape[0])]
        mu2_data.insert(0, 'mu2')
        
        data = [weight_data, mu1_data, mu2_data]
        
        print(tabulate(data, headers=['', 'clust1', 'clust2', 'clust3', 'clust4', 'clust5']))

### Initialization

In [6]:
perm_mlp = PermutationLayer().float()
T = 10
M = 784

adam_params = {"lr": 0.00001}
optimizer = Adam(adam_params)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

use_gpu = troch.cude.is_available()
if use_gpu:
    perm_mlp = perm_mlp.cuda()

### Training

In [7]:
epochs = 20

for epoch in range(epochs):
    for n_batch, (real_batch,_) in enumerate(data_loader):
        data = images_to_vectors(real_batch).unsqueeze(0)

        n_steps = 50000
        start = time.time()
        for step in range(n_steps):
            svi.step(data, step)
            pyro.get_param_store()
            if step % 100 == 0:
                end = time.time()
                print('took', end-start, 'seconds')
                start = time.time()

                if not dry:
                    torch.save({'model_state_dict': alpha_mlp.state_dict(),
                               }, 'saved_models/alpha_mlp_{}.pth'.format(step))

                    torch.save({'model_state_dict': tau_mlp.state_dict(),
                               }, 'saved_models/tau_mlp_{}.pth'.format(step))
        break
    break

Z.shape torch.Size([1, 1568, 500, 500])


AttributeError: 'torch.return_types.max' object has no attribute 'shape'