In [None]:
from IPython.display import clear_output
from matplotlib import pyplot as plt
import numpy as np
from camminapy import * 
from scipy.spatial.distance import cdist
import matplotlib.collections
import gif


%matplotlib inline
plt.style.use("kitishnotex")

In [None]:
@gif.frame
def live_plot(pos,types,time,radius,S,I,R):
    eps = 0.1
    clear_output(wait=True)
    colors = [kit.blue,kit.orange,kit.green]
    fig,axs = plt.subplots(1,2,dpi=100, figsize=(2*3,3))
    ax = axs[0]
    for t in np.unique(types):
        i = np.argwhere(types == t)
        patches =  [plt.Circle((x,y), radius) for x,y in 
                              zip(pos[i,0],pos[i,1])]
        coll = matplotlib.collections.PatchCollection(patches,facecolors = colors[t])
        ax.add_collection(coll)
    
    
    ax.set_xlim([0-eps,1+eps])
    ax.set_ylim([0-eps,1+eps])
    ax.axis('off')
    
    ax = axs[1]
    ax.stackplot(np.arange(len(S)),I,S,R,colors = [kit.orange,kit.blue,kit.green],labels = ["I","S","R"])
    ax.axis('off')
 #   ax.legend(["I","S","R"],bbox_to_anchor=(1.3, 0.5))
    #plt.tight_layout()
    handles, labels = ax.get_legend_handles_labels()
    fig.legend([handles[1],handles[0],handles[2]], [labels[1],labels[0],labels[2]],bbox_to_anchor=(0.63,1.1),ncol= 3)#(1.3, 0.5))
    #plt.suptitle("Time = {:05.2f}".format(time))

    #plt.xlabel('epoch')
    #plt.show();

In [None]:
n = 200
dt = 0.002
sigma = 1
radius = 0.015
theal = 0.3
pos = np.random.rand(n,2)
vel = np.random.randn(n,2)
for i in range(n):
    vel[i,:] /= np.linalg.norm(vel[i,:])*0.5
    #vel[i,:] *= -np.log(np.random.rand(1))/sigma

types = np.zeros(n,dtype=int)
types[0] = 1
timeinfected = np.zeros(n)
    
t = 0
S = [np.sum(types==0)]
I = [np.sum(types==1)]
R = [np.sum(types==2)]
frames = []

while True:
    frame=live_plot(pos,types,t,radius,S,I,R)
    frames.append(frame)
    S.append(np.sum(types==0))
    I.append(np.sum(types==1))
    R.append(np.sum(types==2))

    t += dt
    
    ## Moving around part
    pos += dt*vel
    pos = np.where(pos<0,1+pos,pos)
    pos = np.where(pos>1,pos-1,pos)
    
    ## Infection part
    susceptible = np.argwhere(types==0)
    infecting = np.argwhere(types==1)
    timeinfected[infecting] += dt
    healed = np.argwhere(timeinfected>theal)
    types[healed] = 2
    infecting = np.argwhere(types==1)

    
    if len(infecting) == 0:
        S.append(np.sum(types==0))
        I.append(np.sum(types==1))
        R.append(np.sum(types==2))
        frame=live_plot(pos,types,t,radius,S,I,R)
        frames.append(frame)
        break
        
    possus = pos[susceptible,:].reshape(len(susceptible),2)
    posinf = pos[infecting, :].reshape(len(infecting),2)

    dist = cdist(possus,posinf,'euclidean') # distance of all sus to all inf
    dist = np.min(dist,axis = 1) # minimal distance to an inf for every sus 
    tobeinfected = np.argwhere(dist<2*radius)
    types[susceptible[tobeinfected]] = 1


In [None]:
gif.save(frames, "Model.gif", duration=50)
