In [96]:
import torch
import torch.distributions as tdist
import torch.nn.functional as F 
import pyro
import pyro.distributions as dist
from torch.distributions import constraints
import math
import numpy as np
import matplotlib.pyplot as plt


##### Model
helpers:

In [2]:
alleles = ['A', 'C', 'G', 'T', '_']

def to_allele(idxs):
    return([alleles[int(idx)] for idx in idxs])

def from_allele(name):
    return(float(alleles.index(name)))


init:

In [3]:
def init_seq(size, p_init_seq=torch.tensor([1/4, 1/4, 1/4, 1/4])):
    return(tdist.Categorical(
        p_init_seq.expand(size, -1)).sample())

def init_trans_seq(size, p_init_trans_seq=1/2):
    return(tdist.Bernoulli(
        p_init_trans_seq * torch.ones(size)).sample())

basic

In [4]:
def edit_time(seq_size, t, alpha1):
    '''
    dependence on site index in seq:
    time scale at each site with Gamma distribution
    (Yang [1993])
    
    Inputs:
     - ``t`` and ``alpha1`` -- describe time and variance
     ex:  t: 0.01, alpha1: 0.1
    
    '''
    gd = tdist.Gamma(torch.arange(1, seq_size+1).float(),
                     alpha1*torch.ones(seq_size))

    rr = gd.sample()
    # print("gamma rr = ", rr)
    # print("t original = ", t)
    ts = rr*t/gd.mean
    # print("ts = ", ts)
    return(ts)


def high_level(seq_size, ltrans_seq, probs):
    
    '''Dependence from previus seq, not from seq shift!
    So p(M|M) means P(seq_{i}=M|last_seq_{i}=M)
    
    Inputs:
    - ``ltrans_seq`` -- last transition seq
    - ``probs`` -- dict with keys "M|M", "M|D"
    ex: {"M|M":0.9, "M|D": 0.7}
    '''
    
    trans_seq = ltrans_seq.detach().clone()
    # 1 means M:
    cond_M = ltrans_seq == 1
    
    # TODO: dependence on previus state
    # <M, D>|M:
    trans_seq[cond_M] = tdist.Bernoulli(
        probs["M|M"]*torch.ones(seq_size)[cond_M]).sample()
    
    # <M, D>|D
    trans_seq[torch.logical_not(cond_M)] = tdist.Bernoulli(
        probs["M|D"]*torch.ones(seq_size)[torch.logical_not(cond_M)]).sample()
    
    return(trans_seq)


def low_level(lseq, trans_seq, alpha, ts, prob_del, Dn):
    '''
    Compute emissions for M and D states (from `trans_seq`):
        P(<A, C, G, T>_{i}|<A, C, G, T>_{i-1} , M) or P("_"|D)
    if M state given use time scale Yang [1993]
        ts
    and Jukes-Cantor substitution matrix:
        S = torch.tensor([[rt, st, st, st],
                          [st, rt, st, st],
                          [st, st, rt, st],
                          [st, st, st, rt]])
    if D state given set binomial(Dn, prob_del) for
    each D_{i} to generate subsequence of "_" in each i
    (`i` taken from `trans_seq`)
    
    Inputs:
    - ``prob_del`` and ``Dn`` used for subsequence of "_":
    binomial(Dn, prob_del)
    ex: prob_del: 0.1, Dn: 3
    
    - ``alpha`` and ``ts`` used for `rt` and `ts` parameters
    of `S` matrix.
    
    '''
    rt = torch.unsqueeze(1/4*(1+3*np.exp(-4*alpha*ts)), 0)
    st = torch.unsqueeze(1/4*(1-np.exp(-4*alpha*ts)), 0)
    # print("rt, st = ", (rt, st))

    
    # FOR M:
    cond_trans_M = trans_seq==1
    # print("cond_trans_M:")
    # print(cond_trans_M)
    cond_A = lseq == from_allele("A")
    probs_A = [rt, st, st, st]
    
    cond_C = lseq == from_allele("C")
    probs_C = [st, rt, st, st]
    
    cond_G = lseq == from_allele("G")
    probs_G = [st, st, rt, st]
    
    cond_T = lseq == from_allele("T")
    probs_T = [st, st, st, rt]
    
    cond_del = lseq == from_allele("_")
    probs_del = [1/4 * torch.unsqueeze(torch.ones(seq_size), 0), 
                 1/4 * torch.unsqueeze(torch.ones(seq_size), 0),
                 1/4 * torch.unsqueeze(torch.ones(seq_size), 0),
                 1/4 * torch.unsqueeze(torch.ones(seq_size), 0)]
    
    seq = lseq.detach().clone()
    
    emissions = [(cond_A, probs_A), (cond_C, probs_C),
                 (cond_G, probs_G), (cond_T, probs_G),
                 (cond_T, probs_T), (cond_del, probs_del)
                ]
    for cond_word, probs_word in emissions:
        cond_M_A = torch.logical_and(cond_trans_M, cond_word)
        size_M_A = len(lseq[cond_M_A])
        seq[cond_M_A] = tdist.Categorical(
            torch.cat(probs_word).T[cond_M_A]).sample()  # .expand(size_M_A, -1))
    # print("seq after M:")
    # print(seq)
    # END FOR
    
    # FOR D:
    cond_trans_D = trans_seq==0
    # print("cond_trans_D:")
    # print(cond_trans_D)
    idxs = (cond_trans_D).nonzero().flatten()
    # print("idxs:")
    # print(idxs)
    idxs_shift_sample = tdist.Binomial(Dn, prob_del * torch.ones(idxs.size()[0])).sample()
    # print("idxs_shift_sample:")
    # print(idxs_shift_sample)
    
    # add subsequences idxs to del:
    idxs1 = torch.cat((idxs, idxs+idxs_shift_sample),0).long()
    # print("idxs1:")
    # print(idxs1)
    
    # cut oversize, del subsequnces:
    seq[idxs1[idxs1<seq.size()[0]]] = from_allele("_")
    
    # print("seq after D:")
    # print(seq)
    # END FOR
      
    # cd = tdist.Categorical(S)
    # dwords = tdist.Categorical(S)
    return(seq)    

##### Single test:

In [5]:
# init:
seq_size = 100
ilseq = init_seq(seq_size)
iltseq = init_trans_seq(seq_size)
print("init seq:")
print("".join(to_allele(ilseq)))

# main
# transition:
ltseq = high_level(seq_size, iltseq, {"M|M":0.9, "M|D": 0.7})
# print("last transition seq:")
# print(ltseq)

# time scale Yang [1993]:
ts = edit_time(seq_size, 0.01, 0.1)
# print("ts:")
# print(ts)

# emission:
seq = low_level(ilseq, ltseq, 10.0, ts, 0.1, 3)
print("\nseq:")
print("".join(to_allele(seq)))
# trans_seq = 
# low_level(lseq, trans_seq, alpha, ts, 0.9)
# ltseq

init seq:
CATGCACCACCGGCCTTGGGGGGTAGACACATACCCAACGAGAGTTTCACCTACGCGTGGACATGTCTGGTATGACTAGGGAATGCTGGGCGGTGCGACT

seq:
C__GCA__CCC_GCC_TGGCGGATAGA_AC_CAGACAAC_A__GTTTCACCTAC__GT__C__GG__TGCG__GATTCA_GAGTG_TG____GTGGGTCT


	nonzero()
Consider using one of the following signatures instead:
	nonzero(*, bool as_tuple) (Triggered internally at  /pytorch/torch/csrc/utils/python_arg_parser.cpp:766.)


##### Iterations:

In [6]:
def iteration(lseq, ltseq, t=0.01, time_alpha=0.1, s_alpha=10.0):
    '''
    s_alpha = 10.0  # for transition matrix S
    time_alpha = 0.9  # for gamma dist
    t = 0.01  # timestep \in [0, 1]
    '''
    seq_size = lseq.size()[0]
    
    # transition:
    ltseq = high_level(seq_size, iltseq, {"M|M":0.9, "M|D": 0.7})
    # print("last transition seq:")
    # print(ltseq)

    # time scale Yang [1993]:
    ts = edit_time(seq_size, t, time_alpha)
    # print("ts:")
    # print(ts)

    # emission:
    seq = low_level(ilseq, ltseq, s_alpha, ts, 0.1, 3)
    # print("seq:")
    # print("".join(to_allele(seq)))
    return(seq, ltseq)

##### Test iterations

In [7]:
# init:
seq_size = 100
ilseq = init_seq(seq_size)
iltseq = init_trans_seq(seq_size)
print("init seq:")
print("".join(to_allele(ilseq)))
seq, tseq = ilseq, iltseq

for i in range(30):
    seq, tseq = iteration(seq, tseq)
print("\nseq:")
print("".join(to_allele(seq)))
    

init seq:
GGGTCCAAAAGCGGGGGAATTCCTATGGGTAAGATAGGCGCACGCTTCAGCGTCTCCGTAGAAGCCGCAGCGTTAAGGAGAATAACATAAAGTTCTTCTC

seq:
GAGGGC__AAGACAGGGAATCCCT_TGGGA__TAT__GTG_ATC__TCAGCG____T_TCG__A__GCAG_GTG_A_GAG_TTAA_ATC_AG_GCTT_TT


##### Loop for a tree:

In [8]:
import networkx as nx

In [236]:
def timer_succ(t):
    time_dist = tdist.Uniform(0, 0.1)
    t += float(time_dist.sample())
    return(t)
    
def loop(N, timer, net=None, nodes=[]):
    
    
    
    # init:
    if net is None:
        lseq = init_seq(seq_size,
                        p_init_seq=torch.tensor([1/4, 1/4, 1/4, 1/4]))
        ltseq = init_trans_seq(seq_size, p_init_trans_seq=1/2)
        net = nx.DiGraph()
        t = timer(0)
        
        net.add_node("s", seq="".join(to_allele(lseq)),
                     pos=(1, 1))
        nodes = [("s", lseq, ltseq, t)]
        return(loop(N-1, timer, net, nodes))
    
    # finish:
    if N <= 0:
        return(net)
    
    # main:
    # choice that with minimum time:
    nodes.sort(key=lambda x: x[-1])
    first = nodes.pop(0)
    pidx, lseq, ltseq, t = first
    lidx = pidx + "0"
    ridx = pidx + "1"
    p_pos = net.nodes[pidx]['pos']
    
    left_seq, left_tseq = iteration(lseq, ltseq)
    right_seq, right_tseq = iteration(lseq, ltseq)
    lt = timer(t)
    
    net.add_node(lidx, seq="".join(to_allele(left_seq)),
                 pos=(p_pos[0]+0.01*lt, p_pos[1]+lt*30+30/N))
    net.add_edge(pidx, lidx, time=lt)
    rt = timer(t)
    net.add_node(ridx, seq="".join(to_allele(right_seq)),
                pos=(p_pos[0]-0.01*rt, p_pos[1]+rt*30+30/N))
    net.add_edge(pidx, ridx, time=rt)
    nodes.append((lidx, left_seq, left_tseq, lt))
    nodes.append((ridx, right_seq, right_tseq, rt))
    return(loop(N-1, timer, net, nodes))

##### Test loop

In [237]:
net = loop(10, timer_succ)


In [238]:
# next(nx.dfs_edges(net, "0"))
net.nodes(data="pos")

NodeDataView({'s': (1, 1), 's0': (1.0008799329586326, 6.973132209231457), 's1': (0.9985589176416397, 8.656580408414206), 's00': (1.002474001608789, 15.505338159700235), 's01': (0.9994487079232931, 15.016807315250238), 's010': (1.0012430205754936, 24.685459557565906), 's011': (0.9971755396388471, 26.122026454302528), 's10': (1.0009172266721726, 20.731507500012718), 's11': (0.9963853763788939, 20.17720419665178), 's000': (1.004350966066122, 27.1362315316995), 's001': (1.0005370890535414, 27.316075825442873), 's0100': (1.0033903862163425, 38.62755648011253), 's0101': (0.9984582717716695, 40.539705969038465), 's0000': (1.0070593437179922, 45.26136448731025), 's0001': (1.0016993990913035, 45.09093245615562), 's0010': (1.0034183793142437, 50.95994660754999), 's0011': (0.9979447156190873, 50.09319612880548), 's01000': (1.0059723158739506, 76.37334545293734), 's01001': (1.0009087496064604, 76.07246630975888)}, data='pos')

In [239]:
%matplotlib
pos =  net.nodes(data="pos")
#
labels=nx.draw_networkx_labels(net, pos=pos)
# labels=nx.draw_networkx_labels(net,pos=nx.spring_layout(net))
edge_labels = dict([((u, v), "%.2f" % (float(c)))
                    for u, v, c in net.edges(data="time")])
edge_labels=nx.draw_networkx_edge_labels(net, pos=pos,
                                         edge_labels=edge_labels)
# nx.draw(net, pos=nx.spring_layout(net))

nx.draw(net, pos=pos)

Using matplotlib backend: Qt5Agg


##### Test Gamma:

In [92]:
alpha1 = 0.9  # for gamma dist
t = 1  # timestep
r = float(100)  # site number in seq
gd = tdist.Gamma(torch.tensor(r),torch.tensor(alpha1))
def gen(t):
    rr = gd.sample()
    return(rr)
    # return(rr*t/gd.mean)

# print("gamma rr = ", rr)
# print("t original = ", t)
# print("t = ", t)
data = [gen(t) for i in range(100)]

%matplotlib
plt.hist(data)

Using matplotlib backend: Qt5Agg


(array([ 5., 11.,  6., 12., 11., 16., 20.,  6.,  8.,  5.]),
 array([ 84.7612  ,  89.795235,  94.82926 ,  99.8633  , 104.89732 ,
        109.93136 , 114.96539 , 119.99942 , 125.033455, 130.06749 ,
        135.10152 ], dtype=float32),
 <a list of 10 Patch objects>)

##### Gamma distribution

In [43]:
dd = tdist.Gamma(torch.tensor(7.0),torch.tensor(0.9))
data = [int(dd.sample()) for i in range(1000)]
print(len(set(data)), len(data))
classes = set(data)
# print(classes)
print(dd.mean)
%matplotlib
# plt.figure(figsize=(10,1))
plt.hist(data, len(classes), # density=True, 
        # orientation='horizontal',
        stacked=True,
        rwidth=0.1, label = [str(c) for c in classes])


19 1000
tensor(7.7778)
Using matplotlib backend: Qt5Agg


(array([ 30.,  50., 102., 133., 150., 132., 113.,  87.,  57.,  80.,  33.,
         12.,   7.,   9.,   3.,   0.,   0.,   1.,   1.]),
 array([ 1.        ,  2.10526316,  3.21052632,  4.31578947,  5.42105263,
         6.52631579,  7.63157895,  8.73684211,  9.84210526, 10.94736842,
        12.05263158, 13.15789474, 14.26315789, 15.36842105, 16.47368421,
        17.57894737, 18.68421053, 19.78947368, 20.89473684, 22.        ]),
 <a list of 19 Patch objects>)

##### Model Wright-Fisher

In [92]:
def model0(init_i, n=0, states=torch.tensor([])):
    Ns = torch.sum(init_i, -1).int()
    # Ns = N * torch.ones(len(init_i.shape))
    '''
    if not all(torch.sum(init_i, -1).int() == Ns.int()):
        print(init_i)
        print(torch.sum(init_i, -1))
        print(Ns)
        print(torch.sum(init_i, -1) == Ns)
        raise(BaseException("wrong args"))
    '''
    if n <= 0:
        return(states)
    n -= 1
    i = torch.tensor(init_i).float()
    # print((i.T/Ns).T)
    # print(Ns.unsqueeze(1))
    dd = tdist.Binomial(Ns.unsqueeze(1), (i.T/Ns).T)
    x = (Ns * F.normalize(dd.sample(), p=1, dim=-1).T).T
    
    # print("x = ", x)
    # print("prob(x_t|x_t-1)= ", torch.exp(dd.log_prob(x)))
    return(model0(x, n=n,
                  states=torch.cat((states, torch.unsqueeze(x, 0)), 0)))

#### Test Wright-Fisher

In [93]:
# states = model0(torch.tensor([[2, 1, 1, 1, 1, 1, 1, 1, 1],
#                               [1, 1, 1, 1, 1, 1, 1, 3, 0]]), 10, 70)
states = model0(torch.tensor([[50000, 50000] for i in range(100)]), 1000)



In [94]:
# states[:,0,:].T
states.shape

torch.Size([1000, 100, 2])

In [95]:
# (states[0].T/torch.sum(states[0], -1).int()).T
# (states[:,i,:].T/torch.sum(states[:,i,:],-1).int())[0].shape

#### Results Wright-Fisher

In [101]:
%matplotlib
plt.ylim(0, 1)
allele = 0
for i in range(states.shape[-2]):
    plt.plot((states[:,i,:].T/torch.sum(states[:,i,:],-1).int())[allele])

Using matplotlib backend: Qt5Agg


In [100]:
%matplotlib
plt.ylim(0, 100000)
allele = 0
for i in range(states.shape[-2]):
    plt.plot(states[:,i,:].T[allele])

Using matplotlib backend: Qt5Agg


In [99]:
%matplotlib
gen=0
for i in range(states.shape[-1]):
    plt.plot(states[:,gen,:].T[i])


Using matplotlib backend: Qt5Agg


##### Model SEIR

In [71]:
def model1(T=3, dt=0.01, S0=997, E0=0, I0=3, R0=0,
           beta=1.5, eps=0.35, gamma=0.035, mu=0.005,
           states=torch.tensor([])):
    '''Assuming birth=death'''
    if T==0:
        return(states)
        # return(S0, E0, I0, R0, states)
    
    N = S0+E0+I0+R0
    S = dt*(mu*N-mu*S0-beta*I0*S0/N)+S0
    E = dt*(beta*S0*I0/N-(eps+mu)*E0)+E0
    I = dt*(eps*E0-(gamma+mu)*I0)+I0
    R = dt*(gamma*I0 - mu*R0)+R0
    T-=1
    # print("N=",S+E+I+R)
    # print(S, E, I, R)
    return(model1(T, dt=dt, S0=S, E0=E, I0=I, R0=R,
                  beta=beta, eps=eps, gamma=gamma, mu=mu,
                 states=torch.cat((states,
                                   torch.unsqueeze(torch.tensor([S, E, I, R]), 0)), 0)))

##### Test SEIR

In [84]:
# for deep recursion avoidance:
states = torch.tensor([[997, 0, 3, 0]])
for i in range(7):
    res = model1(1000, S0=float(states[-1][0]), E0=float(states[-1][1]),
                 I0=float(states[-1][2]), R0=float(states[-1][3]))
    states = torch.cat((states, res),0)
states.shape

torch.Size([7001, 4])

#### Results SEIR

In [91]:
%matplotlib
plt.ylim(0, 1000)

for i in range(states.shape[-1]):
    plt.plot(states.T[i], label=['S', 'E', 'I', 'R'][i])
plt.legend(loc="upper left")

Using matplotlib backend: Qt5Agg


<matplotlib.legend.Legend at 0x7fe44fcad4e0>

##### Appendix

In [26]:
pyro.clear_param_store()
# for vectorized, sampled, depended:
def model1(init_i, N):
    # mu = torch.tensor(0.5)
    # sigma = torch.tensor(0.1)
    # p = pyro.sample("latent_fairness", dist.Normal(mu, sigma))
    # print("p = ", p)
    i = init_i
    
    # vectorized, sampled, dependent:
    p = i/N
    # dd = dist.Bernoulli(p).expand([7, 2]).to_event(1)
    x = dd.sample()
    print("x= ", x)
    print("prob(x)= ", torch.exp(dd.log_prob(x)))

In [27]:
model1(3, 7)

x=  tensor([[0., 0.],
        [1., 0.],
        [1., 0.],
        [0., 1.],
        [0., 0.],
        [0., 0.],
        [0., 0.]])
prob(x)=  tensor([0.3265, 0.2449, 0.2449, 0.2449, 0.3265, 0.3265, 0.3265])


In [20]:
pyro.clear_param_store()
# for vectorized, sampled, depended:
def model(init_i, N):
    # mu = torch.tensor(0.5)
    # sigma = torch.tensor(0.1)
    # p = pyro.sample("latent_fairness", dist.Normal(mu, sigma))
    # print("p = ", p)
    i = init_i
    
    # vectorized, sampled, dependent:
    with pyro.plate("data_loop", size=3, subsample_size=2) as ind:
        p = i/N
        dd = dist.Bernoulli(p).expand([7, 2]).to_event(1)
        print("dd.batch_shape:")
        print(dd.batch_shape)
        print("dd.event_shape:")
        print(dd.event_shape)
        x = dd.sample()
        print("x = ", x)
        print("prob(x) = ", torch.exp(dd.log_prob(x)))
        # print("accurate: ",
        #       torch.tensor([(p if x0 else 1-p)*(p if x1 else 1-p)
        #                     for x0, x1 in x]))

        y = pyro.sample("y", dd)
        print("y = ", y)
        print("ind:")
        print(ind)
        
        # a = pyro.sample("obs", dd, obs=data.index_select(0, ind))
        # print("a:")
        # print(a)

In [21]:
model(3,7)

dd.batch_shape:
torch.Size([7])
dd.event_shape:
torch.Size([2])
x =  tensor([[1., 0.],
        [0., 1.],
        [0., 0.],
        [0., 1.],
        [1., 0.],
        [1., 0.],
        [0., 1.]])
prob(x) =  tensor([0.2449, 0.2449, 0.3265, 0.2449, 0.2449, 0.2449, 0.2449])


ValueError: Shape mismatch inside plate('data_loop') at site y dim -1, 2 vs 7

In [28]:
x_axis = pyro.plate("x", 3, dim=-1)
y_axis = pyro.plate("y", 2, dim=-2)
with x_axis:
    x = pyro.sample("x", dist.Normal(0, 1))
    # this dont work here because of plate:
    # x = pyro.sample("x", dist.Normal(0, 1).expand([5, 2]).to_event(1))
with y_axis:
    y = pyro.sample("y", dist.Normal(0, 1))
print("x: ", x)
print("y: ", y)

x:  tensor([-0.5929,  0.3222,  0.2136])
y:  tensor([[0.8491],
        [1.8800]])
