In [4]:
import time
import numpy as np
import pandas as pd
import pylab as plt
from mesa import Agent, Model
from mesa.time import RandomActivation
from mesa.space import MultiGrid
from mesa.datacollection import DataCollector
import plotly.express as px


In [31]:
class State:
    SUSCEPTIBLE = 0
    INFECTED = 1
    REMOVED = 2


class IModel(Model):
    """A model for infection spread."""

    def __init__(
        self,
        N=100,
        width=100,
        height=100,
        ptrans=0.5,
        death_rate=0.02,
        recovery_days=21,
        recovery_sd=7,
    ):

        self.num_agents = N
        self.recovery_days = recovery_days
        self.recovery_sd = recovery_sd
        self.ptrans = ptrans
        self.death_rate = death_rate
        self.grid = MultiGrid(width, height, True)
        self.schedule = RandomActivation(self)
        self.running = True
        self.dead_agents = []
        # Create agents
        for i in range(self.num_agents):
            agent = MyAgent(i, self)
            self.schedule.add(agent)
            # Add the agent to a random grid cell
            x = self.random.randrange(self.grid.width)
            y = self.random.randrange(self.grid.height)
            self.grid.place_agent(agent, (x, y))
            # make some agents infected at start
            infected = np.random.choice([0, 1], p=[0.98, 0.02])
            if infected == 1:
                agent.state = State.INFECTED
                agent.recovery_time = self.get_recovery_time()

        self.datacollector = DataCollector(agent_reporters={"State": "state"})

    def get_recovery_time(self):
        return int(self.random.normalvariate(self.recovery_days, self.recovery_sd))

    def step(self):
        self.datacollector.collect(self)
        self.schedule.step()


class MyAgent(Agent):
    """ Create an agent."""

    def __init__(self, unique_id, model):
        super().__init__(unique_id, model)
        self.age = self.random.normalvariate(20, 80)
        self.state = State.SUSCEPTIBLE
        self.infection_time = 0

    def move(self):
        possible_steps = self.model.grid.get_neighborhood(
            self.pos, moore=True, include_center=False
        )
        new_position = self.random.choice(possible_steps)
        self.model.grid.move_agent(self, new_position)

    def status(self):
        if self.state == State.INFECTED:
            drate = self.model.death_rate
            alive = np.random.choice([0, 1], p=[drate, 1 - drate])
            if alive == 0:
                self.model.schedule.remove(self)
            t = self.model.schedule.time - self.infection_time
            if t >= self.recovery_time:
                self.state = State.REMOVED

    def contact(self):
        cellmates = self.model.grid.get_cell_list_contents([self.pos])
        if len(cellmates) > 1:
            for other in cellmates:
                if self.random.random() > self.model.ptrans:
                    continue
                if self.state is State.INFECTED and other.state is State.SUSCEPTIBLE:
                    other.state = State.INFECTED
                    other.infection_time = self.model.schedule.time
                    other.recovery_time = self.model.get_recovery_time()

    def step(self):
        self.status()
        self.move()
        self.contact()

In [50]:
# Create New Model
model = IModel(1000, 20, 20)
# Running Model
for i in range(50):
    model.step()
agent_state = model.datacollector.get_agent_vars_dataframe()

In [51]:
def get_output_data(model):
    """pivot the model dataframe to get states count at each step"""
    agent_state = model.datacollector.get_agent_vars_dataframe()
    df_out = pd.pivot_table(
        agent_state.reset_index(),
        index="Step",
        columns="State",
        aggfunc=np.size,
        fill_value=0,
    )
    labels = ["Susceptible", "Infected", "Removed"]
    df_out.columns = labels[: len(df_out.columns)]
    return df_out

In [52]:
output = get_output_data(model=model)

In [53]:
output = pd.melt(
    output.reset_index(),
    id_vars="Step",
    value_vars=["Susceptible", "Infected", "Removed"],
)

In [54]:
fig = px.line(output, x="Step", y="value", color="variable")
fig.show()