In [None]:
import torch as tn
import torchtt as tntt
import TTCME
import matplotlib.pyplot as plt 
import datetime
import numpy as np

tn.set_default_tensor_type(tn.DoubleTensor)
qtt = True

In [None]:
r1 = TTCME.ChemicalReaction(['S','E','I','R'],'S+I->E+I', 0.1)
r2 = TTCME.ChemicalReaction(['S','E','I','R'],'E->I',     0.5)
r3 = TTCME.ChemicalReaction(['S','E','I','R'],'I->S',     1.0)
r4 = TTCME.ChemicalReaction(['S','E','I','R'],'S->',      0.01)
r5 = TTCME.ChemicalReaction(['S','E','I','R'],'E->',      0.01)
r6 = TTCME.ChemicalReaction(['S','E','I','R'],'I->R',     0.01)
r7 = TTCME.ChemicalReaction(['S','E','I','R'],'->S',      0.4)

mdl = TTCME.ReactionSystem(['S','E','I','R'],[r1, r2 ,r3, r4, r5 ,r6, r7])
N = [128,64,64,64]

Att = mdl.generatorTT(N)

In [None]:
mu0 = [50,4,0,0]
sigma = 1
p0 = tntt.rank1TT([tn.exp(-0.5*(mu0[i]-tn.arange(N[i]))**2/sigma**2) for i in range(4)])
p0 = p0 / p0.sum()

In [None]:
fwd_int = TTCME.TimeIntegrator.TTInt(Att if not qtt else Att.to_qtt().round(1e-13), epsilon=1e-5, N_max = 8, dt_max=1, method='cheby')
Nt = 4
Tend = 8

In [None]:
if qtt:
    p = p0.clone()
    time = 0.0
    Ps = [p0.clone()]
    p = p.to_qtt()
    for i in range(Nt):
        dt = Tend/Nt
        tme = datetime.datetime.now()
        p = fwd_int.solve(p, dt, intervals = 8, qtt = True, verb=False,rounding = False) 
        tme = datetime.datetime.now() - tme
        time += dt
        Ps.append(tntt.reshape(p.clone(),N))
        p = p.round(1e-10)
        print('Time ', time, ', rank ',p.R,', solver time ', tme)
    p = tntt.reshape(p,N)
else:
    p = p0.clone()
    time = 0.0
    Ps = [p0.clone()]

    for i in range(Nt):
        dt = Tend/Nt
        tme = datetime.datetime.now()
        p = fwd_int.solve(p, dt, intervals = 8) 
        tme = datetime.datetime.now() - tme
        time += dt
        Ps.append(p.clone())
        print('Time ', time, ', rank ',p.R,', solver time ', tme)

In [None]:
Pend = p.numpy()
P_ref = Pend

plt.figure()
plt.imshow(Pend.sum(2).sum(2).transpose(),origin='lower',cmap='gray_r')
plt.colorbar()
plt.xlabel(r'$x_1$')
plt.ylabel(r'$x_2$')

plt.figure()
plt.imshow(np.abs(Pend-P_ref).sum(2).sum(2).transpose(),origin='lower',cmap='gray_r')
plt.colorbar()
plt.xlabel(r'$x_1$')
plt.ylabel(r'$x_2$')

plt.figure()
plt.imshow(Pend.sum(0).sum(2).transpose(),origin='lower',cmap='gray_r')
plt.colorbar()
plt.xlabel(r'$x_2$')
plt.ylabel(r'$x_3$')

plt.figure()
plt.imshow(np.abs(Pend-P_ref).sum(0).sum(2).transpose(),origin='lower',cmap='gray_r')
plt.colorbar()
plt.xlabel(r'$x_2$')
plt.ylabel(r'$x_3$')

plt.figure()
plt.imshow(Pend.sum(0).sum(0).transpose(),origin='lower',cmap='gray_r')
plt.colorbar()
plt.xlabel(r'$x_3$')
plt.ylabel(r'$x_4$')

plt.figure()
plt.imshow(np.abs(Pend-P_ref).sum(0).sum(0).transpose(),origin='lower',cmap='gray_r')
plt.colorbar()
plt.xlabel(r'$x_3$')
plt.ylabel(r'$x_4$')

plt.figure()
plt.imshow(Pend.sum(1).sum(1).transpose(),origin='lower',cmap='gray_r')
plt.colorbar()
plt.xlabel(r'$x_1$')
plt.ylabel(r'$x_4$')

plt.figure()
plt.imshow(np.abs(Pend-P_ref).sum(1).sum(1).transpose(),origin='lower',cmap='gray_r')
plt.colorbar()
plt.xlabel(r'$x_1$')
plt.ylabel(r'$x_4$')

plt.figure()
plt.imshow(Ps[0].numpy().sum(0).sum(2).transpose(),origin='lower',cmap='gray_r')
plt.colorbar()
plt.xlabel(r'$x_2$')
plt.ylabel(r'$x_3$')

plt.figure()
plt.imshow(Ps[1].numpy().sum(0).sum(2).transpose(),origin='lower',cmap='gray_r')
plt.colorbar()
plt.xlabel(r'$x_2$')
plt.ylabel(r'$x_3$')

plt.figure()
plt.imshow(Ps[2].numpy().sum(0).sum(2).transpose(),origin='lower',cmap='gray_r')
plt.colorbar()
plt.xlabel(r'$x_2$')
plt.ylabel(r'$x_3$')

plt.figure()
plt.imshow(Ps[3].numpy().sum(0).sum(2).transpose(),origin='lower',cmap='gray_r')
plt.colorbar()
plt.xlabel(r'$x_2$')
plt.ylabel(r'$x_3$')