## Notes
- Thinking of making batch size large and just use the batch as the set of images to cluster
- If want to go for variable dataset sizes, can just have an outer loop that will change the batch size but train the same model

### Import Packages

In [1]:
from permnet import PermNet

import numpy as np
import time

import pyro
from pyro.distributions import *
from pyro.infer import Predictive, SVI, Trace_ELBO
from pyro.optim import Adam, AdamW
from pyro.nn import PyroModule

import torch
import torch.nn as nn
from torchvision import transforms, datasets

import matplotlib.pyplot as plt
%matplotlib inline

from tabulate import tabulate

### Load Data

In [2]:
# from https://medium.com/ai-society/gans-from-scratch-1-a-deep-introduction-with-code-in-pytorch-and-tensorflow-cb03cdcdba0f

def mnist_data():
    compose = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((.5), (.5))
        ])
    out_dir = './data/mnist'
    return datasets.MNIST(root=out_dir, train=True, transform=compose, download=True)

def images_to_vectors(images):
    return images.view(images.size(0), 784)

def vectors_to_images(vectors):
    return vectors.view(vectors.size(0), 1, 28, 28)

In [3]:
# Load data
N = 100
data = mnist_data()
# Create loader with data, so that we can iterate over it
data_loader = torch.utils.data.DataLoader(data, batch_size=N, shuffle=True)
# Num batches
num_batches = len(data_loader)

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


### Pyro

In [4]:
def model(data, step):
    # global variables
    alpha = torch.ones(T)
    weights = pyro.sample('weights', Dirichlet(alpha))
    
    with pyro.plate('components', T):
        locs = pyro.sample('locs', MultivariateNormal(torch.zeros(M), torch.eye(M)))

    # use PCA to reduce dimensionality so I can train faster
    img_set = images_to_vectors(data.squeeze())
    (U, S, V) = torch.pca_lowrank(img_set, q=None, center=True, niter=10)
    z = torch.matmul(data, V[:, :100000])
    # local variables
    with pyro.plate('data', N):
        assignment = pyro.sample('assignments', Categorical(weights))
        pyro.sample('obs', MultivariateNormal(locs[assignment], torch.eye(M)), obs=z)
        
def guide(data, step):
    # amortize using MLP
    
    if use_gpu: 
        data = data.cuda()
        
    tau = tau_mlp(data.float())
    tau = tau.view(T,M)

    # sample mixture components mu
    with pyro.plate('components', T):
        locs = pyro.sample('locs', MultivariateNormal(tau, torch.eye(M)))
    
    # sample cluster assignments
    alpha = alpha_mlp(data.float()) # returns a vector of length T
    weights = pyro.sample('weights', Dirichlet(alpha))  # vector of length T
    with pyro.plate('data', size=N):
        assignments = pyro.sample('assignments', Categorical(weights))
    
    # logging
    if step % 2 == 0:
        
        print('='*10, 'Iteration {}'.format(step), '='*10)
        weight_data = [weights[0][i] for i in range(len(weights[0]))]
        weight_data.insert(0, 'props')

        mu1_data = [locs[i,0] for i in range(locs.shape[0])]
        mu1_data.insert(0, 'mu1')

        mu2_data = [locs[i,1] for i in range(locs.shape[0])]
        mu2_data.insert(0, 'mu2')
        
        mu3_data = [locs[i,2] for i in range(locs.shape[0])]
        mu3_data.insert(0, 'mu3')
        
        mu4_data = [locs[i,3] for i in range(locs.shape[0])]
        mu4_data.insert(0, 'mu4')
        
        mu5_data = [locs[i,4] for i in range(locs.shape[0])]
        mu5_data.insert(0, 'mu5')
        
        mu6_data = [locs[i,5] for i in range(locs.shape[0])]
        mu6_data.insert(0, 'mu6')
        
        
        data = [weight_data, mu1_data, mu2_data, mu3_data, mu4_data, mu5_data, mu6_data]
        
        print(tabulate(data, headers=['', 'clust1', 'clust2', 'clust3', 'clust4', 'clust5', 'clust6', 'clust7', 'clust8', 'clust9', 'clust10']))

### Initialization

In [5]:
T = 10
M = 6  # how many features after pca
alpha_mlp = PermNet(num_pixels=784, hidden=2048, output_size=T).float()
tau_mlp = PermNet(num_pixels=784, hidden=2048, output_size=T*M).float()

adam_params = {"lr": 0.00001}
optimizer = Adam(adam_params)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

use_gpu = torch.cuda.is_available()
if use_gpu:
    print('using GPU!')
    alpha_mlp = alpha_mlp.cuda()
    tau_mlp = tau_mlp.cuda()
else:
    print('not using GPU!')

not using GPU!


  return torch._C._cuda_getDeviceCount() > 0


### Training

In [6]:
epochs = 20
dry = False

for epoch in range(epochs):
    for n_batch, (real_batch,_) in enumerate(data_loader):
        
        data = images_to_vectors(real_batch).unsqueeze(0)
        
        n_steps = 10
        start = time.time()
        for step in range(n_steps):
            svi.step(data, step)
            pyro.get_param_store()
            if step % 2 == 0:
                end = time.time()
                print('took', end-start, 'seconds')
                start = time.time()

        if not dry:
            torch.save({'model_state_dict': alpha_mlp.state_dict(),
                       }, 'saved_models/alpha_mlp_{}.pth'.format(step))

            torch.save({'model_state_dict': tau_mlp.state_dict(),
                       }, 'saved_models/tau_mlp_{}.pth'.format(step))
        break
    break

         clust1     clust2      clust3     clust4     clust5     clust6    clust7      clust8    clust9    clust10
-----  --------  ---------  ----------  ---------  ---------  ---------  --------  ----------  --------  ---------
props  0.220254  0.0136781   0.0502475  0.106647   0.0562601  0.0284842  0.097218   0.0823469  0.143951   0.200912
mu1    0.537987  2.33507    -0.957864   0.449947   2.93169    2.5207     3.07959    3.17407    1.64365    1.46379
mu2    1.56636   2.94652     2.27128    0.0901594  1.47747    2.54397    2.84904    0.109354   3.51041    3.04167
mu3    3.06897   2.35485     3.35059    3.44511    2.69556    0.902499   2.7231     1.72833    1.97658    2.06199
mu4    1.44914   1.57254     0.113334   3.76592    3.21001    3.61562    1.19541    3.7302     2.21537    2.67731
mu5    0.258158  0.421295    1.94493    1.43892    0.552475   1.53576    1.47021    1.27526    1.6497     1.58753
mu6    1.19847   2.13961     1.50532    2.56593    4.06919    2.69376    0.922602  -0