In [6]:
import torch
import pyro
import pyro.distributions as dist

import numpy as np
import matplotlib.pyplot as plt
from collections import Counter

In [59]:
def samples_hist(pdf, n=100):
  labels, vals = zip(*sorted(Counter(
    int(pyro.sample('samples_hist', pdf)) for _ in range(n)).items()))

  width = 1
  inds = np.arange(len(labels))
  plt.bar(inds, vals, width)
  plt.xticks(inds, labels)
  plt.show()

In [137]:
class population_history:
  def __init__(self, labels):
    self.labels = labels
    self.states = []
  
  def append(self, *args):
    self.states.append(args)

  def __repr__(self):
    N_PRINT = 10
    more_states = len(self.states) - N_PRINT - 1
    more_str = f'\n ... ({more_states} more)' if more_states > 0 else ''      
    return '\n'.join(
      ', '.join(f'{l}: {s[ind]}' for ind, l in enumerate(self.labels)
        ) for s in self.states[:-N_PRINT-1:-1]) + more_str

class population_state:
  def __init__(self):
    self.infected = int(pyro.sample('inf_init', dist.Poisson(5)))
    self.dead = 0
    self.recovered = 0
    self.history = population_history(['i', 'd', 'r'])
    self.history.append(self.infected, self.dead, self.recovered)
    
    self.r_infect = 0.2 # daily growth of 20%
    self.r_die = 0.001 # 2% risk after 20 days -> .1% risk/day
    self.r_recover = 0.049 # 98% recovery after 20 days -> .049% risk/day
    
  def __repr__(self):
    return f'i: {self.infected}, d: {self.dead}, r: {self.recovered}'
    
  def step(self, n=1):
    for _ in range(n):
      def samp_infected(r):
        return int(pyro.sample(f'inf_{r}', dist.Poisson(r * self.infected)))
      new_infected = samp_infected(self.r_infect)
      new_dead = samp_infected(self.r_die)
      new_recovered = samp_infected(self.r_recover)
#       print(f'new i: {new_infected}, d: {new_dead}, r: {new_recovered}')

      self.infected += new_infected - new_dead - new_recovered
      self.dead += new_dead
      self.recovered += new_recovered
      self.history.append(self.infected, self.dead, self.recovered)

In [138]:
s = population_state()
s.step(60)
s.history

i: 46599, d: 298, r: 15461
i: 40529, d: 262, r: 13499
i: 35146, d: 223, r: 11739
i: 30628, d: 179, r: 10245
i: 26606, d: 157, r: 8968
i: 23169, d: 138, r: 7798
i: 20232, d: 126, r: 6787
i: 17691, d: 110, r: 5879
i: 15430, d: 93, r: 5111
i: 13427, d: 84, r: 4440
 ... (50 more)