In [1]:
import torch
from torch.distributions import constraints
from collections import defaultdict


import pyro
import pyro.distributions as dist
import pyro.optim as optim
from pyro.optim import Adam
from pyro import poutine
from pyro.contrib.autoguide import AutoDelta
from pyro.infer import SVI, Trace_ELBO

from matplotlib import pyplot
%matplotlib inline

In [2]:
import numpy as np

In [3]:
import matplotlib.pylab as plt

In [4]:
torch.set_default_tensor_type(torch.FloatTensor)

In [5]:
pyro.enable_validation(True)

From https://pyro.ai/examples/intro_part_i.html

Primitive stochastic functions, or distributions, are an important class of stochastic functions for which we can explicitly compute the probability of the outputs given the inputs. 

All probabilistic programs are built up by composing primitive stochastic functions and deterministic computation.

In [6]:
def make_categorical_matrix(N,D):
    # number of categories is sampled between 2 and number of features
    C = np.random.choice(np.arange(2,D*2),D)
    print(C)
    categorical_matrix = []
    for c in C:
        categorical_matrix.append(np.random.choice(c, size=N))
    return np.array(categorical_matrix).T

In [7]:
N, D = 1000,4
data = make_categorical_matrix(N,D)

[5 5 6 2]


In [8]:
data

array([[4, 2, 1, 0],
       [2, 1, 0, 0],
       [4, 1, 5, 0],
       ...,
       [3, 3, 4, 1],
       [3, 1, 1, 0],
       [1, 3, 0, 1]])

In [9]:
def independentCategorical(data):
    obs_dim, data_dim = data.shape
    # compute number of categories in each feature
    C = [np.unique(data[:,i]).shape[0] for i in range(D)]
    Cmax = np.max(C)
    C = torch.tensor(C)
    # Since there can be different number of categories,
    # we let all features have the same categories as the most category-rich feature,
    # but set the prior of non-existent categories to epsilon.
    dirichlet_params = torch.tensor([np.r_[np.ones(c)*10*Cmax, np.ones(Cmax - c)*0.1*(1/Cmax)] for c in C])
    #for feature in pyro.plate('features', data_dim):
    with pyro.plate('feature_plate', data_dim):
        probs = pyro.sample('probs', pyro.distributions.Dirichlet(dirichlet_params))
        #probs = pyro.sample('probs_{}'.format(feature), pyro.distributions.Dirichlet(dirichlet_params))
        with pyro.plate('data_plate', obs_dim):
            data = pyro.sample('obs', pyro.distributions.Categorical(probs=probs),obs=data)
    return data

In [10]:
data = np.float32(data) # float64 creates mismatch with torch's defaults
data = torch.tensor(data)

In [11]:
#?pyro.poutine.Trace

In [12]:
trace = pyro.poutine.trace(independentCategorical).get_trace(data)
#logp = trace.log_prob_sum()
#params = [trace.nodes[name]["value"].unconstrained() for name in trace.param_nodes]
#trace2 = pyro.poutine.trace(independentCategorical2).get_trace(data)

In [13]:
cond_model = pyro.condition(independentCategorical, data = {'obs', data})

In [14]:
data

tensor([[4., 2., 1., 0.],
        [2., 1., 0., 0.],
        [4., 1., 5., 0.],
        ...,
        [3., 3., 4., 1.],
        [3., 1., 1., 0.],
        [1., 3., 0., 1.]])

In [15]:
independentCategorical(data)

tensor([[4., 2., 1., 0.],
        [2., 1., 0., 0.],
        [4., 1., 5., 0.],
        ...,
        [3., 3., 4., 1.],
        [3., 1., 1., 0.],
        [1., 3., 0., 1.]])

In [16]:
hide_everything_but_probs_model = poutine.block(independentCategorical, expose=['probs'])

In [17]:
def independentCategoricalGuide(data):
    obs_dim, data_dim = data.shape
    # compute number of categories in each feature
    C = [np.unique(data[:,i]).shape[0] for i in range(D)]
    Cmax = np.max(C)
    C = torch.tensor(C)
    alphas = pyro.param('alphas', torch.ones(D,Cmax,dtype=torch.float64), constraint=constraints.positive)
    with pyro.plate('feature_plate', data_dim):
        pyro.sample('probs', pyro.distributions.Dirichlet(alphas))

In [18]:
optim = pyro.optim.Adam({'lr': 0.01, 'betas': [0.9, 0.999]})
elbo = Trace_ELBO(max_plate_nesting=2)
svi = SVI(independentCategorical, independentCategoricalGuide, optim, loss=elbo)

In [19]:
pyro.clear_param_store()

In [20]:
guide_trace = pyro.poutine.trace(independentCategoricalGuide).get_trace(data)

In [21]:
pyro.util.check_model_guide_match(trace, guide_trace)

In [22]:
pyro.get_param_store().named_parameters()

dict_items([('alphas', tensor([[0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.]], dtype=torch.float64, requires_grad=True))])

In [23]:
# Register hooks to monitor gradient norms.
gradient_norms = defaultdict(list)
for name, value in pyro.get_param_store().named_parameters():
    value.register_hook(lambda g, name=name: gradient_norms[name].append(g.norm().item()))

losses = []
# while not all parameters are locked:
# for parameter in parameters:
# lock parameter if gradient_norm is below tolerance
for i in range(2000):
    loss = svi.step(data)
    losses.append(loss)
    print('.' if i % 100 else '\n', end='')


...................................................................................................
...................................................................................................
...................................................................................................
...................................................................................................
...................................................................................................
...................................................................................................
...................................................................................................
...................................................................................................
...................................................................................................
...................................................................................................

In [24]:
def rel_err(true_parameter,estimated_parameter):
    if type(true_parameter) != torch.Tensor:
        true_parameter = torch.Tensor(true_parameter)
    norm = torch.norm
    return norm(true_parameter-estimated_parameter)/norm(true_parameter)

In [25]:
pyro.param('alphas')

tensor([[15.2690, 15.8122, 15.7969, 15.5664, 15.4220,  0.0534],
        [15.8758, 15.5346, 14.0170, 15.4700, 16.1497,  0.0443],
        [14.0860, 14.8840, 14.3809, 13.2279, 14.6049, 12.8324],
        [19.1577, 19.3151,  0.0365,  0.0373,  0.0382,  0.0380]],
       dtype=torch.float64, grad_fn=<AddBackward0>)

In [26]:
print(rel_err(some_data_locs,pyro.param('auto_locs')))
print(rel_err(some_data_scales,pyro.param('auto_scales')))

NameError: name 'some_data_locs' is not defined

In [None]:
pyplot.figure(figsize=(10,3), dpi=100).set_facecolor('white')
pyplot.plot(losses)
pyplot.xlabel('iters')
pyplot.ylabel('loss')
pyplot.yscale('log')
pyplot.title('Convergence of SVI');

In [None]:
pyplot.figure(figsize=(10,4), dpi=100).set_facecolor('white')
for name, grad_norms in gradient_norms.items():
    pyplot.plot(grad_norms, label=name)
pyplot.xlabel('iters')
pyplot.ylabel('gradient norm')
pyplot.yscale('log')
pyplot.legend(loc='best')
pyplot.title('Gradient norms during SVI');

In [None]:
optim = pyro.optim.Adam({'lr': 0.005, 'betas': [0.9, 0.999]})
svi = SVI(independentGaussian, global_guide, optim, loss=elbo)

In [None]:
for i in range(1000):
    loss = svi.step(data)
    losses.append(loss)
    print('.' if i % 100 else '\n', end='')

In [None]:
pyplot.figure(figsize=(10,3), dpi=100).set_facecolor('white')
pyplot.plot(losses)
pyplot.xlabel('iters')
pyplot.ylabel('loss')
pyplot.yscale('log')
pyplot.title('Convergence of SVI');

In [None]:
pyplot.figure(figsize=(10,4), dpi=100).set_facecolor('white')
for name, grad_norms in gradient_norms.items():
    pyplot.plot(grad_norms, label=name)
pyplot.xlabel('iters')
pyplot.ylabel('gradient norm')
pyplot.yscale('log')
pyplot.legend(loc='best')
pyplot.title('Gradient norms during SVI');

In [None]:
print(rel_err(some_data_locs,pyro.param('auto_locs')))
print(rel_err(some_data_scales,pyro.param('auto_scales')))

In [None]:
print(rel_err(some_data_locs,data.mean(0)))
print(rel_err(some_data_scales,np.sqrt(data.var(0))))