In [210]:
import logging
import os

import torch
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import pyro
import pyro.distributions as dist
from pyro.infer import SVI
import torch.distributions.constraints as constraints

In [27]:
smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.9.0')

pyro.enable_validation(True)
pyro.set_rng_seed(1)
logging.basicConfig(format='%(message)s', level=logging.INFO)

# Set matplotlib settings
%matplotlib inline
plt.style.use('default')

In [28]:
def create_data_object_types(alpha=2,num_objs=10):
  '''Function for generating cluster assignments from a CRP.

  Arguments:
    alpha: Concentration parameter.
    n_objects: Number of objects to be clustered.

  Returns:
    cluster_assignments: Cluster assignments for each object.
  '''
  cluster_frequencies = [1]
  cluster_assignments =[0]
  n = num_objs-1
  probs = torch.tensor([1/(1+alpha),alpha/(1+alpha)])
  for i in range(n):
    z_i = dist.Categorical(probs)().item()
    total_customer = i + 1
    if z_i in cluster_assignments:
      cluster_frequencies[z_i] += 1
    else:
      cluster_frequencies += [ 1]
    probs = [n/(total_customer + alpha) for n in cluster_frequencies+[alpha]]
    probs = torch.tensor(probs)
    cluster_assignments.append(z_i)
  return torch.tensor(cluster_assignments)

In [29]:
def create_data_pred_matrix(obj_types, num_preds, theta):
  '''Function for generating predicability matrices.

  Arguments:
    obj_types: Tensor of object type assignments.
    num_preds: Number of predicates.
    theta: Sucess parameter for Bernoulli.

  Returns:
    pred_matrix: Predicability matrix.
  '''
  num_types = len(torch.unique(obj_types))
  pred_matrix = torch.bernoulli(torch.ones(num_types, num_preds)*theta)
  return pred_matrix

In [30]:
def create_data_truth_matrix(pred_matrix, obj_types, pi):
  '''Function for generating truth-value matrices.

  Arguments:
    pred_matrix: Predicability matrix.
    obj_types: List of object type assignments.
    pi: Success parameter for Bernoulli.

  Returns:
    truth_matrix: Predicability matrix.
  '''
  num_objs = len(obj_types)
  num_preds = pred_matrix.size()[1]
  truth_matrix = torch.zeros(num_objs, num_preds)
  for i in range(num_objs):
    for j in range(num_preds):
      if pred_matrix[z[i],j]==1:
        truth_matrix[i,j] = pyro.sample("t", dist.Bernoulli(pi)).item()
      else:
        truth_matrix[i,j] = 0
  return truth_matrix

In [31]:
def create_data_freq_matrix(truth_matrix, pred_matrix, obj_types):
  '''Function for generating frequency matrices.

  Arguments:
    pred_matrix: Predicability matrix.
    obj_types: List of object type assignments.
    truth_matrix: Truth matrix.

  Returns:
    freq_matrix: Frequency matrix.
  '''
  num_objs = len(obj_types)
  num_preds = pred_matrix.size()[1]
  freq_matrix = torch.zeros(num_objs, num_preds)
  for i in range(num_objs):
    for j in range(num_preds):
      if pred_matrix[z[i],j]==1:
        freq_matrix[i,j] = truth_matrix[i,j]
      else:
        freq_matrix[i,j] = -1
  return freq_matrix

In [582]:
# Number of objects
num_objs = 5
# Number of predicates
num_preds = 5
# Concentration param
alpha = 10
# Success parameter for predicability
theta = pyro.sample("theta", dist.Beta(1,1)).item()
# Success parameter for truth
pi = pyro.sample("pi", dist.Beta(1,1)).item()
print(f"Theta {theta} Pi {pi}")

Theta 0.5328394770622253 Pi 0.255874365568161


In [583]:
# Object type assignments
z = create_data_object_types(alpha=alpha, num_objs=num_objs)
print("Object type assignments:")
print(z)

Object type assignments:
tensor([0, 1, 2, 3, 4])


In [990]:
r = create_data_pred_matrix(obj_types=z, num_preds=num_preds, theta=theta)
print("Predicability matrix:")
r

Predicability matrix:


tensor([[0., 1., 0., 0., 1.],
        [0., 0., 1., 1., 0.]])

In [585]:
t = create_data_truth_matrix(pred_matrix=r, obj_types=z, pi=pi)
print("Truth matrix:")
t

Truth matrix:


tensor([[0., 0., 0., 0., 1.],
        [1., 0., 1., 1., 0.],
        [0., 1., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])

In [586]:
d = create_data_freq_matrix(truth_matrix=t, pred_matrix=r, obj_types=z)
print("Freq matrix:")
d

Freq matrix:


tensor([[-1.,  0., -1., -1.,  1.],
        [ 1.,  0.,  1.,  1., -1.],
        [ 0.,  1., -1., -1.,  0.],
        [ 1., -1.,  0.,  0., -1.],
        [ 0., -1., -1., -1.,  0.]])

In [985]:
def model(freq_matrix):
  '''A Pyro model for the TIRM model.

  Arguments:
    freq_matrix: The frequency matrix of observations.

  Notes:
    Currently not sampling alpha.
  '''
  # Model parameters
  theta = pyro.param("theta", torch.tensor(0.5), constraint=constraints.interval(0., 1.0))
  # pi = pyro.param("pi", torch.tensor(0.5), constraint=constraints.interval(0., 1.0))
  alpha = pyro.param("alpha", torch.tensor(10), constraint=constraints.positive)
  # Useful variables
  num_objs = freq_matrix.size()[0]
  num_preds = freq_matrix.size()[1]
  freqs = [] # number of customers at each table
  for i in range(num_objs):
    probs = torch.tensor(freqs + [alpha])
    probs /= probs.sum()
    z_sample = pyro.sample(f"z_{i}", dist.Categorical(probs))
    z_item = z_sample.item()
    if z_item >= len(freqs):
      freqs += [1.]
    else:
      freqs[z_item] += 1.
  obj_axis = pyro.plate("obj_axis", num_objs)
  pred_axis = pyro.plate("pred_axis", num_preds)
  with obj_axis:
    z = pyro.sample("z", dist.Categorical(probs))
  type_axis = pyro.plate("type_axis", max(z)+1)
  with obj_axis, type_axis:
    r = pyro.sample("r", dist.Bernoulli(theta), obs=freq_matrix)
  # r = F.pad(input=r, pad=(0,0,0,num_objs-r.shape[0]), mode='constant', value=0)
  # with obj_axis, pred_axis:
  #   t = pyro.sample("t", dist.Bernoulli(r[z]*pi), obs=freq_matrix)
    # obs = t+(r[z]-1)
    # d = pyro.sample("d", dist.Delta(obs), obs=freq_matrix)
  # return z, r, t, d

In [None]:
def model(freq_matrix):
  '''A Pyro model for the TIRM model.

  Arguments:
    freq_matrix: The frequency matrix of observations.

  Notes:
    Currently not sampling alpha.
  '''
  # Model parameters
  theta = pyro.param("theta", torch.tensor(0.5), constraint=constraints.interval(0., 1.0))
  # pi = pyro.param("pi", torch.tensor(0.5), constraint=constraints.interval(0., 1.0))
  alpha = pyro.param("alpha", torch.tensor(10), constraint=constraints.positive)
  # Useful variables
  num_objs = freq_matrix.size()[0]
  num_preds = freq_matrix.size()[1]
  freqs = [] # number of customers at each table
  for i in range(num_objs):
    probs = torch.tensor(freqs + [alpha])
    probs /= probs.sum()
    z_sample = pyro.sample(f"z_{i}", dist.Categorical(probs))
    z_item = z_sample.item()
    if z_item >= len(freqs):
      freqs += [1.]
    else:
      freqs[z_item] += 1.
  obj_axis = pyro.plate("obj_axis", num_objs)
  pred_axis = pyro.plate("pred_axis", num_preds)
  with obj_axis:
    z = pyro.sample("z", dist.Categorical(probs))
  type_axis = pyro.plate("type_axis", max(z)+1)
  with obj_axis, type_axis:
    r = pyro.sample(f"r_{i}", dist.Bernoulli(theta), obs=freq_matrix)
  # r = F.pad(input=r, pad=(0,0,0,num_objs-r.shape[0]), mode='constant', value=0)
  # with obj_axis, pred_axis:
  #   t = pyro.sample("t", dist.Bernoulli(r[z]*pi), obs=freq_matrix)
    # obs = t+(r[z]-1)
    # d = pyro.sample("d", dist.Delta(obs), obs=freq_matrix)
  # return z, r, t, d

In [992]:
def guide(freq_matrix):
  # Model parameters
  theta = pyro.param("theta", torch.tensor(.7), constraint=constraints.interval(0., 1.0))
  # pi = pyro.param("pi", torch.tensor(.7), constraint=constraints.interval(0., 1.0))
  alpha = pyro.param("alpha", torch.tensor(10), constraint=constraints.positive)
  # Useful variables
  num_objs = freq_matrix.size()[0]
  num_preds = freq_matrix.size()[1]
  freqs = [] # number of customers at each table
  for i in range(num_objs):
    probs = torch.tensor(freqs + [alpha])
    probs /= probs.sum()
    z_sample = pyro.sample(f"z_{i}", dist.Categorical(probs))
    z_item = z_sample.item()
    if z_item >= len(freqs):
      freqs += [1.]
    else:
      freqs[z_item] += 1.
  obj_axis = pyro.plate("obj_axis", num_objs)
  # pred_axis = pyro.plate("pred_axis", num_preds)
  with obj_axis:
    z = pyro.sample("z", dist.Categorical(probs))
  # type_axis = pyro.plate("type_axis", max(z)+1)
  # with obj_axis, type_axis:
  #   r = pyro.sample("r", dist.Bernoulli(theta))
  # r = F.pad(input=r, pad=(0,0,0,num_objs-r.shape[0]), mode='constant', value=0)
  # with obj_axis, pred_axis:
  #   t = pyro.sample("t", dist.Bernoulli(r[z]*pi))
    # obs = t+(r[z]-1)
    # d = pyro.sample("d", dist.Delta(obs))

In [995]:
z

tensor([4, 4, 2, 4, 4])

In [994]:
pyro.clear_param_store()
svi = SVI(model, guide, pyro.optim.Adam({"lr": 1}), loss=pyro.infer.TraceEnum_ELBO())
num_steps = 1000
for step in range(num_steps):
    loss = svi.step(r)
    if step % 100 == 0:
        print("Step {}: loss = {:.2f}".format(step, loss))

ValueError: Error while computing log_prob at site 'r':
Value is not broadcastable with batch_shape+event_shape: torch.Size([2, 5]) vs torch.Size([1, 2]).
Trace Shapes:      
 Param Sites:      
        theta      
        alpha      
Sample Sites:      
     z_0 dist     |
        value     |
     log_prob     |
     z_1 dist     |
        value     |
     log_prob     |
       z dist   2 |
        value   2 |
     log_prob   2 |
       r dist 1 2 |
        value 2 5 |

In [976]:
posterior_samples = {param: pyro.param(param).detach().clone() for param in pyro.get_param_store().keys()}

In [977]:
posterior_samples

{'theta': tensor(1.0000), 'pi': tensor(0.0800), 'alpha': tensor(10.)}

In [822]:
pyro.clear_param_store()
z, r, t, d =model(r)

tensor([[ 0., -1.,  1.,  0.,  0.],
        [ 0., -1.,  1.,  1.,  1.],
        [ 1., -1.,  1.,  0.,  1.],
        [ 0., -1.,  0.,  0.,  0.],
        [ 1., -1.,  0.,  1.,  1.]])


In [955]:
pyro.clear_param_store()
nuts_kernel = NUTS(model=model)
mcmc = MCMC(kernel=nuts_kernel, num_samples=1000, num_chains=10, warmup_steps=100)
posterior = mcmc.run(d)

Sample [0]: 100%|█████████████████████████████████| 1100/1100 [00:00, 30432.20it/s, step size=1.00e+00, acc. prob=1.000]                                                                                           


tensor([1.])
tensor([0.0909, 0.9091])
tensor([0.0833, 0.0833, 0.8333])
tensor([0.0769, 0.0769, 0.0769, 0.7692])
tensor([0.0714, 0.0714, 0.0714, 0.0714, 0.7143])
tensor([1.])
tensor([0.0909, 0.9091])
Z_1
Sample: tensor([[[[0]]],


        [[[1]]]])
Sample Shape: torch.Size([2, 1, 1, 1])
Probs tensor([0.0909, 0.9091])
tensor([1.])
tensor([0.0909, 0.9091])
Z_1
Sample: tensor([[[[0]]],


        [[[1]]]])
Sample Shape: torch.Size([2, 1, 1, 1])
Probs tensor([0.0909, 0.9091])


Sample [1]: 100%|█████████████████████████████████| 1100/1100 [00:00, 37270.66it/s, step size=1.00e+00, acc. prob=1.000]                                                                                           


tensor([1.])
tensor([0.0909, 0.9091])
tensor([0.0833, 0.0833, 0.8333])
tensor([0.0769, 0.0769, 0.0769, 0.7692])
tensor([0.0714, 0.0714, 0.0714, 0.0714, 0.7143])
tensor([1.])
tensor([0.0909, 0.9091])
Z_1
Sample: tensor([[[[0]]],


        [[[1]]]])
Sample Shape: torch.Size([2, 1, 1, 1])
Probs tensor([0.0909, 0.9091])
tensor([1.])
tensor([0.0909, 0.9091])
Z_1
Sample: tensor([[[[0]]],


        [[[1]]]])
Sample Shape: torch.Size([2, 1, 1, 1])
Probs tensor([0.0909, 0.9091])


Sample [2]: 100%|█████████████████████████████████| 1100/1100 [00:00, 36940.31it/s, step size=1.00e+00, acc. prob=1.000]                                                                                           


tensor([1.])
tensor([0.0909, 0.9091])
tensor([0.0833, 0.0833, 0.8333])
tensor([0.1538, 0.0769, 0.7692])
tensor([0.1429, 0.1429, 0.7143])
tensor([1.])
tensor([0.0909, 0.9091])
Z_1
Sample: tensor([[[[0]]],


        [[[1]]]])
Sample Shape: torch.Size([2, 1, 1, 1])
Probs tensor([0.0909, 0.9091])
tensor([1.])
tensor([0.0909, 0.9091])
Z_1
Sample: tensor([[[[0]]],


        [[[1]]]])
Sample Shape: torch.Size([2, 1, 1, 1])
Probs tensor([0.0909, 0.9091])


Sample [3]: 100%|█████████████████████████████████| 1100/1100 [00:00, 37601.75it/s, step size=1.00e+00, acc. prob=1.000]                                                                                           


tensor([1.])
tensor([0.0909, 0.9091])
tensor([0.0833, 0.0833, 0.8333])
tensor([0.0769, 0.0769, 0.0769, 0.7692])
tensor([0.0714, 0.0714, 0.0714, 0.0714, 0.7143])
tensor([1.])
tensor([0.0909, 0.9091])
Z_1
Sample: tensor([[[[0]]],


        [[[1]]]])
Sample Shape: torch.Size([2, 1, 1, 1])
Probs tensor([0.0909, 0.9091])
tensor([1.])
tensor([0.0909, 0.9091])
Z_1
Sample: tensor([[[[0]]],


        [[[1]]]])
Sample Shape: torch.Size([2, 1, 1, 1])
Probs tensor([0.0909, 0.9091])


Sample [4]: 100%|█████████████████████████████████| 1100/1100 [00:00, 37284.21it/s, step size=1.00e+00, acc. prob=1.000]                                                                                           


tensor([1.])
tensor([0.0909, 0.9091])
tensor([0.0833, 0.0833, 0.8333])
tensor([0.0769, 0.0769, 0.0769, 0.7692])
tensor([0.0714, 0.0714, 0.0714, 0.0714, 0.7143])
tensor([1.])
tensor([0.0909, 0.9091])
Z_1
Sample: tensor([[[[0]]],


        [[[1]]]])
Sample Shape: torch.Size([2, 1, 1, 1])
Probs tensor([0.0909, 0.9091])
tensor([1.])
tensor([0.0909, 0.9091])
Z_1
Sample: tensor([[[[0]]],


        [[[1]]]])
Sample Shape: torch.Size([2, 1, 1, 1])
Probs tensor([0.0909, 0.9091])


Sample [5]: 100%|█████████████████████████████████| 1100/1100 [00:00, 36703.16it/s, step size=1.00e+00, acc. prob=1.000]                                                                                           


tensor([1.])
tensor([0.0909, 0.9091])
tensor([0.0833, 0.0833, 0.8333])
tensor([0.0769, 0.1538, 0.7692])
tensor([0.1429, 0.1429, 0.7143])
tensor([1.])
tensor([0.0909, 0.9091])
Z_1
Sample: tensor([[[[0]]],


        [[[1]]]])
Sample Shape: torch.Size([2, 1, 1, 1])
Probs tensor([0.0909, 0.9091])
tensor([1.])
tensor([0.0909, 0.9091])
Z_1
Sample: tensor([[[[0]]],


        [[[1]]]])
Sample Shape: torch.Size([2, 1, 1, 1])
Probs tensor([0.0909, 0.9091])


Warmup:   0%|                                                                                                                                                                               | 0/1100 [00:00, ?it/s]

tensor([1.])
tensor([0.0909, 0.9091])
tensor([0.0833, 0.0833, 0.8333])
tensor([0.0769, 0.0769, 0.0769, 0.7692])
tensor([0.0714, 0.0714, 0.0714, 0.0714, 0.7143])
tensor([1.])
tensor([0.0909, 0.9091])
Z_1
Sample: tensor([[[[0]]],


        [[[1]]]])
Sample Shape: torch.Size([2, 1, 1, 1])
Probs tensor([0.0909, 0.9091])
tensor([1.])
tensor([0.0909, 0.9091])
Z_1
Sample: tensor([[[[0]]],


        [[[1]]]])
Sample Shape: torch.Size([2, 1, 1, 1])
Probs tensor([0.0909, 0.9091])


Sample [6]: 100%|█████████████████████████████████| 1100/1100 [00:00, 33428.50it/s, step size=1.00e+00, acc. prob=1.000]                                                                                           
Warmup:   0%|                                                                                                                                                                               | 0/1100 [00:00, ?it/s]

tensor([1.])
tensor([0.0909, 0.9091])
tensor([0.0833, 0.0833, 0.8333])
tensor([0.0769, 0.0769, 0.0769, 0.7692])
tensor([0.1429, 0.0714, 0.0714, 0.7143])
tensor([1.])
tensor([0.0909, 0.9091])
Z_1
Sample: tensor([[[[0]]],


        [[[1]]]])
Sample Shape: torch.Size([2, 1, 1, 1])
Probs tensor([0.0909, 0.9091])
tensor([1.])
tensor([0.0909, 0.9091])
Z_1
Sample: tensor([[[[0]]],


        [[[1]]]])
Sample Shape: torch.Size([2, 1, 1, 1])
Probs tensor([0.0909, 0.9091])


Sample [7]: 100%|█████████████████████████████████| 1100/1100 [00:00, 36868.29it/s, step size=1.00e+00, acc. prob=1.000]                                                                                           
Warmup:   0%|                                                                                                                                                                               | 0/1100 [00:00, ?it/s]

tensor([1.])
tensor([0.0909, 0.9091])
tensor([0.0833, 0.0833, 0.8333])
tensor([0.0769, 0.0769, 0.0769, 0.7692])
tensor([0.0714, 0.0714, 0.0714, 0.0714, 0.7143])
tensor([1.])
tensor([0.0909, 0.9091])
Z_1
Sample: tensor([[[[0]]],


        [[[1]]]])
Sample Shape: torch.Size([2, 1, 1, 1])
Probs tensor([0.0909, 0.9091])
tensor([1.])
tensor([0.0909, 0.9091])
Z_1
Sample: tensor([[[[0]]],


        [[[1]]]])
Sample Shape: torch.Size([2, 1, 1, 1])
Probs tensor([0.0909, 0.9091])


Sample [8]: 100%|█████████████████████████████████| 1100/1100 [00:00, 35496.54it/s, step size=1.00e+00, acc. prob=1.000]                                                                                           
Warmup:   0%|                                                                                                                                                                               | 0/1100 [00:00, ?it/s]

tensor([1.])
tensor([0.0909, 0.9091])
tensor([0.1667, 0.8333])
tensor([0.1538, 0.0769, 0.7692])
tensor([0.1429, 0.1429, 0.7143])
tensor([1.])
tensor([0.0909, 0.9091])
Z_1
Sample: tensor([[[[0]]],


        [[[1]]]])
Sample Shape: torch.Size([2, 1, 1, 1])
Probs tensor([0.0909, 0.9091])
tensor([1.])
tensor([0.0909, 0.9091])
Z_1
Sample: tensor([[[[0]]],


        [[[1]]]])
Sample Shape: torch.Size([2, 1, 1, 1])
Probs tensor([0.0909, 0.9091])


Sample [9]: 100%|█████████████████████████████████| 1100/1100 [00:00, 37103.41it/s, step size=1.00e+00, acc. prob=1.000]                                                                                           


In [409]:
predictive = pyro.infer.Predictive(model, guide=guide, num_samples=10000)
svi_samples = predictive(d)

In [843]:
import torch.nn.functional as F

In [854]:
data = torch.ones(4, 4)
print(data)
# pad(left, right, top, bottom)
new_data = F.pad(input=data, pad=(0, 1, 0, 0), mode='constant', value=0)
print(new_data)

tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]])
tensor([[1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 0.]])


In [866]:
z_pad = F.pad(input=z, pad=(0,10-len(z)), mode='constant', value=0)
z_pad

tensor([3, 3, 3, 3, 3, 0, 0, 0, 0, 0])

In [871]:
num_objs = d.shape[0]

In [875]:
print(r)
r_ = F.pad(input=r, pad=(0,0,0,num_objs-r.shape[0]), mode='constant', value=0)
print(r_)

tensor([[0., 0., 1., 1., 0.],
        [1., 1., 0., 1., 1.],
        [0., 0., 1., 1., 0.],
        [1., 0., 1., 1., 1.]])
tensor([[0., 0., 1., 1., 0.],
        [1., 1., 0., 1., 1.],
        [0., 0., 1., 1., 0.],
        [1., 0., 1., 1., 1.],
        [0., 0., 0., 0., 0.]])


In [876]:
r_[z]

tensor([[1., 0., 1., 1., 1.],
        [1., 0., 1., 1., 1.],
        [1., 0., 1., 1., 1.],
        [1., 0., 1., 1., 1.],
        [1., 0., 1., 1., 1.]])

In [117]:
import os
import torch
import pyro
from torch.distributions import constraints
from pyro.distributions import Bernoulli, Categorical, MultivariateNormal, Normal
from pyro.distributions.util import broadcast_shape
from pyro.infer import Trace_ELBO, TraceEnum_ELBO, config_enumerate
import pyro.poutine as poutine
from pyro.optim import Adam

smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.9.0')

# We'll ue this helper to check our models are correct.
def test_model(model, guide, loss):
    pyro.clear_param_store()
    loss.loss(model, guide)

In [161]:
d = dist.Bernoulli(torch.tensor([0.3,0.7]))

In [156]:
x = d.sample()

In [157]:
x

tensor([1., 0.])

In [158]:
x.shape

torch.Size([2])

In [159]:
d.event_shape

torch.Size([])

In [160]:
d.batch_shape

torch.Size([2])