In [26]:
%matplotlib inline
# import some dependencies
import numpy as np
import matplotlib.pyplot as plt
try:
    import seaborn as sns
    sns.set()
except ImportError:
    pass

import torch
from torch.autograd import Variable

import pyro
import pyro.infer
import pyro.optim
import pyro.distributions as dist
from pyro.optim import Adam
from pyro.infer import SVI
torch.manual_seed(101)

<torch._C.Generator at 0x7f33af08f530>

In [None]:
from __future__ import absolute_import, division, print_function

import numbers

import scipy.stats as spr
import torch
from torch.autograd import Variable

from pyro.distributions.distribution import Distribution
from pyro.distributions.util import log_gamma


[docs]
class InvGamma(Distribution):
    """
    Gamma distribution parameterized by `alpha` and `beta`.

    This is often used in conjunction with `torch.nn.Softplus` to ensure
    `alpha` and `beta` parameters are positive.

    :param torch.autograd.Variable alpha: Shape parameter. Should be positive.
    :param torch.autograd.Variable beta: Shape parameter. Should be positive.
        Shouldb be the same shape as `alpha`.
    """

    def __init__(self, alpha, beta, batch_size=None, *args, **kwargs):
        if alpha.size() != beta.size():
            raise ValueError("Expected alpha.size() == beta.size(), but got {} vs {}".format(alpha.size(), beta.size()))
        self.alpha = alpha
        self.beta = beta
        if alpha.dim() == 1 and beta.dim() == 1 and batch_size is not None:
            self.alpha = alpha.expand(batch_size, alpha.size(0))
            self.beta = beta.expand(batch_size, beta.size(0))
        super(Gamma, self).__init__(*args, **kwargs)

[docs]
    def batch_shape(self, x=None):
        """
        Ref: :py:meth:`pyro.distributions.distribution.Distribution.batch_shape`
        """
        event_dim = 1
        alpha = self.alpha
        if x is not None:
            if x.size()[-event_dim] != alpha.size()[-event_dim]:
                raise ValueError("The event size for the data and distribution parameters must match.\n"
                                 "Expected x.size()[-1] == self.alpha.size()[-1], but got {} vs {}".format(
                                     x.size(-1), alpha.size(-1)))
            try:
                alpha = self.alpha.expand_as(x)
            except RuntimeError as e:
                raise ValueError("Parameter `alpha` with shape {} is not broadcastable to "
                                 "the data shape {}. \nError: {}".format(alpha.size(), x.size(), str(e)))
        return alpha.size()[:-event_dim]


[docs]
    def event_shape(self):
        """
        Ref: :py:meth:`pyro.distributions.distribution.Distribution.event_shape`
        """
        event_dim = 1
        return self.alpha.size()[-event_dim:]


[docs]
    def sample(self):
        """
        Ref: :py:meth:`pyro.distributions.distribution.Distribution.sample`
        """
        theta = torch.pow(self.beta, -1.0)
        np_sample = spr.gamma.rvs(self.alpha.data.cpu().numpy(), scale=theta.data.cpu().numpy())
        if isinstance(np_sample, numbers.Number):
            np_sample = [np_sample]
        x = Variable(torch.Tensor(np_sample).type_as(self.alpha.data))
        x = x.expand(self.shape())
        return x


[docs]
    def batch_log_pdf(self, x):
        """
        Ref: :py:meth:`pyro.distributions.distribution.Distribution.batch_log_pdf`
        """
        alpha = self.alpha.expand(self.shape(x))
        beta = self.beta.expand(self.shape(x))
        ll_1 = -beta * x
        ll_2 = (alpha - 1.0) * torch.log(x)
        ll_3 = alpha * torch.log(beta)
        ll_4 = -log_gamma(alpha)
        log_pdf = torch.sum(ll_1 + ll_2 + ll_3 + ll_4, -1)
        batch_log_pdf_shape = self.batch_shape(x) + (1,)
        return log_pdf.contiguous().view(batch_log_pdf_shape)


[docs]
    def analytic_mean(self):
        """
        Ref: :py:meth:`pyro.distributions.distribution.Distribution.analytic_mean`
        """
        return self.alpha / self.beta


[docs]
    def analytic_var(self):
        """
        Ref: :py:meth:`pyro.distributions.distribution.Distribution.analytic_var`
        """
        return self.alpha / torch.pow(self.beta, 2.0)

In [28]:
import data
import random

data_dict = {c:data.generate_data(c, 2) for c in ['red', 'green', 'blue', 'yellow']}
objs = {}
n = 2
for j, c in enumerate(['red', 'green', 'blue', 'yellow']):
    for i in range(n):
        objs["o{}".format(j*n + i)] = (c, data_dict[c][i])
        
        
def check_rule(o1, o2):
    c1, d1 = o1
    c2, d2 = o2
    
    if c1 == 'red': 
         if c2 != 'blue':
            return 'correction' 
    
    return 'no correction'

def correction_to_int(corr_string):
    return int(corr_string == 'correction')

def colour_to_int(corr_string, colour):
    return int(corr_string == colour)

def my_sample(objs, n=10):
    cs = []
    red = []
    blue = []
    
    for i in range(n):
        o1 = random.sample(objs.keys(), k=1)[0]
        o2 = random.sample(objs.keys(), k=1)[0]
        while o1 == o2:
            o2 = random.sample(objs.keys(), k=1)[0]
        c = correction_to_int(check_rule(objs[o1], objs[o2]))
        f_red = objs[o1][1]
        f_blue = objs[o2][1]
        cs.append(c)
        red.append(f_red)
        blue.append(f_blue)
    return cs, red, blue

def my_colour_sample(objs, n=10):
    cs = []
    red = []
    
    for i in range(n):
        o1 = random.sample(objs.keys(), k=1)[0]
        c = objs[o1][0]
        f_red = objs[o1][1]
        cs.append(c)
        red.append(f_red)
    return cs, red

colours, c_data = my_colour_sample(objs)
colours = list(map(lambda x: [colour_to_int(x, 'blue')], colours))
colours = np.array(colours)
c_data = np.array(c_data)

colours

array([[0],
       [0],
       [0],
       [1],
       [0],
       [0],
       [1],
       [1],
       [0],
       [0]])

In [None]:
def inverse_gamma(a, b):
    

In [102]:
def mixture(colour, obj):
    # I'm not sure what difference between dist.Beta(a, b) vs dist.Beta, a, b is but the second one throws an error...
    prior_colour = pyro.sample('{}_prior'.format(colour), dist.beta,
                               Variable(torch.Tensor([1.])), Variable(torch.Tensor([1.])))
    colour_sample = pyro.sample(colour, dist.bernoulli,
                        prior_colour)
    colour_sample = colour if colour_sample.data[0] == 1.0 else 'not'
    mean_prior = pyro.sample('{}_mean_prior'.format(colour), dist.normal, 
                             Variable(torch.Tensor([0.5, 0.5, 0.5])), Variable(torch.Tensor([1., 1., 1.])))
    mean_f = {colour: mean_prior, 'not': Variable(torch.Tensor([0.5, 0.5, 0.5]))}[colour_sample]
    sigma_f = {colour:[0.1, 0.1, 0.1], 'not': [1,1,1]}[colour_sample]
    fo2 = pyro.sample('f({})'.format(obj), dist.normal,
                     mean_f,
                     Variable(torch.Tensor(sigma_f)))
    return colour_sample, fo2.data

In [103]:
mixture('blue', 'o1')

('blue', 
  1.5104
  0.3225
  1.9995
 [torch.FloatTensor of size 3])

In [70]:
posterior = pyro.infer.Importance(mixture, num_samples=100)

In [71]:
marginal = pyro.infer.Marginal(posterior)

In [72]:
cond_mixture = pyro.condition(mixture, data={'f(o2)': Variable(torch.Tensor([1,0,0]))})

In [73]:
post_cond = pyro.infer.Importance(cond_mixture, num_samples=100)
marg_cond = pyro.infer.Marginal(post_cond)

In [76]:
data = []
for i in range(10000):
    s = pyro.sample('model|data', marg_cond)
    data.append(int(s[0] == 'blue'))
print(np.mean(data))

0.0


In [81]:
data = []
for i in range(10000):
    s = pyro.sample('model', mixture)
    data.append(int(s[0] == 'blue'))
print(np.mean(data))

0.495


In [None]:
for c, d in zip(colours, c_data):
    

In [74]:
def training_mixture(colour, obj):
    def mix(colour_data, feature_data):
        #def mixture():
        # p_colour latent variable to be estimated
        alpha0 = Variable(torch.Tensor([1.]))
        beta0 = Variable(torch.Tensor([1.]))
        prior_colour = pyro.sample('{}_prior'.format(colour), dist.beta, alpha0, beta0)
        # the mean of the normal for the case of colour == True
        pos_mean_prior = Variable(torch.Tensor([0.5,0.5,0.5]))
        pos_sigm_prior = Variable(torch.Tensor([1., 1.,1.]))
        pos_mean = pyro.sample('{}_mu+'.format(colour), dist.normal, pos_mean_prior, pos_sigm_prior)
        # the mean of the normal for the case of colour == False
        neg_mean_prior = Variable(torch.Tensor([0.5,0.5,0.5]))
        neg_sigm_prior = Variable(torch.Tensor([1., 1.,1.]))
        neg_mean = pyro.sample('{}_mu-'.format(colour), dist.normal, pos_mean_prior, pos_sigm_prior)
        # the std deviation for the normal when colour == True
        pos_sigm = Variable(torch.Tensor([0.1, 0.1, 0.1]))
        # the std deviation for the normal when colour == False
        neg_sigm = Variable(torch.Tensor([1., 1., 1.]))

        # observe values
        for i, (c, d) in enumerate(zip(colour_data, feature_data)):
            colour_obs = pyro.observe("{}.{}".format(colour, i), dist.bernoulli, c, prior_colour)
            colour_sample = colour if colour_obs.data[0] == 1.0 else 'not'
            mean_f = {colour: pos_mean, 'not': neg_mean}[colour_sample]
            sigma_f = {colour:pos_sigm, 'not': neg_sigm}[colour_sample]
            pyro.observe('f({}).{}'.format(obj, i), dist.normal, d, mean_f, sigma_f)

            #colour_sample = pyro.sample(colour, dist.bernoulli, prior_colour)
            #colour_sample = colour if colour_obs.data[0] == 1.0 else 'not'
            #mean_f = {colour: pos_mean, 'not': neg_mean}[colour_sample]
            #sigma_f = {colour:pos_sigm, 'not': neg_sigm}[colour_sample]
            #f = pyro.sample('f({})'.format(obj), dist.normal, mean_f, sigma_f)
            #return colour_sample, f
        #return mixture
        
       
    def guide(colour_data, feature_data):
        log_alpha_q_0 = Variable(torch.Tensor([np.log(15.0)]), requires_grad=True)
        log_beta_q_0 = Variable(torch.Tensor([np.log(15.0)]), requires_grad=True)

        log_alpha_q = pyro.param("log_alpha_q", log_alpha_q_0)
        log_beta_q = pyro.param("log_beta_q", log_beta_q_0)

        alpha_q, beta_q = torch.exp(log_alpha_q), torch.exp(log_beta_q)

        pyro.sample('{}_prior'.format(colour), dist.beta, alpha_q, beta_q)
        pos_mean_prior_q0 = Variable(torch.Tensor([0.5,0.5,0.5]), requires_grad=True)
        pos_sigm_prior_q0 = Variable(torch.Tensor([1., 1.,1.]), requires_grad=True)
        pos_mean_prior_q = pyro.param("pos_mu_q", pos_mean_prior_q0)
        pos_sigm_prior_q = pyro.param("pos_sigm_q", pos_sigm_prior_q0)
        pos_mean_q = pyro.sample('{}_mu+'.format(colour), dist.normal, pos_mean_prior_q, pos_sigm_prior_q)
        # the mean of the normal for the case of colour == False
        neg_mean_prior_q0 = Variable(torch.Tensor([0.5,0.5,0.5]), requires_grad=True)
        neg_sigm_prior_q0 = Variable(torch.Tensor([1., 1.,1.]), requires_grad=True)
        neg_mean_prior_q = pyro.param("neg_mu_q", neg_mean_prior_q0)
        neg_sigm_prior_q = pyro.param("neg_sigm_q", neg_sigm_prior_q0)
        neg_mean = pyro.sample('{}_mu-'.format(colour), dist.normal, neg_mean_prior_q, neg_sigm_prior_q)

        
    return mix, guide

In [75]:
mixture, guide = training_mixture('blue', 'o2')

In [76]:
adam_params = {"lr":0.0005, "betas": (0.90, 0.999)}
optimiser = Adam(adam_params)

In [77]:
svi = SVI(mixture, guide, optimiser, loss="ELBO", num_particles=7)

In [80]:
n_steps = 4000
for stip in range(n_steps):
    svi.step(Variable(torch.Tensor(colours)), Variable(torch.Tensor(c_data)))
    if step % 100 == 0:
        print('.', end=' ')

ValueError: optimizing a parameter that doesn't require gradients

In [71]:
m = mixture(colours, c_data)
posterior = pyro.infer.Importance(m, num_samples=100)
marginal = pyro.infer.Marginal(posterior)

In [72]:
cond_mixture = pyro.condition(marginal, data={'f(o2)': Variable(torch.Tensor([1,0,0]))})
post_cond = pyro.infer.Importance(cond_mixture, num_samples=100)
marg_cond = pyro.infer.Marginal(post_cond)

In [73]:
data = []
for i in range(10000):
    s = pyro.sample('model|data{}'.format(i), marg_cond)
    data.append(int(s[0] == 'blue'))
print(np.mean(data))

TypeError: 'int' object is not callable

In [68]:
marg_cond.trace_dist()

TypeError: 'int' object is not callable

In [150]:
pyro.sample('s', marginal)

ValueError: Got ps=None, logits=None. Either `ps` or `logits` must be specified, but not both.

In [144]:
marg_cond.

ValueError: Got ps=None, logits=None. Either `ps` or `logits` must be specified, but not both.

In [8]:
import pyro.poutine as poutine

In [9]:
guides = poutine.block(mixture, hide_types=['observe'])

In [10]:
guide_trace = poutine.trace(guides).get_trace()

In [11]:
model_trace = poutine.trace(poutine.replay(mixture, guide_trace)).get_trace()

In [12]:
log_weight = model_trace.log_prob_sum() - guide_trace.log_prob_sum()

AttributeError: 'Trace' object has no attribute 'log_prob_sum'

In [14]:
model_trace.batch_log_pdf()

ValueError: Got ps=None, logits=None. Either `ps` or `logits` must be specified, but not both.

In [84]:
p = dist.Bernoulli(Variable(torch.Tensor(1)))

In [85]:
p.

<pyro.distributions.bernoulli.Bernoulli at 0x7f33aa729898>