## Notes

The equivariant neural network implementation is loosely based on [PointNet](https://arxiv.org/pdf/1612.00593.pdf). I only implemented the first 'standard transform layer', but could design a much more detailed one. Here is the [GitHub repo](https://github.com/fxia22/pointnet.pytorch/blob/f0c2430b0b1529e3f76fb5d6cd6ca14be763d975/pointnet/model.py#L11) to the code

## Hyperparameters
- Adam{"lr": 0.00001} / SELU() / max / SVI

## BUGS
- cluster degeneracy (one cluster makes up the majority of the proportion vector)
    - [DEEP UNSUPERVISED CLUSTERING WITH GAUSSIAN MIXTURE VARIATIONAL AUTOENCODERS](https://arxiv.org/pdf/1611.02648.pdf)


## Resources
- [Tutorial on to_event and .expand()](https://bochang.me/blog/posts/pytorch-distributions/)
- [Event, Batch, Sample shapes](https://ericmjl.github.io/blog/2019/5/29/reasoning-about-shapes-and-probability-distributions/)
- [Debugging Neural Networks](https://stats.stackexchange.com/questions/352036/what-should-i-do-when-my-neural-network-doesnt-learn)

### Import Packages

In [None]:
from pointnet import PointNet

import numpy as np
import time

import pyro
from pyro.distributions import *
from pyro.infer import Predictive, SVI, Trace_ELBO
from pyro.optim import Adam, AdamW
from pyro.nn import PyroModule

import torch
import torch.nn as nn

import matplotlib.pyplot as plt
%matplotlib inline

from tabulate import tabulate

### Generate Data

In [None]:
N = 500  # number of data points
M = 2  # number of features
T = 5  # Fixed number of components.

cov = np.identity(M)  # covariance matrix is just the identity for now

# generate data
clust1 = np.random.multivariate_normal(np.zeros(M), cov, 100)
clust2 = np.random.multivariate_normal(np.ones(M)*10, cov, 100)
clust3 = np.random.multivariate_normal(np.ones(M)*-10, cov, 100)
clust4 = np.random.multivariate_normal([10, -10], cov, 100)
clust5 = np.random.multivariate_normal([-10, 10], cov, 100)
data = np.concatenate((clust1, clust2, clust3, clust4, clust5))

plt.scatter(data[:,0], data[:,1])
plt.grid()

data = torch.from_numpy(data).float()  # convert numpy to torch

### Pyro

In [None]:
def model(data, step):
    # global variables
    alpha = torch.ones(T)
    weights = pyro.sample('weights', Dirichlet(alpha))
    
    with pyro.plate('components', T):
        locs = pyro.sample('locs', MultivariateNormal(torch.zeros(M), torch.eye(M)))

    # local variables
    with pyro.plate('data', N):
        assignment = pyro.sample('assignments', Categorical(weights))
        pyro.sample('obs', MultivariateNormal(locs[assignment], torch.eye(M)), obs=data)
        
def guide(data, step):
    # amortize using MLP
    pyro.module('alpha_mlp', alpha_mlp)
    pyro.module('tau_mlp', tau_mlp)
    
    # sample mixture components mu
    tau = tau_mlp(data.permute(1,0).float())
    tau = tau.view(5,2)  # reshape tensor
    
    with pyro.plate('components', T):
        locs = pyro.sample('locs', MultivariateNormal(tau, torch.eye(M)))
    
    # sample cluster assignments
    alpha = alpha_mlp(data.permute(1,0).float()) # returns a vector of length T
    weights = pyro.sample('weights', Dirichlet(alpha))  # vector of length T
    with pyro.plate('data', N):
        assignments = pyro.sample('assignments', Categorical(weights))
    
    if step % 100 == 0:
        
        print('='*10, 'Iteration {}'.format(step), '='*10)
        weight_data = [weights[i] for i in range(len(weights))]
        weight_data.insert(0, 'props')

        mu1_data = [locs[i,0] for i in range(locs.shape[0])]
        mu1_data.insert(0, 'mu1')

        mu2_data = [locs[i,1] for i in range(locs.shape[0])]
        mu2_data.insert(0, 'mu2')
        
        data = [weight_data, mu1_data, mu2_data]
        
        print(tabulate(data, headers=['', 'clust1', 'clust2', 'clust3', 'clust4', 'clust5']))

### Initialization

In [None]:
dry = True

alpha_mlp = PointNet(T, M).float() 
tau_mlp = PointNet(T*M, M).float()

adam_params = {"lr": 0.00001}
optimizer = Adam(adam_params)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

### Training

In [None]:
n_steps = 50000
start = time.time()
for step in range(n_steps):
    svi.step(data, step)
    pyro.get_param_store()
    if step % 100 == 0:
        end = time.time()
        print('took', end-start, 'seconds')
        start = time.time()
        
        if not dry:
            torch.save({'model_state_dict': alpha_mlp.state_dict(),
                       }, 'saved_models/alpha_mlp_{}.pth'.format(step))

            torch.save({'model_state_dict': tau_mlp.state_dict(),
                       }, 'saved_models/tau_mlp_{}.pth'.format(step))