<a href="https://colab.research.google.com/github/framante/contagion_model/blob/main/contagion_SEIR.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.gridspec as gridspec
import matplotlib.patches as patches


class ContagionSimulator:

    def __init__(self):
        pass

    def set_params(self, params):
        """
        Set model parameters from dictionary
        """
        self.__dict__.update(params)
        pass

    def initialise(self):
        """
        Initialise buffers for simulation
        """

        nagents = self.nagents

        """
        Initialise agents
        """
        # Position
        self.x = np.random.rand(nagents)
        self.y = np.random.rand(nagents)
        # Cache initial positions
        self.xbak = self.x.copy()
        self.ybak = self.y.copy()
        
        # Direction
        alpha = 2 * np.pi * np.random.rand(nagents)
        self.dx = self.dr * np.sin(alpha)
        self.dy = self.dr * np.cos(alpha)

        # Agent status
        self.isVulnerable = np.ones(nagents, dtype=bool)
        self.isExposed = np.zeros(nagents, dtype=bool)
        self.isSick = np.zeros(nagents, dtype=bool)
        self.isDead = np.zeros(nagents, dtype=bool)
        self.agentStatus = np.zeros(nagents, dtype=int) # 0 vulnerable, 1 exposed, 2 sick, 3 dead
        self.tExposure = np.zeros(nagents)
        self.tSick = np.zeros(nagents)

        # Natality
        self.tPregnancy = np.zeros(nagents)
        self.tLife = np.zeros(nagents)

        # Patient zero (literally, because Python starts counting at 0 and not 1 like in MATLAB)
        iseed = 0
        self.isVulnerable[iseed] = False
        self.isSick[iseed] = True
        self.agentStatus[iseed] = 2
        """
        Initialise health status counts
        """
        nt = self.nt
        self.nvulnerable = np.zeros(nt) * np.nan
        self.nexposed = np.zeros(nt) * np.nan
        self.nsick = np.zeros(nt) * np.nan
        self.ndead = np.zeros(nt) * np.nan
        self.tt = np.arange(nt) / self.tstepsperday
        pass

    def run_simulation(self, fig, save_file="contagion.mp4"):
        # Initialise buffers
        #self.initialise()

        # Create animation from simulation
        self.fig = fig
        self.ani = animation.FuncAnimation(
            self.fig, self.update_plot, interval=100, init_func=self.init_plot,
            blit=True, frames=self.nt, repeat=False
        )
        self.ani.save(save_file)
        plt.show()
        pass

    def check_bdry(self, ia, xnew, ynew):
        """
        Depending on the initial position of the agent compute 
        if his new positions will be out of the allowed boundary
        for sick people the bdry is the squadre [0,1]x[0,1]
        for the others is one of the inner squares: 
        [0, 0.5]x[0, 0.5] [0, 0.5]x[0.5, 1] [0.5, 1]x[0, 0.5] [0.5, 1]x[0.5, 1]
        """
        if self.isSick[ia]:
          hitWall = (xnew > 1) | (xnew < 0) | (ynew > 1) | (ynew < 0)
        else:
          if   (self.ybak[ia] < 0.5  and self.xbak[ia] < 0.5):   #left bottom
            hitWall = (xnew > 1) | (xnew < 0) | (ynew > 1) | (ynew < 0) | \
                      (xnew >= 0.5) | (ynew >= 0.5)
          elif (self.ybak[ia] >= 0.5  and self.xbak[ia] < 0.5):   #left top
            hitWall = (xnew > 1) | (xnew < 0) | (ynew > 1) | (ynew < 0) | \
                      (xnew >= 0.5) | (ynew < 0.5)
          elif (self.ybak[ia] >= 0.5  and self.xbak[ia] >= 0.5):   #right top
            hitWall = (xnew > 1) | (xnew < 0) | (ynew > 1) | (ynew < 0) | \
                      (xnew < 0.5) | (ynew < 0.5)
          elif (self.ybak[ia] < 0.5  and self.xbak[ia] >= 0.5):     #right bottom
            hitWall = (xnew > 1) | (xnew < 0) | (ynew > 1) | (ynew < 0) | \
                      (xnew < 0.5) | (ynew >= 0.5)
          else:        #should never happen
            hitWall = True
            print(ia)
            exit()

        return hitWall

    def take_step(self, ia):
        """
        Let the agent move around one step
        """
        xnew = self.x[ia] + self.dx[ia]
        ynew = self.y[ia] + self.dy[ia]

        # Check for domain boundaries
        hitWall = self.check_bdry(ia, xnew, ynew)
        ntries = 0
        stillTrying = True
        # If the agent is moving outside of the domain
        while hitWall and stillTrying:
            # New random direction vector
            alpha = 2 * np.pi * np.random.rand()
            self.dx[ia] = self.dr * np.sin(alpha)
            self.dy[ia] = self.dr * np.cos(alpha)
            # New location
            xnew = self.x[ia] + self.dx[ia]
            ynew = self.y[ia] + self.dy[ia]
            # If new location is still outside domain: try again
            hitWall = self.check_bdry(ia, xnew, ynew)
            ntries += 1
            if ntries > 100:
                stillTrying = False

        # Update position
        if ~hitWall:
          self.x[ia] = xnew
          self.y[ia] = ynew
        pass

    def stay_put(self, ia):
        """
        Do nothing
        """
        pass

    def check_health(self, ia):
        """
        Check the health status of the agent
        """
        if self.isExposed[ia]:
            # Increment exposure timer
            self.tExposure[ia] += 1
            # If exposure timer exceeds time to get sick
            if self.tExposure[ia] > self.dtGetSick:
                # Agent is no longer exposed
                self.isExposed[ia] = False
                # AGent is sick
                self.isSick[ia] = True
                self.agentStatus[ia] = 2
        elif self.isSick[ia]:
            # Increment sick timer
            self.tSick[ia] += 1
            # If sick timer exceeds time to get better
            if self.tSick[ia] > self.dtDeath:
                # Agent is no longer sick
                self.isSick[ia] = False
                # Agent is dead
                self.isDead[ia] = True
                self.agentStatus[ia] = 3
        else:
          pass

    def check_natural_death(self, it):
        """
        Check the aging
        """
        mask_inds = self.isVulnerable + self.isExposed + self.isSick
        self.tLife[mask_inds] += 1
        inds = self.tLife > self.dtLife
        self.isVulnerable[inds] = False
        self.isExposed[inds] = False
        self.isSick[inds] = False
        self.isDead[inds] = True
        self.agentStatus[inds] = 3
        if inds.sum() != 0:
          print(inds.sum(), " dead at ", it)
        """
        # generate probability from std normal distribution
        mu, sigma = 0.5, np.sqrt(0.005) # mean and standard deviation
        prob = np.random.normal(mu, sigma)
        if prob < it/self.dtNaturalDeath:
          print(ia, " dead at ", it, "cause of", it/self.dtNaturalDeath)
          # Agent is dead
          self.isVulnerable[ia] = False
          self.isExposed[ia] = False
          self.isSick[ia] = False
          self.isDead[ia] = True
          self.agentStatus[ia] = 3
          self.tPregnancy[ia] = 0
        else:
          pass
        """

    def check_contagion(self, ia):
        """
        Let the agent contaminate others
        """
        if self.isSick[ia]:
            # Compute the distance between the agent and others
            xdist = self.x[ia] - self.x
            ydist = self.y[ia] - self.y
            r = np.sqrt(xdist**2 + ydist**2)
            # All agents within contagion distance
            meetAgents = (r <= self.rcont)
            # Infect those within contagion distance,
            # who are vulnerable (possibly take into account those who are not at home)
            mask_inds = (meetAgents & self.isVulnerable)
            """
            inds = list()
            for i in range(len(mask_inds)):
              if mask_inds[i] == True:
                inds.append(i)
            """
            # extract the indices of the agents who are going to be infected
            inds = np.where(mask_inds == True) 
            # each of them gets infected
            for ind in inds:
              # take into account an efficiency of the infection != 1
              prob = np.random.rand()
              if prob < self.pContagion:
                self.isVulnerable[ind] = False
                self.isExposed[ind] = True
                self.agentStatus[ind] = 1
                self.tPregnancy[ind] = 0

    def check_pregnancy(self, ia):
        """
        check pregnancy
        """
        if self.isVulnerable[ia]:
          # Increment pregnancy timer
          self.tPregnancy[ia] += 1
          # Check pregnancy timer
          if self.tPregnancy[ia] == self.dtPregnancy:
            # print("conclusion of pregnancy for", ia)
            # Restore pregnancy timer to 0
            self.tPregnancy[ia] = 0
        
            # compute total agents alive
            alive_agents = self.isVulnerable.sum() + self.isExposed.sum() + self.isSick.sum()

            # set threshold and probability
            if (alive_agents / self.carriercapacity) > 0.95:
              thresh = 0.95
              # generate probability from std normal distribution
              mu, sigma = 0.5, np.sqrt(0.05) # mean and standard deviation
              prob = np.random.normal(mu, sigma)
            else:
              thresh = alive_agents / self.carriercapacity
              prob = np.random.rand()

            # pregnancy is successfull depending on a threshold
            if prob > thresh:
              print("successfull pregnancy for", ia)
              # add offspring
              self.nagents += self.noffspring
              # Position
              self.x = np.append(self.x, np.ones(self.noffspring) * self.x[ia])
              self.y = np.append(self.y, np.ones(self.noffspring) * self.y[ia])
              # Cache initial positions
              self.xbak = np.append(self.xbak, np.ones(self.noffspring) * self.xbak[ia])
              self.ybak = np.append(self.ybak, np.ones(self.noffspring) * self.ybak[ia])
              # Direction
              alpha = 2 * np.pi * np.random.rand(self.nagents)
              self.dx = self.dr * np.sin(alpha)
              self.dy = self.dr * np.cos(alpha)
              # Agent status
              self.isVulnerable = np.append(self.isVulnerable, np.ones(self.noffspring, dtype=bool))
              self.isExposed = np.append(self.isExposed, np.zeros(self.noffspring, dtype=bool))
              self.isSick = np.append(self.isSick, np.zeros(self.noffspring, dtype=bool))
              self.isDead = np.append(self.isDead, np.zeros(self.noffspring, dtype=bool))
              self.agentStatus = np.append(self.agentStatus, np.zeros(self.noffspring, dtype=int))
              self.tExposure = np.append(self.tExposure, np.zeros(self.noffspring))
              self.tSick = np.append(self.tSick, np.zeros(self.noffspring))
              
              # Natality
              self.tPregnancy = np.append(self.tPregnancy, np.zeros(self.noffspring))  
              self.tLife = np.append(self.tLife, np.zeros(self.noffspring))        


    def count_cases(self, it):
        """
        Count the number of agents
        """
        self.nvulnerable[it] = self.isVulnerable.sum()
        self.nexposed[it] = self.isExposed.sum()
        self.nsick[it] = self.isSick.sum()
        self.ndead[it] = self.isDead.sum()
        nall = self.nvulnerable[it] + self.nsick[it] + self.ndead[it] + self.nexposed[it]
        if nall != self.nagents:
            print("Counts don't add up", nall, " ", self.nagents)
            self.break_simulation = True
        pass

    def simulation_step(self, it):
        """
        Perform one simulation step
        """

        # Loop over all agents
        for ia in range(self.nagents):  
            # Reset nextMove
            if self.isDead[ia]:
              nextMove = "stay_put"
            else:
              nextMove = "take_step"
            
            # print(f"Next action of agent {ia} at ({x[ia]}, {y[ia]}): {nextMove}")

            # If the nextMove is not defined (should never happen)
            if nextMove is None:
                print(it, ia)
                exit()

            # Call the method corresponding to nextMove
            move = getattr(self, nextMove)
            move(ia)

            # Check health status
            self.check_health(ia)

            #check pregnancy status
            self.check_pregnancy(ia)

            # Contaminate others
            self.check_contagion(ia)

        #check natural death
        self.check_natural_death(it)

        # Count cases
        self.count_cases(it)
        pass

    def init_plot(self):
        """
        Initialise the animation window
        """
        # Initialise buffers
        self.initialise()
        # Plot lines that divide the domain
        line1 = patches.ConnectionPatch(xyA=(0,0.5), xyB=(1,0.5), coordsA="data", coordsB="data", color="black")
        line2 = patches.ConnectionPatch(xyA=(0.5,0), xyB=(0.5,1), coordsA="data", coordsB="data", color="black")
        # Health status colours
        self.colours = ["c", "m", "g", "y"]
        cm = LinearSegmentedColormap.from_list("agent_colours", self.colours, N=4)

        # Plotting panels
        gs = gridspec.GridSpec(ncols=1, nrows=4, figure=self.fig)

        # Panel 1: time series of health status
        self.ax_plt = self.fig.add_subplot(gs[:2])
        # Number of vulnerable agents
        self.nvulnerable_plt, = self.ax_plt.plot(self.tt[0], self.nvulnerable[0], "-", c="c", label="Susceptible")
        # Number of exposed agents
        self.nexposed_plt, = self.ax_plt.plot(self.tt[0], self.nexposed[0], "-", c="m", label="Exposed")
        # Number of sick agents
        self.nsick_plt, = self.ax_plt.plot(self.tt[0], self.nsick[0], "-", c="g", label="Sick")
        # Number of dead agents
        self.ndead_plt, = self.ax_plt.plot(self.tt[0], self.ndead[0], "-", c="y", label="Dead")
        # Format axes
        self.ax_plt.set_xlim((0, self.nt / self.tstepsperday))
        self.ax_plt.set_ylim((-1, 3.1*self.nagents)) # here you need n times * self.nagents to cope with the newborns
        self.ax_plt.legend(ncol=1, loc="upper right")
        self.ax_plt.set_ylabel("Count")
        self.ax_plt.set_xlabel("Time [days]")

        # Panel 2: scatter plot of agents
        self.ax_sct = self.fig.add_subplot(gs[2:])
        # Add lines
        self.ax_sct.add_patch(line1)
        self.ax_sct.add_patch(line2)
        # Plot agents
        self.sct = self.ax_sct.scatter(self.x, self.y, c=self.agentStatus, cmap=cm, vmin=0, vmax=3)
        # Format axes
        self.ax_sct.set_xlim((0, 1))
        self.ax_sct.set_ylim((0, 1))
        self.ax_sct.set_xticks([])
        self.ax_sct.set_yticks([])
        self.ax_sct.set_title("\nagents: " + str(self.nagents) +"    carrier capacity: " + str(self.carriercapacity) + "    dtGetSick: " + str(self.dtGetSick) + 
                              "\ndtDeath: " + str(self.dtDeath) + "    dtPregnancy: " + str(self.dtPregnancy) + "    pContagion: " + str(self.pContagion))
        self.fig.tight_layout()
        self.fig.subplots_adjust(top=0.95, bottom=0.05, left=0.1, right=0.95)

        return self.nvulnerable_plt, self.nexposed_plt, self.nsick_plt, self.ndead_plt, self.sct,

    def update_plot(self, it):
        # Do one simulation step
        self.simulation_step(it)
        # Agent coordinates
        pos = np.c_[self.x, self.y]
        # Update agent positions
        self.sct.set_offsets(pos)
        # Update agent colours
        self.sct.set_array(self.agentStatus)
        # Update time series
        self.nvulnerable_plt.set_data(self.tt[:it], self.nvulnerable[:it])
        self.nexposed_plt.set_data(self.tt[:it], self.nexposed[:it])
        self.nsick_plt.set_data(self.tt[:it], self.nsick[:it])
        self.ndead_plt.set_data(self.tt[:it], self.ndead[:it])

        return self.nvulnerable_plt, self.nexposed_plt, self.nsick_plt, self.ndead_plt, self.sct,


if __name__ == "__main__":

    params = {
        # Model parameters
        "nagents": 50,             # No. of agents
        "carriercapacity": 30,     # carrier capacity of the environment
        "tstepsperday": 5,         # No. of time steps per day (affects display only)
        "nt": 5000,                # No. of time steps of simulation (~1 year)               1500
        "dtGetSick": 200,          # Time after which agents get sick (6 weeks -> ~40 days)  200
        "dtDeath": 40,             # Time after which agents die from disease (6-8 days)     35
        "dtLife": 5500,            # Time after which agents die from causes other than disease (~2-5 years)   5500  
        "dtPregnancy": 250,        # How long a pregnancy takes (~50 days)                   250
        "pPregnancy": 0.2,         # probability the pregnancy is successfull (no longer used once the carrier capacity has been introduced)
        "pContagion": 0.5,         # probability the contagion is effective                
        "noffspring": 4,           # How much offspring (3-5)
        "rcont": .04,              # Distance below which agents pass on disease
        "dr": .02,                 # Step length per time step
    }

In [None]:
simulator = ContagionSimulator()
simulator.set_params(params)

In [None]:
simulator.run_simulation(plt.figure(figsize=(8,12)))