In [1]:
import numpy as np
import time

import pyro
from pyro.distributions import *
from pyro.infer import Predictive, SVI, Trace_ELBO
from pyro.optim import Adam

import torch
import torch.nn as nn

In [2]:
clust1 = np.random.normal(0, 1, 1000)
clust2 = np.random.normal(12, 1, 1000)
clust3 = np.random.normal(6, 1, 1000)
clust4 = np.random.normal(3, 1, 1000)
clust5 = np.random.normal(9, 1, 1000)

data = torch.from_numpy(np.concatenate((clust1, clust2, clust3, clust4, clust5))).double()
N = data.shape[0]

In [75]:
T = 5  # Fixed number of components.

class MLP(nn.Module):
  '''
    Outputs a probability vector of length T
  '''
  def __init__(self):
    super().__init__()
    hidden_layer_1_size = 2*N
    hidden_layer_2_size = 2*N
    input_size = N
    output_size = T
    self.layers = nn.Sequential(
      nn.Linear(input_size, hidden_layer_1_size),
      nn.Linear(hidden_layer_1_size, hidden_layer_2_size),
      nn.ReLU(),
      nn.Linear(hidden_layer_2_size, output_size),
      nn.Softmax(0)
    )

  def forward(self, x):
    return self.layers(x)


def model(data):
    with pyro.plate('components', T):
        locs = pyro.sample('locs', Normal(0, 1))

    with pyro.plate('data', N):
        # Local variables.
        assignments = pyro.sample('assignments', Categorical(torch.ones(T) / T)) # returns a vector of length T
        print('len(locs[assignments])', len(locs[assignments]))
        print('len(assignments)', len(assignments))
        obs = pyro.sample('obs', Normal(locs[assignments], 1), obs=data)
        print('len(obs)', len(obs))

In [76]:
mlp = MLP().double()

def guide(data):
    # amortize using MLP
    pyro.module('mlp', mlp)
    
    # sample mixture components mu
    tau = pyro.param('tau', lambda: Normal(0, 1))
    with pyro.plate('components', T) as i:      
        pyro.sample('locs', Normal(tau[i], 1))
    
    # sample cluster assignments
    with pyro.plate("data", N):
        phi = mlp(data.double()) # returns a vector of length T
        pyro.param('phi', Delta(phi))
        pyro.sample("assignments", Categorical(phi)) # returns a vector of length N

In [77]:
def print_progress(step):
    print('='*10, 'Iteration {}'.format(step), '='*10)
    tau = pyro.param('tau')
    phi = pyro.param('phi')

    print('tau is', tau)
    print('phi is', phi)

In [78]:
adam_params = {"lr": 0.005}
optimizer = Adam(adam_params)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

n_steps = 2501
start = time.time()
for step in range(n_steps):
  svi.step(data)
  pyro.get_param_store()
  if step % 100 == 0:
        end = time.time()
        print_progress(step)
        print('took', end-start, 'seconds')
        start = time.time()

len(locs[assignments]) 5000
len(assignments) 5000


ValueError: Shape mismatch inside plate('data') at site obs dim -1, 5000 vs 5
   Trace Shapes:         
    Param Sites:         
   Sample Sites:         
 components dist        |
           value      5 |
       locs dist      5 |
           value 5    5 |
       data dist        |
           value   5000 |
assignments dist   5000 |
           value   5000 |