In [66]:
import torch
import torch.distributions as tdist
import torch.nn.functional as F 
import pyro
import pyro.distributions as dist
from torch.distributions import constraints
import numpy as np
import matplotlib.pyplot as plt

##### Model Wright-Fisher

In [92]:
def model0(init_i, n=0, states=torch.tensor([])):
    Ns = torch.sum(init_i, -1).int()
    # Ns = N * torch.ones(len(init_i.shape))
    '''
    if not all(torch.sum(init_i, -1).int() == Ns.int()):
        print(init_i)
        print(torch.sum(init_i, -1))
        print(Ns)
        print(torch.sum(init_i, -1) == Ns)
        raise(BaseException("wrong args"))
    '''
    if n <= 0:
        return(states)
    n -= 1
    i = torch.tensor(init_i).float()
    # print((i.T/Ns).T)
    # print(Ns.unsqueeze(1))
    dd = tdist.Binomial(Ns.unsqueeze(1), (i.T/Ns).T)
    x = (Ns * F.normalize(dd.sample(), p=1, dim=-1).T).T
    
    # print("x = ", x)
    # print("prob(x_t|x_t-1)= ", torch.exp(dd.log_prob(x)))
    return(model0(x, n=n,
                  states=torch.cat((states, torch.unsqueeze(x, 0)), 0)))

#### Test Wright-Fisher

In [93]:
# states = model0(torch.tensor([[2, 1, 1, 1, 1, 1, 1, 1, 1],
#                               [1, 1, 1, 1, 1, 1, 1, 3, 0]]), 10, 70)
states = model0(torch.tensor([[50000, 50000] for i in range(100)]), 1000)



In [94]:
# states[:,0,:].T
states.shape

torch.Size([1000, 100, 2])

In [95]:
# (states[0].T/torch.sum(states[0], -1).int()).T
# (states[:,i,:].T/torch.sum(states[:,i,:],-1).int())[0].shape

#### Results Wright-Fisher

In [101]:
%matplotlib
plt.ylim(0, 1)
allele = 0
for i in range(states.shape[-2]):
    plt.plot((states[:,i,:].T/torch.sum(states[:,i,:],-1).int())[allele])

Using matplotlib backend: Qt5Agg


In [100]:
%matplotlib
plt.ylim(0, 100000)
allele = 0
for i in range(states.shape[-2]):
    plt.plot(states[:,i,:].T[allele])

Using matplotlib backend: Qt5Agg


In [99]:
%matplotlib
gen=0
for i in range(states.shape[-1]):
    plt.plot(states[:,gen,:].T[i])


Using matplotlib backend: Qt5Agg


##### Model SEIR

In [71]:
def model1(T=3, dt=0.01, S0=997, E0=0, I0=3, R0=0,
           beta=1.5, eps=0.35, gamma=0.035, mu=0.005,
           states=torch.tensor([])):
    '''Assuming birth=death'''
    if T==0:
        return(states)
        # return(S0, E0, I0, R0, states)
    
    N = S0+E0+I0+R0
    S = dt*(mu*N-mu*S0-beta*I0*S0/N)+S0
    E = dt*(beta*S0*I0/N-(eps+mu)*E0)+E0
    I = dt*(eps*E0-(gamma+mu)*I0)+I0
    R = dt*(gamma*I0 - mu*R0)+R0
    T-=1
    # print("N=",S+E+I+R)
    # print(S, E, I, R)
    return(model1(T, dt=dt, S0=S, E0=E, I0=I, R0=R,
                  beta=beta, eps=eps, gamma=gamma, mu=mu,
                 states=torch.cat((states,
                                   torch.unsqueeze(torch.tensor([S, E, I, R]), 0)), 0)))

##### Test SEIR

In [84]:
# for deep recursion avoidance:
states = torch.tensor([[997, 0, 3, 0]])
for i in range(7):
    res = model1(1000, S0=float(states[-1][0]), E0=float(states[-1][1]),
                 I0=float(states[-1][2]), R0=float(states[-1][3]))
    states = torch.cat((states, res),0)
states.shape

torch.Size([7001, 4])

#### Results SEIR

In [91]:
%matplotlib
plt.ylim(0, 1000)

for i in range(states.shape[-1]):
    plt.plot(states.T[i], label=['S', 'E', 'I', 'R'][i])
plt.legend(loc="upper left")

Using matplotlib backend: Qt5Agg


<matplotlib.legend.Legend at 0x7fe44fcad4e0>

##### Appendix

In [26]:
pyro.clear_param_store()
# for vectorized, sampled, depended:
def model1(init_i, N):
    # mu = torch.tensor(0.5)
    # sigma = torch.tensor(0.1)
    # p = pyro.sample("latent_fairness", dist.Normal(mu, sigma))
    # print("p = ", p)
    i = init_i
    
    # vectorized, sampled, dependent:
    p = i/N
    # dd = dist.Bernoulli(p).expand([7, 2]).to_event(1)
    x = dd.sample()
    print("x= ", x)
    print("prob(x)= ", torch.exp(dd.log_prob(x)))

In [27]:
model1(3, 7)

x=  tensor([[0., 0.],
        [1., 0.],
        [1., 0.],
        [0., 1.],
        [0., 0.],
        [0., 0.],
        [0., 0.]])
prob(x)=  tensor([0.3265, 0.2449, 0.2449, 0.2449, 0.3265, 0.3265, 0.3265])


In [20]:
pyro.clear_param_store()
# for vectorized, sampled, depended:
def model(init_i, N):
    # mu = torch.tensor(0.5)
    # sigma = torch.tensor(0.1)
    # p = pyro.sample("latent_fairness", dist.Normal(mu, sigma))
    # print("p = ", p)
    i = init_i
    
    # vectorized, sampled, dependent:
    with pyro.plate("data_loop", size=3, subsample_size=2) as ind:
        p = i/N
        dd = dist.Bernoulli(p).expand([7, 2]).to_event(1)
        print("dd.batch_shape:")
        print(dd.batch_shape)
        print("dd.event_shape:")
        print(dd.event_shape)
        x = dd.sample()
        print("x = ", x)
        print("prob(x) = ", torch.exp(dd.log_prob(x)))
        # print("accurate: ",
        #       torch.tensor([(p if x0 else 1-p)*(p if x1 else 1-p)
        #                     for x0, x1 in x]))

        y = pyro.sample("y", dd)
        print("y = ", y)
        print("ind:")
        print(ind)
        
        # a = pyro.sample("obs", dd, obs=data.index_select(0, ind))
        # print("a:")
        # print(a)

In [21]:
model(3,7)

dd.batch_shape:
torch.Size([7])
dd.event_shape:
torch.Size([2])
x =  tensor([[1., 0.],
        [0., 1.],
        [0., 0.],
        [0., 1.],
        [1., 0.],
        [1., 0.],
        [0., 1.]])
prob(x) =  tensor([0.2449, 0.2449, 0.3265, 0.2449, 0.2449, 0.2449, 0.2449])


ValueError: Shape mismatch inside plate('data_loop') at site y dim -1, 2 vs 7

In [28]:
x_axis = pyro.plate("x", 3, dim=-1)
y_axis = pyro.plate("y", 2, dim=-2)
with x_axis:
    x = pyro.sample("x", dist.Normal(0, 1))
    # this dont work here because of plate:
    # x = pyro.sample("x", dist.Normal(0, 1).expand([5, 2]).to_event(1))
with y_axis:
    y = pyro.sample("y", dist.Normal(0, 1))
print("x: ", x)
print("y: ", y)

x:  tensor([-0.5929,  0.3222,  0.2136])
y:  tensor([[0.8491],
        [1.8800]])
