In [1]:
import pyro
import torch

import numpy as np

## Data generating model

In [484]:
def model(num_time_steps):
    
    """
    Toy Gillespie model implemented in pyro. Used to test different observational scenerios.
    """
    
    batch = torch.tensor([[0,0,0]] * num_time_steps)
    
    ## Latent model parameters. Interested in infering rate
    with pyro.plate("latent_rates"):
        #r1 = pyro.param("rate", torch.tensor(.5),
        #                     constraint=pyro.distributions.constraints.positive)
        r1 = pyro.sample("r1", pyro.distributions.Normal(.5, .01))
    
    with pyro.plate("data"):
    #for i in range(len(data)):    
        ## Initial model setup
        #with pyro.plate("starting_state"):
        s1_temp = pyro.sample("s1_start", pyro.distributions.Normal(10, .00001))#, obs = data[0, i, 0])
        s2_temp = pyro.sample("s2_start", pyro.distributions.Normal(10, .00001))#, obs = data[0, i, 1])
        s3_temp = pyro.sample("s3_start", pyro.distributions.Normal(10, .00001))#, obs = data[0, i, 2])
            # t_prev = pyro.sample("time_start", pyro.distributions.Delta(torch.tensor(0)), obs = data[0, 3])
        batch[0] = torch.tensor([s1_temp, s2_temp, s3_temp])
        ## Time series sampling and updates
        for t in range(1, num_time_steps):

            ## Gillespie sampling
            #with pyro.plate("latent_discrete"):
            sample = pyro.sample("sample_{0}".format(str(t)), 
                                     pyro.distributions.Categorical(
                                         torch.tensor([s1_temp * r1, 
                                                       s2_temp * .1, 
                                                       s3_temp * .9])))#,
                                     #infer={"enumerate": "parallel"})

            ## Update species
            update = np.zeros(3)
            update[sample] = 1
            s1_temp = pyro.sample("s1_{0}".format(str(t)), 
                               pyro.distributions.Normal(s1_temp + update[0], .00001))#, 
                               #obs = data[t, i, 0])
            s2_temp = pyro.sample("s2_{0}".format(str(t)), 
                               pyro.distributions.Normal(s2_temp + update[1], .00001))#, 
                               #obs = data[t, i, 1])
            s3_temp = pyro.sample("s3_{0}".format(str(t)), 
                               pyro.distributions.Normal(s3_temp + update[2], .00001))#, 
                               #obs = data[t, i, 2])
            batch[t] = torch.tensor([s1_temp, s2_temp, s3_temp])
            ## Update time
            # t_temp =  pyro.sample("t_{0}_temp".format(str(t)), 
            #                       pyro.distributions.Exponential(np.sum([s1_t, s2_t, s3_t])))
            # t_t =  pyro.sample("t_{0}".format(str(t)), pyro.distributions.Normal(t_prev + t_temp, .00001),
            #                    obs = data[t,3])

            # s1_prev = s1_t
            # s2_prev = s2_t
            # s3_prev = s3_t
            # t_prev = t_t
            
    return batch


In [485]:
data = list()
for x in range(100):
    temp = model(100)
    data.append(temp)
data = torch.stack(data)

## Incorrectly specified model

In [386]:
s1_temp = torch.tensor([4,12,12])
s2_temp = torch.tensor([4,12,12])
s3_temp = torch.tensor([4,12,12])

In [384]:
with pyro.plate("loop", size = len(s1_temp)):
    pyro.sample("test", pyro.distributions.Categorical(torch.tensor([torch.tensor([4,12,12]) * .5,
                                                  torch.tensor([4,12,12]) * .5,
                                                  torch.tensor([4,12,12]) * .5])))

ValueError: only one element tensors can be converted to Python scalars

In [390]:
data.shape[1]

10

In [575]:
def model(data):
    
    ## Latent model parameters. Interested in infering rate
    #with pyro.plate("latent_rates"):
        
    ## Rate parameter
    rate = pyro.param("rate", torch.tensor(.99),
                     constraint=pyro.distributions.constraints.positive)

    r1 = pyro.sample("r1", pyro.distributions.Normal(rate, .01))

    ## starting points
    s1_start = pyro.sample("s1_start", pyro.distributions.Normal(10, .00001))
    s2_start = pyro.sample("s2_start", pyro.distributions.Normal(10, .00001))
    s3_start = pyro.sample("s3_start", pyro.distributions.Normal(10, .00001))

    with pyro.plate("loop", data.shape[0]) as i:
        
        for t in range(data.shape[1]):

            ## Gillespie sampling
            #with pyro.plate("latent_discrete"):
            sample = pyro.sample("sample_{0}".format(str(t)), 
                                     pyro.distributions.Categorical(
                                         torch.tensor([s1_start * r1, 
                                                       s2_start * .1, 
                                                       s3_start * .9])).to_event(1))#,
                                     #infer={"enumerate": "parallel"})

            ## Update species
            update = np.zeros(3)
            update[sample] = 1
            s1_temp = pyro.sample("s1_{0}".format(str(t)), 
                               pyro.distributions.Normal(s1_start + update[0], .00001).to_event(1), 
                               obs = data[i, t, 0])
            s2_temp = pyro.sample("s2_{0}".format(str(t)), 
                               pyro.distributions.Normal(s2_start + update[1], .00001).to_event(1), 
                               obs = data[i, t, 1])
            s3_temp = pyro.sample("s3_{0}".format(str(t)), 
                               pyro.distributions.Normal(s3_start + update[2], .00001).to_event(1), 
                               obs = data[i, t, 2])
            # batch[t] = torch.tensor([s1_temp, s2_temp, s3_temp])
            ## Update time
            # t_temp =  pyro.sample("t_{0}_temp".format(str(t)), 
            #                       pyro.distributions.Exponential(np.sum([s1_t, s2_t, s3_t])))
            # t_t =  pyro.sample("t_{0}".format(str(t)), pyro.distributions.Normal(t_prev + t_temp, .00001),
            #                    obs = data[t,3])

            s1_start = s1_temp
            s2_start = s2_temp
            s3_start = s3_temp
            # t_prev = t_t
            
    return s1_start


In [576]:
model(data)

ValueError: Expected reinterpreted_batch_ndims <= len(base_distribution.batch_shape), actual 1 vs 0

In [439]:
nuts_kernel = pyro.infer.NUTS(model)
mcmc = pyro.infer.MCMC(nuts_kernel, num_samples=500)
mcmc.run(data)
samples = mcmc.get_samples()

Warmup:   0%|                                            | 0/1000 [00:00, ?it/s]

ValueError: number of dimensions must be within [0, 32]
  Trace Shapes:                                                                      
   Param Sites:                                                                      
           rate                                                                      
  Sample Sites:                                                                      
        r1 dist                                                                     |
          value                                                                     |
  s1_start dist                                                                     |
          value                                                                     |
  s2_start dist                                                                     |
          value                                                                     |
  s3_start dist                                                                     |
          value                                                                     |
      loop dist                                                                     |
          value                                                                 100 |
sample_0_0 dist                                                                     |
          value                                                                   3 |
    s1_0_0 dist                                                                     |
          value                                                                     |
    s2_0_0 dist                                                                     |
          value                                                                     |
    s3_0_0 dist                                                                     |
          value                                                                     |
sample_1_0 dist                                                                     |
          value                                                               3   1 |
    s1_1_0 dist                                                                     |
          value                                                                     |
    s2_1_0 dist                                                                     |
          value                                                                     |
    s3_1_0 dist                                                                     |
          value                                                                     |
sample_2_0 dist                                                                     |
          value                                                             3 1   1 |
    s1_2_0 dist                                                                     |
          value                                                                     |
    s2_2_0 dist                                                                     |
          value                                                                     |
    s3_2_0 dist                                                                     |
          value                                                                     |
sample_3_0 dist                                                                     |
          value                                                           3 1 1   1 |
    s1_3_0 dist                                                                     |
          value                                                                     |
    s2_3_0 dist                                                                     |
          value                                                                     |
    s3_3_0 dist                                                                     |
          value                                                                     |
sample_4_0 dist                                                                     |
          value                                                         3 1 1 1   1 |
    s1_4_0 dist                                                                     |
          value                                                                     |
    s2_4_0 dist                                                                     |
          value                                                                     |
    s3_4_0 dist                                                                     |
          value                                                                     |
sample_5_0 dist                                                                     |
          value                                                       3 1 1 1 1   1 |
    s1_5_0 dist                                                                     |
          value                                                                     |
    s2_5_0 dist                                                                     |
          value                                                                     |
    s3_5_0 dist                                                                     |
          value                                                                     |
sample_6_0 dist                                                                     |
          value                                                     3 1 1 1 1 1   1 |
    s1_6_0 dist                                                                     |
          value                                                                     |
    s2_6_0 dist                                                                     |
          value                                                                     |
    s3_6_0 dist                                                                     |
          value                                                                     |
sample_7_0 dist                                                                     |
          value                                                   3 1 1 1 1 1 1   1 |
    s1_7_0 dist                                                                     |
          value                                                                     |
    s2_7_0 dist                                                                     |
          value                                                                     |
    s3_7_0 dist                                                                     |
          value                                                                     |
sample_8_0 dist                                                                     |
          value                                                 3 1 1 1 1 1 1 1   1 |
    s1_8_0 dist                                                                     |
          value                                                                     |
    s2_8_0 dist                                                                     |
          value                                                                     |
    s3_8_0 dist                                                                     |
          value                                                                     |
sample_9_0 dist                                                                     |
          value                                               3 1 1 1 1 1 1 1 1   1 |
    s1_9_0 dist                                                                     |
          value                                                                     |
    s2_9_0 dist                                                                     |
          value                                                                     |
    s3_9_0 dist                                                                     |
          value                                                                     |
sample_0_1 dist                                                                     |
          value                                             3 1 1 1 1 1 1 1 1 1   1 |
    s1_0_1 dist                                                                     |
          value                                                                     |
    s2_0_1 dist                                                                     |
          value                                                                     |
    s3_0_1 dist                                                                     |
          value                                                                     |
sample_1_1 dist                                                                     |
          value                                           3 1 1 1 1 1 1 1 1 1 1   1 |
    s1_1_1 dist                                                                     |
          value                                                                     |
    s2_1_1 dist                                                                     |
          value                                                                     |
    s3_1_1 dist                                                                     |
          value                                                                     |
sample_2_1 dist                                                                     |
          value                                         3 1 1 1 1 1 1 1 1 1 1 1   1 |
    s1_2_1 dist                                                                     |
          value                                                                     |
    s2_2_1 dist                                                                     |
          value                                                                     |
    s3_2_1 dist                                                                     |
          value                                                                     |
sample_3_1 dist                                                                     |
          value                                       3 1 1 1 1 1 1 1 1 1 1 1 1   1 |
    s1_3_1 dist                                                                     |
          value                                                                     |
    s2_3_1 dist                                                                     |
          value                                                                     |
    s3_3_1 dist                                                                     |
          value                                                                     |
sample_4_1 dist                                                                     |
          value                                     3 1 1 1 1 1 1 1 1 1 1 1 1 1   1 |
    s1_4_1 dist                                                                     |
          value                                                                     |
    s2_4_1 dist                                                                     |
          value                                                                     |
    s3_4_1 dist                                                                     |
          value                                                                     |
sample_5_1 dist                                                                     |
          value                                   3 1 1 1 1 1 1 1 1 1 1 1 1 1 1   1 |
    s1_5_1 dist                                                                     |
          value                                                                     |
    s2_5_1 dist                                                                     |
          value                                                                     |
    s3_5_1 dist                                                                     |
          value                                                                     |
sample_6_1 dist                                                                     |
          value                                 3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1   1 |
    s1_6_1 dist                                                                     |
          value                                                                     |
    s2_6_1 dist                                                                     |
          value                                                                     |
    s3_6_1 dist                                                                     |
          value                                                                     |
sample_7_1 dist                                                                     |
          value                               3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1   1 |
    s1_7_1 dist                                                                     |
          value                                                                     |
    s2_7_1 dist                                                                     |
          value                                                                     |
    s3_7_1 dist                                                                     |
          value                                                                     |
sample_8_1 dist                                                                     |
          value                             3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1   1 |
    s1_8_1 dist                                                                     |
          value                                                                     |
    s2_8_1 dist                                                                     |
          value                                                                     |
    s3_8_1 dist                                                                     |
          value                                                                     |
sample_9_1 dist                                                                     |
          value                           3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1   1 |
    s1_9_1 dist                                                                     |
          value                                                                     |
    s2_9_1 dist                                                                     |
          value                                                                     |
    s3_9_1 dist                                                                     |
          value                                                                     |
sample_0_2 dist                                                                     |
          value                         3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1   1 |
    s1_0_2 dist                                                                     |
          value                                                                     |
    s2_0_2 dist                                                                     |
          value                                                                     |
    s3_0_2 dist                                                                     |
          value                                                                     |
sample_1_2 dist                                                                     |
          value                       3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1   1 |
    s1_1_2 dist                                                                     |
          value                                                                     |
    s2_1_2 dist                                                                     |
          value                                                                     |
    s3_1_2 dist                                                                     |
          value                                                                     |
sample_2_2 dist                                                                     |
          value                     3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1   1 |
    s1_2_2 dist                                                                     |
          value                                                                     |
    s2_2_2 dist                                                                     |
          value                                                                     |
    s3_2_2 dist                                                                     |
          value                                                                     |
sample_3_2 dist                                                                     |
          value                   3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1   1 |
    s1_3_2 dist                                                                     |
          value                                                                     |
    s2_3_2 dist                                                                     |
          value                                                                     |
    s3_3_2 dist                                                                     |
          value                                                                     |
sample_4_2 dist                                                                     |
          value                 3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1   1 |
    s1_4_2 dist                                                                     |
          value                                                                     |
    s2_4_2 dist                                                                     |
          value                                                                     |
    s3_4_2 dist                                                                     |
          value                                                                     |
sample_5_2 dist                                                                     |
          value               3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1   1 |
    s1_5_2 dist                                                                     |
          value                                                                     |
    s2_5_2 dist                                                                     |
          value                                                                     |
    s3_5_2 dist                                                                     |
          value                                                                     |
sample_6_2 dist                                                                     |
          value             3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1   1 |
    s1_6_2 dist                                                                     |
          value                                                                     |
    s2_6_2 dist                                                                     |
          value                                                                     |
    s3_6_2 dist                                                                     |
          value                                                                     |
sample_7_2 dist                                                                     |
          value           3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1   1 |
    s1_7_2 dist                                                                     |
          value                                                                     |
    s2_7_2 dist                                                                     |
          value                                                                     |
    s3_7_2 dist                                                                     |
          value                                                                     |
sample_8_2 dist                                                                     |
          value         3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1   1 |
    s1_8_2 dist                                                                     |
          value                                                                     |
    s2_8_2 dist                                                                     |
          value                                                                     |
    s3_8_2 dist                                                                     |
          value                                                                     |
sample_9_2 dist                                                                     |
          value       3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1   1 |
    s1_9_2 dist                                                                     |
          value                                                                     |
    s2_9_2 dist                                                                     |
          value                                                                     |
    s3_9_2 dist                                                                     |
          value                                                                     |
sample_0_3 dist                                                                     |
          value     3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1   1 |
    s1_0_3 dist                                                                     |
          value                                                                     |
    s2_0_3 dist                                                                     |
          value                                                                     |
    s3_0_3 dist                                                                     |
          value                                                                     |
sample_1_3 dist                                                                     |
          value   3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1   1 |
    s1_1_3 dist                                                                     |
          value                                                                     |
    s2_1_3 dist                                                                     |
          value                                                                     |
    s3_1_3 dist                                                                     |
          value                                                                     |
sample_2_3 dist                                                                     |
          value 3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1   1 |

Warmup:   6%|▏ | 64/1000 [00:15, 36.76it/s, step size=2.96e-08, acc. prob=0.743]

In [435]:
mcmc.summary()


                mean       std    median      5.0%     95.0%     n_eff     r_hat
        r1     -1.08      0.00     -1.08     -1.08     -1.08     11.09      1.04
  s1_start      9.50      0.00      9.50      9.50      9.50      0.50       nan
  s2_start      9.00      0.00      9.00      9.00      9.00      0.58      1.00
  s3_start      9.00      0.00      9.00      9.00      9.00      0.50       nan

Number of divergences: 446


In [588]:
def model(data):
    
    ### Latent model parameters. Interested in infering rate
    #with pyro.plate("latent_rates"):
        
    ## Rate parameter
    #rate = pyro.param("rate", torch.tensor(.99))#,
#                     constraint=pyro.distributions.constraints.positive)
    rate = pyro.param("rate", torch.tensor(.99),
                     constraint=pyro.distributions.constraints.positive)
    r1 = pyro.sample("r1", pyro.distributions.Normal(rate, .01))

    ## starting points
    s1_start = pyro.sample("s1_start", pyro.distributions.Normal(10, .00001))
    s2_start = pyro.sample("s2_start", pyro.distributions.Normal(10, .00001))
    s3_start = pyro.sample("s3_start", pyro.distributions.Normal(10, .00001))

    for i in range(data.shape[0]):
        
        for t in range(data.shape[1]):

            ## Gillespie sampling
            with pyro.plate("latent_discrete"):
                sample = pyro.sample("sample_{0}_{1}".format(str(t), str(i)), 
                                         pyro.distributions.Categorical(
                                             torch.tensor([s1_start * r1, 
                                                           s2_start * .1, 
                                                           s3_start * .9])),
                                         infer={"enumerate": "parallel"})

            ## Update species
            with pyro.plate("update_species"):
                update = np.zeros(3)
                update[sample] = 1

                s1_temp = pyro.sample("s1_{0}_{1}".format(str(t), str(i)), 
                                   pyro.distributions.Normal(s1_start + update[0], .00001), 
                                   obs = data[i, t, 0])
                s2_temp = pyro.sample("s2_{0}_{1}".format(str(t), str(i)), 
                                   pyro.distributions.Normal(s2_start + update[1], .00001), 
                                   obs = data[i, t, 1])
                s3_temp = pyro.sample("s3_{0}_{1}".format(str(t), str(i)), 
                                   pyro.distributions.Normal(s3_start + update[2], .00001), 
                                   obs = data[i, t, 2])
            # batch[t] = torch.tensor([s1_temp, s2_temp, s3_temp])
            ## Update time
            # t_temp =  pyro.sample("t_{0}_temp".format(str(t)), 
            #                       pyro.distributions.Exponential(np.sum([s1_t, s2_t, s3_t])))
            # t_t =  pyro.sample("t_{0}".format(str(t)), pyro.distributions.Normal(t_prev + t_temp, .00001),
            #                    obs = data[t,3])

            s1_start = s1_temp
            s2_start = s2_temp
            s3_start = s3_temp
            # t_prev = t_t
            
    return s1_start


In [560]:
def guide(data):
    
    ## Latent model parameters. Interested in infering rate
    #with pyro.plate("latent_rates"):
        
    ## Rate parameter
    rate = pyro.param("rate", torch.tensor(.99),
                     constraint=pyro.distributions.constraints.positive)

    pyro.sample("r1", pyro.distributions.Normal(rate, .001))

    ## starting points
    s1_start = pyro.sample("s1_start", pyro.distributions.Normal(10, .00001))
    s2_start = pyro.sample("s2_start", pyro.distributions.Normal(10, .00001))
    s3_start = pyro.sample("s3_start", pyro.distributions.Normal(10, .00001))

    for i in range(data.shape[0]):
        
        for t in range(data.shape[1]):

            ## Gillespie sampling
            #with pyro.plate("latent_discrete"):
            pyro.sample("sample_{0}_{1}".format(str(t), str(i)), 
                                     pyro.distributions.Categorical(
                                         torch.tensor([s1_start * r1, 
                                                       s2_start * .1, 
                                                       s3_start * .9])))#,
                                    # infer={"enumerate": "parallel"})

            

In [555]:
guide = pyro.infer.autoguide.AutoNormal(pyro.poutine.block(model, expose=["r1"]))

pyro.clear_param_store()
elbo = pyro.infer.TraceEnum_ELBO(max_plate_nesting=1)
elbo.loss(model, guide, data)

ValueError: number of dimensions must be within [0, 32]
   Trace Shapes:                                                                    
    Param Sites:                                                                    
   Sample Sites:                                                                    
         r1 dist                                                                   |
           value                                                                   |
   s1_start dist                                                                   |
           value                                                                   |
   s2_start dist                                                                   |
           value                                                                   |
   s3_start dist                                                                   |
           value                                                                   |
 sample_0_0 dist                                                                   |
           value                                                               3 1 |
     s1_0_0 dist                                                                   |
           value                                                                   |
     s2_0_0 dist                                                                   |
           value                                                                   |
     s3_0_0 dist                                                                   |
           value                                                                   |
 sample_1_0 dist                                                                   |
           value                                                             3 1 1 |
     s1_1_0 dist                                                                   |
           value                                                                   |
     s2_1_0 dist                                                                   |
           value                                                                   |
     s3_1_0 dist                                                                   |
           value                                                                   |
 sample_2_0 dist                                                                   |
           value                                                           3 1 1 1 |
     s1_2_0 dist                                                                   |
           value                                                                   |
     s2_2_0 dist                                                                   |
           value                                                                   |
     s3_2_0 dist                                                                   |
           value                                                                   |
 sample_3_0 dist                                                                   |
           value                                                         3 1 1 1 1 |
     s1_3_0 dist                                                                   |
           value                                                                   |
     s2_3_0 dist                                                                   |
           value                                                                   |
     s3_3_0 dist                                                                   |
           value                                                                   |
 sample_4_0 dist                                                                   |
           value                                                       3 1 1 1 1 1 |
     s1_4_0 dist                                                                   |
           value                                                                   |
     s2_4_0 dist                                                                   |
           value                                                                   |
     s3_4_0 dist                                                                   |
           value                                                                   |
 sample_5_0 dist                                                                   |
           value                                                     3 1 1 1 1 1 1 |
     s1_5_0 dist                                                                   |
           value                                                                   |
     s2_5_0 dist                                                                   |
           value                                                                   |
     s3_5_0 dist                                                                   |
           value                                                                   |
 sample_6_0 dist                                                                   |
           value                                                   3 1 1 1 1 1 1 1 |
     s1_6_0 dist                                                                   |
           value                                                                   |
     s2_6_0 dist                                                                   |
           value                                                                   |
     s3_6_0 dist                                                                   |
           value                                                                   |
 sample_7_0 dist                                                                   |
           value                                                 3 1 1 1 1 1 1 1 1 |
     s1_7_0 dist                                                                   |
           value                                                                   |
     s2_7_0 dist                                                                   |
           value                                                                   |
     s3_7_0 dist                                                                   |
           value                                                                   |
 sample_8_0 dist                                                                   |
           value                                               3 1 1 1 1 1 1 1 1 1 |
     s1_8_0 dist                                                                   |
           value                                                                   |
     s2_8_0 dist                                                                   |
           value                                                                   |
     s3_8_0 dist                                                                   |
           value                                                                   |
 sample_9_0 dist                                                                   |
           value                                             3 1 1 1 1 1 1 1 1 1 1 |
     s1_9_0 dist                                                                   |
           value                                                                   |
     s2_9_0 dist                                                                   |
           value                                                                   |
     s3_9_0 dist                                                                   |
           value                                                                   |
sample_10_0 dist                                                                   |
           value                                           3 1 1 1 1 1 1 1 1 1 1 1 |
    s1_10_0 dist                                                                   |
           value                                                                   |
    s2_10_0 dist                                                                   |
           value                                                                   |
    s3_10_0 dist                                                                   |
           value                                                                   |
sample_11_0 dist                                                                   |
           value                                         3 1 1 1 1 1 1 1 1 1 1 1 1 |
    s1_11_0 dist                                                                   |
           value                                                                   |
    s2_11_0 dist                                                                   |
           value                                                                   |
    s3_11_0 dist                                                                   |
           value                                                                   |
sample_12_0 dist                                                                   |
           value                                       3 1 1 1 1 1 1 1 1 1 1 1 1 1 |
    s1_12_0 dist                                                                   |
           value                                                                   |
    s2_12_0 dist                                                                   |
           value                                                                   |
    s3_12_0 dist                                                                   |
           value                                                                   |
sample_13_0 dist                                                                   |
           value                                     3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 |
    s1_13_0 dist                                                                   |
           value                                                                   |
    s2_13_0 dist                                                                   |
           value                                                                   |
    s3_13_0 dist                                                                   |
           value                                                                   |
sample_14_0 dist                                                                   |
           value                                   3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 |
    s1_14_0 dist                                                                   |
           value                                                                   |
    s2_14_0 dist                                                                   |
           value                                                                   |
    s3_14_0 dist                                                                   |
           value                                                                   |
sample_15_0 dist                                                                   |
           value                                 3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 |
    s1_15_0 dist                                                                   |
           value                                                                   |
    s2_15_0 dist                                                                   |
           value                                                                   |
    s3_15_0 dist                                                                   |
           value                                                                   |
sample_16_0 dist                                                                   |
           value                               3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 |
    s1_16_0 dist                                                                   |
           value                                                                   |
    s2_16_0 dist                                                                   |
           value                                                                   |
    s3_16_0 dist                                                                   |
           value                                                                   |
sample_17_0 dist                                                                   |
           value                             3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 |
    s1_17_0 dist                                                                   |
           value                                                                   |
    s2_17_0 dist                                                                   |
           value                                                                   |
    s3_17_0 dist                                                                   |
           value                                                                   |
sample_18_0 dist                                                                   |
           value                           3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 |
    s1_18_0 dist                                                                   |
           value                                                                   |
    s2_18_0 dist                                                                   |
           value                                                                   |
    s3_18_0 dist                                                                   |
           value                                                                   |
sample_19_0 dist                                                                   |
           value                         3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 |
    s1_19_0 dist                                                                   |
           value                                                                   |
    s2_19_0 dist                                                                   |
           value                                                                   |
    s3_19_0 dist                                                                   |
           value                                                                   |
sample_20_0 dist                                                                   |
           value                       3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 |
    s1_20_0 dist                                                                   |
           value                                                                   |
    s2_20_0 dist                                                                   |
           value                                                                   |
    s3_20_0 dist                                                                   |
           value                                                                   |
sample_21_0 dist                                                                   |
           value                     3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 |
    s1_21_0 dist                                                                   |
           value                                                                   |
    s2_21_0 dist                                                                   |
           value                                                                   |
    s3_21_0 dist                                                                   |
           value                                                                   |
sample_22_0 dist                                                                   |
           value                   3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 |
    s1_22_0 dist                                                                   |
           value                                                                   |
    s2_22_0 dist                                                                   |
           value                                                                   |
    s3_22_0 dist                                                                   |
           value                                                                   |
sample_23_0 dist                                                                   |
           value                 3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 |
    s1_23_0 dist                                                                   |
           value                                                                   |
    s2_23_0 dist                                                                   |
           value                                                                   |
    s3_23_0 dist                                                                   |
           value                                                                   |
sample_24_0 dist                                                                   |
           value               3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 |
    s1_24_0 dist                                                                   |
           value                                                                   |
    s2_24_0 dist                                                                   |
           value                                                                   |
    s3_24_0 dist                                                                   |
           value                                                                   |
sample_25_0 dist                                                                   |
           value             3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 |
    s1_25_0 dist                                                                   |
           value                                                                   |
    s2_25_0 dist                                                                   |
           value                                                                   |
    s3_25_0 dist                                                                   |
           value                                                                   |
sample_26_0 dist                                                                   |
           value           3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 |
    s1_26_0 dist                                                                   |
           value                                                                   |
    s2_26_0 dist                                                                   |
           value                                                                   |
    s3_26_0 dist                                                                   |
           value                                                                   |
sample_27_0 dist                                                                   |
           value         3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 |
    s1_27_0 dist                                                                   |
           value                                                                   |
    s2_27_0 dist                                                                   |
           value                                                                   |
    s3_27_0 dist                                                                   |
           value                                                                   |
sample_28_0 dist                                                                   |
           value       3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 |
    s1_28_0 dist                                                                   |
           value                                                                   |
    s2_28_0 dist                                                                   |
           value                                                                   |
    s3_28_0 dist                                                                   |
           value                                                                   |
sample_29_0 dist                                                                   |
           value     3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 |
    s1_29_0 dist                                                                   |
           value                                                                   |
    s2_29_0 dist                                                                   |
           value                                                                   |
    s3_29_0 dist                                                                   |
           value                                                                   |
sample_30_0 dist                                                                   |
           value   3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 |
    s1_30_0 dist                                                                   |
           value                                                                   |
    s2_30_0 dist                                                                   |
           value                                                                   |
    s3_30_0 dist                                                                   |
           value                                                                   |
sample_31_0 dist                                                                   |
           value 3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 |

In [590]:
# set up the optimizer
adam_params = {"lr": 0.05, "betas": (0.90, 0.999)}
optimizer = pyro.optim.Adam(adam_params)
guide = pyro.infer.autoguide.AutoDelta(pyro.poutine.block(model, 
            hide=["sample_{0}_{1}".format(str(x), str(y)) for x in range(100) for y in range(100)]))#, hide=['sample_0_0'])
#guide = pyro.infer.autoguide.AutoNormal(pyro.poutine.block(model, expose=["r1"]))

# setup the inference algorithm
svi = pyro.infer.SVI(model, guide, optimizer, loss=pyro.infer.Trace_ELBO())

n_steps = 100
pyro.clear_param_store()
# do gradient steps
for step in range(n_steps):
    svi.step(data)
    if step % 10 == 0:
        print('.', end='')




..........

In [None]:
pyro.param("rate").item()