## BUGS

## TODO
- work on MixtureParamsMLP to output the right dimensions 
- how to print out trace for pi

In [13]:
import numpy as np
import time

import pyro
from pyro.distributions import *
from pyro.infer import Predictive, SVI, Trace_ELBO
from pyro.optim import Adam

import torch
import torch.nn as nn

In [14]:
clust1 = np.random.normal(0, 1, 1000)
clust2 = np.random.normal(12, 1, 1000)
clust3 = np.random.normal(6, 1, 1000)
clust4 = np.random.normal(3, 1, 1000)
clust5 = np.random.normal(9, 1, 1000)

data = torch.from_numpy(np.concatenate((clust1, clust2, clust3, clust4, clust5))).double()
N = data.shape[0]

In [15]:
T = 5  # Fixed number of components.

class MixturePropsMLP(nn.Module):
  '''
    Outputs a probability vector of length T
  '''
  def __init__(self):
    super().__init__()
    hidden_layer_1_size = 128
    hidden_layer_2_size = 128
    output_size = T
    self.layers = nn.Sequential(
        nn.AdaptiveAvgPool1d(1),
        nn.Linear(1, hidden_layer_1_size),
        nn.ReLU(),
        nn.Linear(hidden_layer_1_size, hidden_layer_2_size),
        nn.ReLU(),
        nn.Linear(hidden_layer_2_size, output_size),
        nn.Softmax(0)  # 0-dim because only 1D
    )

  def forward(self, x):
    return self.layers(x.unsqueeze(0).unsqueeze(0))

class MixtureParamsMLP(nn.Module):
  '''
    Outputs a probability vector of length T
  '''
  def __init__(self):
    super().__init__()
    hidden_layer_1_size = 128
    hidden_layer_2_size = 128
    output_size = T
    self.layers = nn.Sequential(
        nn.AdaptiveAvgPool1d(1),
        nn.Linear(1, hidden_layer_1_size),
        nn.ReLU(),
        nn.Linear(hidden_layer_1_size, hidden_layer_2_size),
        nn.ReLU(),
        nn.Linear(hidden_layer_2_size, output_size)
    )

  def forward(self, x):
    return self.layers(x.unsqueeze(0).unsqueeze(0))


def model(data):
    with pyro.plate('components', T):
        locs = pyro.sample('locs', Normal(0, 1))

    with pyro.plate('data', N):
        # Local variables.
        assignments = pyro.sample('assignments', Categorical(torch.ones(T) / T)) # returns a vector of length T
        obs = pyro.sample('obs', Normal(locs[assignments], 1), obs=data)

In [16]:
def guide(data):
    # amortize using MLP
    pyro.module('pi_mlp', pi_mlp)
    
    # sample mixture components mu
    tau = pyro.param('tau', lambda: Normal(0, 1).sample([T]))
    with pyro.plate('components', T) as i:      
        pyro.sample('locs', Normal(tau[i], 1))
    
    # sample cluster assignments
    pi = pi_mlp(data.double()) # returns a vector of length T
    # pyro.param('phi', phi)    
    with pyro.plate("data", N):
        pyro.sample("assignments", Categorical(pi)) # returns a vector of length N

In [17]:
def print_progress(step):
    print('='*10, 'Iteration {}'.format(step), '='*10)
    tau = pyro.param('tau')
    # pi = pyro.param('pi')

    print('tau is', tau)
    # print('pi is', pi)

In [18]:
pi_mlp = MixturePropsMLP().double()
adam_params = {"lr": 0.005}
optimizer = Adam(adam_params)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

n_steps = 2501
start = time.time()
for step in range(n_steps):
  svi.step(data)
  pyro.get_param_store()
  if step % 100 == 0:
        end = time.time()
        print_progress(step)
        print('took', end-start, 'seconds')
        start = time.time()

tau is tensor([ 1.0423,  1.2229, -1.9547, -0.8401,  0.7332], requires_grad=True)
took 0.017001867294311523 seconds
tau is tensor([ 1.5189,  1.7085, -1.4704, -0.3554,  1.2206], requires_grad=True)
took 1.073002815246582 seconds
tau is tensor([ 1.9730,  2.1706, -1.0016,  0.1052,  1.6818], requires_grad=True)
took 0.9394094944000244 seconds
tau is tensor([ 2.3985,  2.5750, -0.5509,  0.5700,  2.1176], requires_grad=True)
took 0.9525508880615234 seconds
tau is tensor([ 2.8043,  2.9558, -0.1072,  1.0050,  2.5292], requires_grad=True)
took 0.9041855335235596 seconds
tau is tensor([3.1851, 3.3281, 0.3203, 1.4221, 2.9114], requires_grad=True)
took 0.8289201259613037 seconds
tau is tensor([3.5293, 3.6553, 0.7291, 1.8092, 3.2658], requires_grad=True)
took 0.9510729312896729 seconds
tau is tensor([3.8574, 3.9662, 1.1119, 2.1732, 3.6093], requires_grad=True)
took 0.9516282081604004 seconds
tau is tensor([4.1317, 4.2689, 1.5048, 2.5324, 3.9087], requires_grad=True)
took 0.955132246017456 seconds
tau