In [6]:
import torch
import pyro
import pyro.distributions as dist

import numpy as np
import matplotlib.pyplot as plt
from collections import Counter

In [59]:
def samples_hist(pdf, n=100):
  labels, vals = zip(*sorted(Counter(
    int(pyro.sample('samples_hist', pdf)) for _ in range(n)).items()))

  width = 1
  inds = np.arange(len(labels))
  plt.bar(inds, vals, width)
  plt.xticks(inds, labels)
  plt.show()

In [142]:
class population_history:
  def __init__(self, labels):
    self.labels = labels
    self.states = []
  
  def append(self, *args):
    self.states.append(args)

  def __repr__(self):
    N_PRINT = 10
    more_states = len(self.states) - N_PRINT - 1
    more_str = f'\n ... ({more_states} more)' if more_states > 0 else ''      
    return '\n'.join(
      ', '.join(f'{l}: {s[ind]}' for ind, l in enumerate(self.labels)
        ) for s in self.states[:-N_PRINT-1:-1]) + more_str

class population_state:
  def __init__(self):
    self.infected = int(pyro.sample('init', dist.Poisson(5)))
    self.dead = 0
    self.recovered = 0
    self.history = population_history(['i', 'd', 'r'])
    self.history.append(self.infected, self.dead, self.recovered)
    
    self.r_infect = 0.2 # daily growth of 20%
    self.r_die = 0.001 # 2% risk after 20 days -> .1% risk/day
    self.r_recover = 0.049 # 98% recovery after 20 days -> .049% risk/day
    
  def __repr__(self):
    return f'i: {self.infected}, d: {self.dead}, r: {self.recovered}'
    
  def step(self, n=1):
    for _ in range(n):
      def samp_infected(r, label):
        return int(pyro.sample(label, dist.Poisson(r * self.infected)))
      new_infected = samp_infected(self.r_infect, 'i')
      new_dead = samp_infected(self.r_die, 'd')
      new_recovered = samp_infected(self.r_recover, 'r')
#       print(f'new i: {new_infected}, d: {new_dead}, r: {new_recovered}')

      self.infected += new_infected - new_dead - new_recovered
      self.dead += new_dead
      self.recovered += new_recovered
      self.history.append(self.infected, self.dead, self.recovered)

In [143]:
s = population_state()
s.step(11)
s.history

i: 35, d: 0, r: 8
i: 30, d: 0, r: 6
i: 27, d: 0, r: 4
i: 26, d: 0, r: 3
i: 18, d: 0, r: 3
i: 15, d: 0, r: 2
i: 13, d: 0, r: 1
i: 10, d: 0, r: 1
i: 10, d: 0, r: 0
i: 8, d: 0, r: 0
 ... (1 more)

In [144]:
# template from SVI Tutorial Part 3

import torch
import torch.distributions.constraints as constraints
import pyro
import pyro.distributions as dist
import pyro.optim as optim
from pyro.infer import SVI, TraceGraph_ELBO

def param_abs_error(name, target):
    return torch.sum(torch.abs(target - pyro.param(name))).item()

class Inference:
  def __init__(self, data, max_steps = 10000):
    self.max_steps = max_steps
    self.data = data
    self.n_data = self.data.size(0)
    # declare parameters
    # declare initial values for guide params

  def model(self):
    pass
    # declare the model
  
  def guide(self):
    pass
    # declare the guide

  def do_inference(self, tol=0.8):
    pyro.clear_param_store()
    optimizer = optim.Adam({'lr': .0005, 'betas': (0.93, 0.999)})
    svi = SVI(self.model, self.guide, optimizer, loss=TraceGraph_ELBO())
    print('Doing inference:')
    
    for k in range(self.max_steps):
      svi.step()
      if k % 100 == 0:
        print('.', end='')
        
      # compute some error using param_abs_error
      
      # stop inference early if error is small
      # if some_error < tol:
      #   break
      
    print(f'\nDid {k} steps of inference.')