In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
from torch.distributions import constraints

import pyro
import pyro.distributions as dist
import matplotlib.pylab as plt


from pyro.infer import Predictive, SVI, Trace_ELBO, TraceEnum_ELBO
from pyro.optim import Adam
from tqdm import tqdm


from pyro.ops.indexing import Vindex

In [2]:
pyro.set_rng_seed(42)

In [20]:
def model(data, C):
    N = data.shape[0]
    
    global_shrinkage_prior_scale = pyro.param('global_shrinkage_scale_prior', 5*torch.ones(1), constraint=constraints.positive)
    global_shrinkage = pyro.sample('global_shrinkage', dist.HalfNormal(global_shrinkage_prior_scale))
    b = pyro.sample('b', dist.InverseGamma(0.5,1./torch.ones(C)**2).to_event(1))
    local_shrinkage = pyro.sample('local_shrinkage', dist.InverseGamma(0.5,1./b).to_event(1))
    
    mixing_weights = pyro.sample("mixing_weights", dist.Dirichlet(global_shrinkage*local_shrinkage))
#    with pyro.plate("clusters", C):
#        mu = pyro.sample("mu", dist.Normal(0.5, 1.0))
#        sigma = pyro.sample("sigma", dist.LogNormal(-4., 1.))
    mu_mu_prior = pyro.param("mu_mu_prior", 0.5*torch.randn(C))
    mu_sigma_prior = pyro.param("mu_sigma_prior", 20*torch.abs(torch.randn(C)), constraint=constraints.positive)
    sigma_mu_prior = pyro.param("sigma_mu_prior", -1*torch.randn(C))
    sigma_sigma_prior = pyro.param("sigma_sigma_prior", torch.abs(torch.randn(C)), constraint=constraints.positive)

    mu = pyro.sample("mu", dist.Normal(mu_mu_prior, mu_sigma_prior).to_event(1))
    sigma = pyro.sample("sigma", dist.LogNormal(sigma_mu_prior, sigma_sigma_prior).to_event(1))
        
    with pyro.plate("data", N):
        z = pyro.sample("z", dist.Categorical(mixing_weights), infer={'enumerate': 'parallel'})
        #z = pyro.sample("z", dist.Categorical(mixing_weights), infer={'enumerate': 'sequential'})
        #print(z.shape)
        #print(mu.shape)
        #print(Vindex(mu)[...,z].shape)
        data = pyro.sample("obs", dist.Normal(Vindex(mu)[...,z],Vindex(sigma)[...,z]), obs=data)
        #data = pyro.sample("obs", dist.Normal(mu.index_select(-1,z),sigma.index_select(-1,z)), obs=data)
        
    return data

In [4]:
def guide(data, C):
    N = data.shape[0]
    
    global_shrinkage_prior_scale = pyro.param('global_shrinkage_scale', torch.ones(1), constraint=constraints.positive)
    global_shrinkage = pyro.sample('global_shrinkage', dist.HalfNormal(global_shrinkage_prior_scale))
    local_shrinkage_loc = pyro.param('local_shrinkage_loc', torch.ones(C))
    local_shrinkage_scale = pyro.param('local_shrinkage_scale', torch.ones(C), constraint=constraints.positive)
    local_shrinkage = pyro.sample('local_shrinkage', dist.LogNormal(local_shrinkage_loc,local_shrinkage_scale).to_event(1))

    #mixing_weights_concentration = pyro.param("mixing_weights_concentration", torch.ones(C), constraint=constraints.positive)
    #mixing_weights = pyro.sample("mixing_weights", dist.Dirichlet(torch.ones(C)))
    mixing_weights = pyro.sample("mixing_weights", dist.Dirichlet(global_shrinkage*local_shrinkage))
    #with pyro.plate("clusters", C):
    mu_mu = pyro.param("mu_mu", torch.ones(C))
    mu_sigma = pyro.param("mu_sigma", torch.ones(C), constraint=constraints.positive)
    sigma_mu = pyro.param("sigma_mu", torch.ones(C))
    sigma_sigma = pyro.param("sigma_sigma", torch.ones(C), constraint=constraints.positive)
    mu = pyro.sample("mu", dist.Normal(mu_mu,mu_sigma).expand([C]).to_event(1))
    sigma = pyro.sample("sigma", dist.LogNormal(sigma_mu, sigma_sigma).expand([C]).to_event(1))

In [5]:
unmodel = pyro.poutine.uncondition(model)

In [6]:
C = 3
N = 5000

In [7]:
pyro.clear_param_store()

In [8]:
dgp = pyro.poutine.trace(unmodel).get_trace(torch.zeros(N),C)

torch.Size([5000])
torch.Size([3])
torch.Size([5000])


In [9]:
dgp.nodes['mixing_weights']

{'type': 'sample',
 'name': 'mixing_weights',
 'fn': Dirichlet(concentration: torch.Size([3])),
 'is_observed': False,
 'args': (),
 'kwargs': {},
 'value': tensor([5.7167e-01, 1.1755e-38, 4.2833e-01], grad_fn=<_DirichletBackward>),
 'infer': {},
 'scale': 1.0,
 'mask': None,
 'cond_indep_stack': (),
 'done': True,
 'stop': False,
 'continuation': None}

In [10]:
dgp.nodes['mu']['value']

tensor([-3.8127,  0.0493,  6.7976], grad_fn=<AddBackward0>)

In [11]:
dgp.nodes['sigma']['value']

tensor([0.7296, 0.3772, 2.9780], grad_fn=<ExpBackward>)

In [12]:
dgp.nodes['z']['value']

tensor([2, 2, 0,  ..., 2, 0, 0])

In [13]:
%matplotlib qt

In [14]:
plt.figure()
plt.scatter(dgp.nodes['obs']['value'].detach(),torch.zeros(N),c=dgp.nodes['z']['value'].detach())
#plt.xlim(0,1)

<matplotlib.collections.PathCollection at 0x7fc4c80b9dd8>

In [15]:
data = dgp.nodes['obs']['value'].detach()

In [16]:
plt.figure()
plt.hist(data,bins=100);

In [17]:
def train(num_iterations):
    pyro.clear_param_store()
    for j in tqdm(range(num_iterations)):
        loss = svi.step(data,C)
        losses.append(loss)

In [21]:
n_iter = 1500
C = 10
optim = Adam({"lr": 0.01})
svi = SVI(model, guide, optim, loss=TraceEnum_ELBO(max_plate_nesting=1, num_particles = 16))
#svi = SVI(model, guide, optim, loss=TraceEnum_ELBO(max_plate_nesting=1, num_particles=16, vectorize_particles=True))
losses = []

train(n_iter)

100%|██████████| 1500/1500 [09:26<00:00,  2.65it/s]


In [22]:
plt.plot(losses)

[<matplotlib.lines.Line2D at 0x7fc4c2270080>]

In [23]:
foo = pyro.infer.Predictive(unmodel, guide=guide, num_samples=1000)

In [24]:
tmp= foo(data,C)

torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size

torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size

torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size

torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size

torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size

torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size([5000])
torch.Size([10])
torch.Size([5000])
torch.Size

In [25]:
tmp['z']

tensor([[9, 9, 9,  ..., 9, 1, 9],
        [1, 1, 1,  ..., 9, 9, 1],
        [1, 1, 1,  ..., 9, 9, 1],
        ...,
        [9, 1, 9,  ..., 1, 1, 1],
        [1, 1, 9,  ..., 1, 1, 9],
        [9, 9, 9,  ..., 9, 9, 9]])

In [30]:
tmp['mixing_weights'][-1]

tensor([[0.0020, 0.3740, 0.0092, 0.0016, 0.0018, 0.0131, 0.0039, 0.0010, 0.0084,
         0.5849]], grad_fn=<SelectBackward>)

In [27]:
plt.figure()
plt.scatter(tmp['obs'][10].detach(),torch.zeros(N),c=tmp['z'][10].detach())
plt.scatter(data,torch.zeros(N),c='r')
#plt.xlim(0,1)

<matplotlib.collections.PathCollection at 0x7fc4c15cb748>

In [28]:
plt.figure()
plt.hist(data,bins=100, density=True);
plt.hist(tmp['obs'][10].detach(),bins=100, density=True);