In [None]:
import time
import numpy as np
import pandas as pd


from enum import IntEnum
from mesa.agent import Agent
from mesa.model import Model
from mesa.time import RandomActivation
from mesa.space import MultiGrid
from mesa.datacollection import DataCollector

# Bokeh & Panel imports
from bokeh.plotting import figure
from bokeh.models import ColumnDataSource, HoverTool, LinearColorMapper
from bokeh.palettes import Category10
import panel as pn
pn.extension()

# -- MODEL COMPONENTS -------------------------------------------------------

class State(IntEnum):
    SUSCEPTIBLE = 0
    INFECTED    = 1
    REMOVED     = 2

class MyAgent(Agent):
    '''An agent in the epidemic model.'''
    def __init__(self, unique_id, model):
        super().__init__(unique_id, model)
        self.age = max(0, self.random.normalvariate(20, 40))
        self.state = State.SUSCEPTIBLE
        self.infection_time = None
        self.recovery_time = None

    def step(self):
        if self.state == State.INFECTED:
            self._progress_disease()
            self._contact_infection()
        if self.state in (State.SUSCEPTIBLE, State.INFECTED):
            self._move()

    def _progress_disease(self):
        # Death check
        if self.random.random() < self.model.death_rate:
            self.state = State.REMOVED
            return
        # Recovery check
        elapsed = self.model.current_step - self.infection_time
        if self.recovery_time is not None and elapsed >= self.recovery_time:
            self.state = State.REMOVED

    def _move(self):
        neigh = self.model.grid.get_neighborhood(self.pos, moore=True, include_center=False)
        new_pos = self.random.choice(neigh)
        self.model.grid.move_agent(self, new_pos)

    def _contact_infection(self):
        cellmates = self.model.grid.get_cell_list_contents([self.pos])
        for other in cellmates:
            if other is self or other.state != State.SUSCEPTIBLE:
                continue
            if self.random.random() < self.model.ptrans:
                other.state = State.INFECTED
                other.infection_time = self.model.current_step
                other.recovery_time = self.model.get_recovery_time()

class InfectionModel(Model):
    '''Grid-based SIR model.'''
    def __init__(self, N=100, width=20, height=20, ptrans=0.5, death_rate=0.02,
                 recovery_days=21, recovery_sd=7):
        super().__init__()
        self.num_agents = N
        self.grid = MultiGrid(width, height, True)
        self.schedule = RandomActivation(self)
        self.ptrans = ptrans
        self.death_rate = death_rate
        self.recovery_days = recovery_days
        self.recovery_sd = recovery_sd
        self.current_step = 0

        # Create agents
        for i in range(self.num_agents):
            a = MyAgent(i, self)
            # initial infection
            if self.random.random() < 0.02:
                a.state = State.INFECTED
                a.infection_time = 0
                a.recovery_time = self.get_recovery_time()
            self.schedule.add(a)
            x = self.random.randrange(self.grid.width)
            y = self.random.randrange(self.grid.height)
            self.grid.place_agent(a, (x, y))

        self.datacollector = DataCollector(agent_reporters={"State": "state"})

    def get_recovery_time(self):
        return max(1, int(self.random.normalvariate(self.recovery_days, self.recovery_sd)))

    def step(self):
        self.datacollector.collect(self)
        self.schedule.step()
        self.current_step += 1

# -- DATA TRANSFORM & PLOTTING FUNCTIONS -----------------------------------

def get_column_data(model):
    """
    Pivot the agent-level data into per-step counts of each state.
    Returns a DataFrame indexed by Step with columns ['Susceptible', 'Infected', 'Removed'].
    """
    df = model.datacollector.get_agent_vars_dataframe().reset_index()
    df['State'] = df['State'].astype(int)
    pivot = pd.pivot_table(
        df,
        index='Step',
        columns='State',
        values='AgentID',
        aggfunc='count'
    )
    pivot = pivot.reindex(columns=[0,1,2], fill_value=0)
    label_map = {0: 'Susceptible', 1: 'Infected', 2: 'Removed'}
    return pivot.rename(columns=label_map)[['Susceptible','Infected','Removed']]


def plot_states_bokeh(model, title=''):
    """Return a Bokeh figure of state counts over time."""
    X = get_column_data(model).reset_index()
    source = ColumnDataSource(X)
    colors = Category10[3]
    p = figure(width=600, height=400, title=title,
               x_axis_label='Step', y_axis_label='Count')
    for i, state in enumerate(['Susceptible', 'Infected', 'Removed']):
        p.line('Step', state, source=source, line_width=3,
               line_color=colors[i], alpha=0.8, legend_label=state)
    p.legend.location = 'top_right'
    p.background_fill_color = '#f5f5f5'
    p.toolbar.logo = None
    return p


def grid_values(model):
    """Return a DataFrame of grid cell "last" state codes."""
    # coord_iter yields (cell_list, (x, y)) on some Mesa versions
    w, h = model.grid.width, model.grid.height
    arr = np.full((w, h), -1, dtype=int)
    for cell_content, coords in model.grid.coord_iter():
        x, y = coords
        if cell_content:
            arr[x, y] = cell_content[-1].state
    return pd.DataFrame(arr)


def plot_cells_bokeh(model):
    """Return a Bokeh figure of the grid cell states."""
    df = grid_values(model)
    stacked = df.stack().reset_index()
    stacked.columns = ['x', 'y', 'value']
    source = ColumnDataSource(stacked)
    mapper = LinearColorMapper(palette=Category10[3], low=0, high=2)
    p = figure(width=500, height=500,
               x_range=(-0.5, df.shape[0]-0.5),
               y_range=(-0.5, df.shape[1]-0.5),
               tools='hover')
    p.rect('x', 'y', 1, 1, source=source,
           fill_color={'field': 'value', 'transform': mapper},
           line_color='black')
    p.grid.grid_line_color = None
    p.axis.visible = False
    p.toolbar.logo = None
    return p

# -- REAL-TIME DISPLAY WITH PANEL ------------------------------------------

plot_pane = pn.pane.Bokeh()
grid_pane = pn.pane.Bokeh()
layout = pn.Row(plot_pane, grid_pane, sizing_mode='stretch_width')

# Simulation parameters
duration = 100
population = 400

# Instantiate model
model = InfectionModel(population, 20, 20, ptrans=0.25, death_rate=0.01)

# Run and update in real time
for step in range(duration):
    model.step()
    plot_pane.object = plot_states_bokeh(model, title=f"Step {step}")
    grid_pane.object = plot_cells_bokeh(model)
    time.sleep(0.1)

# Display layout
layout


## ANIMATIONS

In [None]:
# -*- coding: utf-8 -*-
"""
Jupyter notebook: InfectionModel simulation, real-time plots, and interactive animation with Mesa
"""

import time
import numpy as np
import pandas as pd

# Mesa core imports
from enum import IntEnum
from mesa.agent import Agent
from mesa.model import Model
from mesa.time import RandomActivation
from mesa.space import MultiGrid
from mesa.datacollection import DataCollector

# Visualization imports
from bokeh.plotting import figure
from bokeh.models import ColumnDataSource, HoverTool, LinearColorMapper
from bokeh.palettes import Category10
import panel as pn
from mesa.visualization.modules import CanvasGrid, ChartModule
from mesa.visualization.ModularVisualization import ModularServer
pn.extension()

# -- MODEL COMPONENTS -------------------------------------------------------

class State(IntEnum):
    SUSCEPTIBLE = 0
    INFECTED    = 1
    REMOVED     = 2

class MyAgent(Agent):
    '''An agent in the epidemic model.'''
    def __init__(self, unique_id, model):
        super().__init__(unique_id, model)
        self.age = max(0, self.random.normalvariate(20, 40))
        self.state = State.SUSCEPTIBLE
        self.infection_time = None
        self.recovery_time = None

    def step(self):
        if self.state == State.INFECTED:
            self._progress_disease()
            self._contact_infection()
        if self.state in (State.SUSCEPTIBLE, State.INFECTED):
            self._move()

    def _progress_disease(self):
        # Death (removed) check
        if self.random.random() < self.model.death_rate:
            self.state = State.REMOVED
            return
        # Recovery check
        elapsed = self.model.current_step - self.infection_time
        if self.recovery_time is not None and elapsed >= self.recovery_time:
            self.state = State.REMOVED

    def _move(self):
        neighbors = self.model.grid.get_neighborhood(self.pos, moore=True, include_center=False)
        new_pos = self.random.choice(neighbors)
        self.model.grid.move_agent(self, new_pos)

    def _contact_infection(self):
        cellmates = self.model.grid.get_cell_list_contents([self.pos])
        for other in cellmates:
            if other is self or other.state != State.SUSCEPTIBLE:
                continue
            if self.random.random() < self.model.ptrans:
                other.state = State.INFECTED
                other.infection_time = self.model.current_step
                other.recovery_time = self.model.get_recovery_time()

class InfectionModel(Model):
    '''Grid-based SIR model.'''
    def __init__(self, N=100, width=20, height=20,
                 ptrans=0.5, death_rate=0.02,
                 recovery_days=21, recovery_sd=7):
        super().__init__()
        self.num_agents = N
        self.grid = MultiGrid(width, height, torus=True)
        self.schedule = RandomActivation(self)
        self.ptrans = ptrans
        self.death_rate = death_rate
        self.recovery_days = recovery_days
        self.recovery_sd = recovery_sd
        self.current_step = 0

        for i in range(self.num_agents):
            agent = MyAgent(i, self)
            if self.random.random() < 0.02:
                agent.state = State.INFECTED
                agent.infection_time = 0
                agent.recovery_time = self.get_recovery_time()
            self.schedule.add(agent)
            x = self.random.randrange(self.grid.width)
            y = self.random.randrange(self.grid.height)
            self.grid.place_agent(agent, (x, y))

        self.datacollector = DataCollector(agent_reporters={"State": "state"})

    def get_recovery_time(self):
        return max(1, int(self.random.normalvariate(self.recovery_days, self.recovery_sd)))

    def step(self):
        self.datacollector.collect(self)
        self.schedule.step()
        self.current_step += 1

# -- BOKEH & PANEL PLOTTING -----------------------------------------------

def get_column_data(model):
    df = model.datacollector.get_agent_vars_dataframe().reset_index()
    df['State'] = df['State'].astype(int)
    pivot = pd.pivot_table(
        df, index='Step', columns='State', values='AgentID', aggfunc='count')
    pivot = pivot.reindex(columns=[0,1,2], fill_value=0)
    return pivot.rename(columns={0:'Susceptible',1:'Infected',2:'Removed'})


def plot_states_bokeh(model, title=''):
    X = get_column_data(model).reset_index()
    src = ColumnDataSource(X)
    colors = Category10[3]
    p = figure(width=600, height=400, title=title)
    p.xaxis.axis_label = 'Step'
    p.yaxis.axis_label = 'Count'
    for i, state in enumerate(['Susceptible','Infected','Removed']):
        p.line('Step', state, source=src, line_width=3, line_color=colors[i], legend_label=state)
    p.legend.location='top_right'
    return p


def grid_values(model):
    w, h = model.grid.width, model.grid.height
    arr = np.full((w,h), -1)
    for content, (x,y) in model.grid.coord_iter():
        if content:
            arr[x,y] = content[-1].state
    return pd.DataFrame(arr)


def plot_cells_bokeh(model):
    df = grid_values(model)
    st = df.stack().reset_index()
    st.columns=['x','y','value']
    src = ColumnDataSource(st)
    mapper = LinearColorMapper(palette=Category10[3], low=0, high=2)
    p = figure(width=500, height=500, tools='hover')
    p.rect('x','y',1,1, source=src,
           fill_color={'field':'value','transform':mapper})
    p.axis.visible=False
    return p

# -- REAL-TIME PANEL LAYOUT ----------------------------------------------
plot_pane = pn.pane.Bokeh()
grid_pane = pn.pane.Bokeh()
layout = pn.Row(plot_pane, grid_pane)

# Example run
steps=100; pop=400
m = InfectionModel(pop,20,20,ptrans=0.25,death_rate=0.01)
for i in range(steps):
    m.step()
    plot_pane.object=plot_states_bokeh(m,title=f"Step {i}")
    grid_pane.object=plot_cells_bokeh(m)
    time.sleep(0.1)
layout

# -- INTERACTIVE ANIMATION WITH MESA VISUALIZATION ------------------------

def portrayal(agent):
    '''Define how to draw an agent.'''
    color_map = {State.SUSCEPTIBLE: 'blue', State.INFECTED: 'red', State.REMOVED: 'green'}
    return {
        "Shape": "circle",
        "r": 0.8,
        "Filled": True,
        "Layer": 0,
        "Color": color_map.get(agent.state, 'gray')
    }

# Build visualization modules
grid_viz = CanvasGrid(portrayal, 20, 20, 500, 500)
chart_viz = ChartModule([
    {"Label": "Susceptible", "Color": "blue"},
    {"Label": "Infected",    "Color": "red"},
    {"Label": "Removed",     "Color": "green"}
])

# Fixed parameters (no UserSettableParameter)
model_params = {
    "N": 200,
    "width": 20,
    "height": 20,
    "ptrans": 0.5,
    "death_rate": 0.02
}

# Launch server
server = ModularServer(
    InfectionModel,
    [grid_viz, chart_viz],
    "Interactive SIR Grid Model",
    model_params
)
server.port = 8521
server.launch()


2025-05-14 03:31:56.650 /opt/miniconda3/envs/swordfish/lib/python3.9/site-packages/IPython/core/__init__.py modified; restarting server


In [None]:
# app.py
import time
import streamlit as st
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from enum import IntEnum
from mesa.agent import Agent
from mesa.model import Model
from mesa.time import RandomActivation
from mesa.space import MultiGrid
from mesa.datacollection import DataCollector
from bokeh.plotting import figure
from bokeh.models import ColumnDataSource, LinearColorMapper
from bokeh.palettes import Category10

st.set_page_config(page_title="HIV ABM Dashboard", layout="wide")

# --- MODEL CLASSES ---------------------------------------------------------
class HIVStatus(IntEnum):
    SUSCEPTIBLE        = 0
    INFECTED_UNTREATED = 1
    INFECTED_TREATED   = 2
    AIDS_UNTREATED     = 3
    AIDS_TREATED       = 4
    DEAD               = 5

class PersonAgent(Agent):
    def __init__(self, unique_id, model):
        super().__init__(unique_id, model)
        self.hiv_status = HIVStatus.SUSCEPTIBLE
        self.infection_time = None
        self.progression_time = None
        self.knows_status = False
        self.on_treatment = False

    def move(self):
        steps = self.model.grid.get_neighborhood(self.pos, moore=True, include_center=False)
        new_pos = self.random.choice(steps)
        self.model.grid.move_agent(self, new_pos)

    def update_disease(self):
        if self.hiv_status in (HIVStatus.SUSCEPTIBLE, HIVStatus.DEAD): return
        if not self.knows_status and self.random.random() < self.model.testing_rate:
            self.knows_status = True
            if self.random.random() < self.model.treatment_coverage:
                self.on_treatment = True
                if self.hiv_status == HIVStatus.INFECTED_UNTREATED:
                    self.hiv_status = HIVStatus.INFECTED_TREATED
                elif self.hiv_status == HIVStatus.AIDS_UNTREATED:
                    self.hiv_status = HIVStatus.AIDS_TREATED
        if self.hiv_status in (HIVStatus.INFECTED_UNTREATED, HIVStatus.INFECTED_TREATED):
            elapsed = self.model.schedule.time - self.infection_time
            eff = elapsed/2 if self.on_treatment else elapsed
            if eff >= self.progression_time:
                self.hiv_status = HIVStatus.AIDS_TREATED if self.on_treatment else HIVStatus.AIDS_UNTREATED
        drate = {
            HIVStatus.AIDS_UNTREATED: self.model.death_rate_untreated,
            HIVStatus.AIDS_TREATED:   self.model.death_rate_treated,
            HIVStatus.INFECTED_UNTREATED: self.model.death_rate_untreated/10
        }.get(self.hiv_status, 0)
        if self.random.random() < drate:
            self.hiv_status = HIVStatus.DEAD
            self.model.deaths += 1

    def interact(self):
        if self.hiv_status not in (HIVStatus.INFECTED_UNTREATED, HIVStatus.INFECTED_TREATED,
                                   HIVStatus.AIDS_UNTREATED, HIVStatus.AIDS_TREATED): return
        mates = self.model.grid.get_cell_list_contents([self.pos])
        for other in mates:
            if other.hiv_status != HIVStatus.SUSCEPTIBLE: continue
            prob = self.model.base_transmission_rate
            if self.hiv_status == HIVStatus.AIDS_UNTREATED: prob *= 2.5
            elif self.hiv_status == HIVStatus.AIDS_TREATED: prob *= 0.5
            if self.on_treatment: prob *= (1 - self.model.treatment_effectiveness)
            if self.random.random() < prob:
                other.hiv_status = HIVStatus.INFECTED_UNTREATED
                other.infection_time = self.model.schedule.time
                other.progression_time = self.model.get_progression_time()
                self.model.infections += 1

    def step(self):
        self.update_disease()
        self.move()
        self.interact()

class HIVModel(Model):
    def __init__(self, N, width, height, init_inf, base_trans, condom_use,
                 treat_cov, treat_eff, prog_time, prog_sd, dr_unc, dr_tr, test_rate):
        super().__init__()
        self.grid = MultiGrid(width, height, True)
        self.schedule = RandomActivation(self)
        self.base_transmission_rate = base_trans
        self.condom_use_rate = condom_use
        self.treatment_coverage = treat_cov
        self.treatment_effectiveness = treat_eff
        self.death_rate_untreated = dr_unc/365
        self.death_rate_treated = dr_tr/365
        self.testing_rate = test_rate/365
        self.prog_time = prog_time
        self.prog_sd = prog_sd
        self.infections = 0; self.deaths = 0
        for i in range(N):
            a = PersonAgent(i, self)
            self.schedule.add(a)
            x, y = self.random.randrange(width), self.random.randrange(height)
            self.grid.place_agent(a, (x, y))
        infected = self.random.sample(self.schedule.agents, init_inf)
        for a in infected:
            a.hiv_status = HIVStatus.INFECTED_UNTREATED
            a.infection_time = 0
            a.progression_time = self.get_progression_time()
        self.datacollector = DataCollector(
            model_reporters={
                'Step': lambda m: m.schedule.time,
                'Sus': lambda m: sum(1 for ag in m.schedule.agents if ag.hiv_status==HIVStatus.SUSCEPTIBLE),
                'Inf_U': lambda m: sum(1 for ag in m.schedule.agents if ag.hiv_status==HIVStatus.INFECTED_UNTREATED),
                'Inf_T': lambda m: sum(1 for ag in m.schedule.agents if ag.hiv_status==HIVStatus.INFECTED_TREATED),
                'AIDS_U': lambda m: sum(1 for ag in m.schedule.agents if ag.hiv_status==HIVStatus.AIDS_UNTREATED),
                'AIDS_T': lambda m: sum(1 for ag in m.schedule.agents if ag.hiv_status==HIVStatus.AIDS_TREATED),
                'Dead': lambda m: sum(1 for ag in m.schedule.agents if ag.hiv_status==HIVStatus.DEAD),
                'NewInf': lambda m: m.infections,
                'Deaths': lambda m: m.deaths
            }
        )
    def get_progression_time(self): return int(np.random.normal(self.prog_time, self.prog_sd))
    def step(self):
        self.infections = 0; self.deaths = 0
        self.schedule.step()
        self.datacollector.collect(self)

# --- STREAMLIT LAYOUT with Animation --------------------------------------
st.sidebar.title('Parameters')
cols = st.sidebar.columns(2)
params = {
    'N': cols[0].slider('Population', 100, 2000, 500),
    'init_inf': cols[0].slider('Initial Infected', 1, 50, 10),
    'base_trans': cols[0].slider('Transmission Prob.', 0.01, 1.0, 0.1, 0.01),
    'condom_use': cols[1].slider('Condom Use Rate', 0.0, 1.0, 0.5, 0.05),
    'treat_cov': cols[1].slider('Treatment Coverage', 0.0, 1.0, 0.2, 0.05),
    'treat_eff': cols[1].slider('Treatment Efficacy', 0.0, 1.0, 0.96, 0.01),
    'prog_time': cols[0].number_input('Prog Time (d)', 365, 3650, 3650),
    'prog_sd': cols[0].number_input('Prog SD (d)', 100, 2000, 1095),
    'dr_unc': cols[1].slider('Death Untr (ann)', 0.0, 1.0, 0.1, 0.01),
    'dr_tr': cols[1].slider('Death Tr (ann)', 0.0, 1.0, 0.02, 0.005),
    'test_rate': cols[0].slider('Testing Rate (ann)', 0.0, 1.0, 0.3, 0.01),
    'days': st.sidebar.slider('Days', 10, 1000, 100)
}
if st.sidebar.button('Run Animation'):
    days = params.pop('days')
    model_args = params
    model = HIVModel(**model_args, width=20, height=20)

    ts_pl = st.empty()
    grid_pl = st.empty()

    # Initialize data lists
    steps_list = []
    data = {'Sus': [], 'Inf_U': [], 'Inf_T': [], 'AIDS_U': [], 'AIDS_T': [], 'Dead': []}

    for step in range(days):
        model.step()
        latest = model.datacollector.get_model_vars_dataframe().iloc[-1]
        steps_list.append(latest['Step'])
        for key in data:
            data[key].append(latest[key])

        # Time-series plot with matplotlib
        fig1, ax1 = plt.subplots()
        for key in data:
            ax1.plot(steps_list, data[key], label=key)
        ax1.set_xlabel('Day')
        ax1.set_ylabel('Count')
        ax1.legend()
        ax1.set_title('Epidemic Curve')
        ts_pl.pyplot(fig1)

        # Grid plot with matplotlib imshow
        w, h = model.grid.width, model.grid.height
        arr = np.full((w, h), np.nan)
        for cell, (x, y) in model.grid.coord_iter():
            if cell:
                arr[x, y] = cell[-1].hiv_status
        fig2, ax2 = plt.subplots()
        cax = ax2.imshow(arr, vmin=0, vmax=5, cmap='tab20')
        ax2.set_title('Spatial Distribution')
        ax2.axis('off')
        grid_pl.pyplot(fig2)
        plt.close(fig1)
        plt.close(fig2)
        # tiny pause
        time.sleep(0.05)


In [None]:
# HOLD ON TO THIS 

# app.py
import time
import streamlit as st
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from enum import IntEnum
from mesa.agent import Agent
from mesa.model import Model
from mesa.time import RandomActivation
from mesa.space import MultiGrid
from mesa.datacollection import DataCollector

# --- PAGE CONFIG & STYLES -------------------------------------------------
st.set_page_config(page_title="HIV MODEL - AGENT BASED SIMULATION", layout="wide")
st.markdown("""
<style>
.main .block-container { padding: 1rem 2rem; }
</style>
""", unsafe_allow_html=True)

# --- HEADER & DESCRIPTION --------------------------------------------------
st.title("HIV MODEL - AGENT BASED SIMULATION")
st.markdown("""
**Objective:** Simulate how HIV spreads and evolves in a population over time, and evaluate the impact of interventions (e.g., increased treatment coverage, condom use) on HIV prevalence, incidence, and mortality.

**Aims:**
1. Model individual-level interactions and disease progression.  
2. Compare epidemic trajectories under different prevention and treatment scenarios.  
3. Provide an interactive dashboard to explore parameter influence on outcomes.
""")

# Explanatory Modeling
with st.expander("What is Explanatory Modeling?"):
    st.markdown("""
- We observe a phenomenon (HIV transmission patterns).  
- We hypothesize key drivers (treatment, condom use, testing).  
- We build an approximate representation (agent-based model).  
- If the model replicates real-world behavior, our hypotheses gain support.
""")

# --- MODEL CLASSES ---------------------------------------------------------
class HIVStatus(IntEnum):
    SUSCEPTIBLE        = 0
    INFECTED_UNTREATED = 1
    INFECTED_TREATED   = 2
    AIDS_UNTREATED     = 3
    AIDS_TREATED       = 4
    DEAD               = 5

class PersonAgent(Agent):
    def __init__(self, unique_id, model):
        super().__init__(unique_id, model)
        self.hiv_status = HIVStatus.SUSCEPTIBLE
        self.infection_time = None
        self.progression_time = None
        self.knows_status = False
        self.on_treatment = False

    def move(self):
        neighbors = self.model.grid.get_neighborhood(self.pos, moore=True, include_center=False)
        self.model.grid.move_agent(self, self.random.choice(neighbors))

    def update_disease(self):
        if self.hiv_status in (HIVStatus.SUSCEPTIBLE, HIVStatus.DEAD): return
        # Testing & treatment
        if not self.knows_status and self.random.random() < self.model.testing_rate:
            self.knows_status = True
            if self.random.random() < self.model.treatment_coverage:
                self.on_treatment = True
                if self.hiv_status == HIVStatus.INFECTED_UNTREATED:
                    self.hiv_status = HIVStatus.INFECTED_TREATED
                elif self.hiv_status == HIVStatus.AIDS_UNTREATED:
                    self.hiv_status = HIVStatus.AIDS_TREATED
        # Progression
        if self.hiv_status in (HIVStatus.INFECTED_UNTREATED, HIVStatus.INFECTED_TREATED):
            elapsed = self.model.schedule.time - self.infection_time
            eff = elapsed/2 if self.on_treatment else elapsed
            if eff >= self.progression_time:
                self.hiv_status = HIVStatus.AIDS_TREATED if self.on_treatment else HIVStatus.AIDS_UNTREATED
        # Death
        drate = {
            HIVStatus.AIDS_UNTREATED: self.model.death_rate_untreated,
            HIVStatus.AIDS_TREATED:   self.model.death_rate_treated,
            HIVStatus.INFECTED_UNTREATED: self.model.death_rate_untreated/10
        }.get(self.hiv_status, 0)
        if self.random.random() < drate:
            self.hiv_status = HIVStatus.DEAD
            self.model.deaths += 1

    def interact(self):
        if self.hiv_status not in (
            HIVStatus.INFECTED_UNTREATED, HIVStatus.INFECTED_TREATED,
            HIVStatus.AIDS_UNTREATED, HIVStatus.AIDS_TREATED): return
        mates = self.model.grid.get_cell_list_contents([self.pos])
        for other in mates:
            if other.hiv_status != HIVStatus.SUSCEPTIBLE: continue
            prob = self.model.base_transmission_rate
            if self.hiv_status == HIVStatus.AIDS_UNTREATED: prob *= 2.5
            elif self.hiv_status == HIVStatus.AIDS_TREATED: prob *= 0.5
            if self.on_treatment: prob *= (1 - self.model.treatment_effectiveness)
            if self.random.random() < prob:
                other.hiv_status = HIVStatus.INFECTED_UNTREATED
                other.infection_time = self.model.schedule.time
                other.progression_time = self.model.get_progression_time()
                self.model.infections += 1

    def step(self):
        self.update_disease()
        self.move()
        self.interact()

class HIVModel(Model):
    def __init__(self, N, width, height, init_inf,
                 base_trans, condom_use, treat_cov, treat_eff,
                 prog_time, prog_sd, dr_unc, dr_tr, test_rate):
        super().__init__()
        self.grid = MultiGrid(width, height, True)
        self.schedule = RandomActivation(self)
        # Parameters
        self.base_transmission_rate = base_trans
        self.condom_use_rate = condom_use
        self.treatment_coverage = treat_cov
        self.treatment_effectiveness = treat_eff
        self.death_rate_untreated = dr_unc/365
        self.death_rate_treated = dr_tr/365
        self.testing_rate = test_rate/365
        self.prog_time = prog_time
        self.prog_sd = prog_sd
        self.infections = 0
        self.deaths = 0
        # Initialize agents
        for i in range(N):
            agent = PersonAgent(i, self)
            self.schedule.add(agent)
            x, y = self.random.randrange(width), self.random.randrange(height)
            self.grid.place_agent(agent, (x, y))
        # Seed infections
        initially = self.random.sample(self.schedule.agents, init_inf)
        for a in initially:
            a.hiv_status = HIVStatus.INFECTED_UNTREATED
            a.infection_time = 0
            a.progression_time = self.get_progression_time()
        # Data collector
        self.datacollector = DataCollector(
            model_reporters={
                'Step': lambda m: m.schedule.time,
                'Sus': lambda m: sum(a.hiv_status==HIVStatus.SUSCEPTIBLE for a in m.schedule.agents),
                'Inf_U': lambda m: sum(a.hiv_status==HIVStatus.INFECTED_UNTREATED for a in m.schedule.agents),
                'Inf_T': lambda m: sum(a.hiv_status==HIVStatus.INFECTED_TREATED for a in m.schedule.agents),
                'AIDS_U': lambda m: sum(a.hiv_status==HIVStatus.AIDS_UNTREATED for a in m.schedule.agents),
                'AIDS_T': lambda m: sum(a.hiv_status==HIVStatus.AIDS_TREATED for a in m.schedule.agents),
                'Dead': lambda m: sum(a.hiv_status==HIVStatus.DEAD for a in m.schedule.agents),
                'NewInf': lambda m: m.infections,
                'Deaths': lambda m: m.deaths
            }
        )
    def get_progression_time(self):
        return int(np.random.normal(self.prog_time, self.prog_sd))
    def step(self):
        self.infections = 0
        self.deaths = 0
        self.schedule.step()
        self.datacollector.collect(self)

# --- STREAMLIT SIDEBAR -----------------------------------------------------
st.sidebar.header('Simulation Parameters')
cols = st.sidebar.columns(2)
params = {
    'N': cols[0].slider('Population', 100, 2000, 500, help='Total agents'),
    'init_inf': cols[0].slider('Initial Infected', 1, 50, 10, help='Seed infections'),
    'base_trans': cols[0].slider('Transmission Prob.', 0.01, 1.0, 0.1, 0.01, help='Per-contact risk'),
    'condom_use': cols[1].slider('Condom Use Rate', 0.0, 1.0, 0.5, 0.05, help='Reduces transmission'),
    'treat_cov': cols[1].slider('Treatment Coverage', 0.0, 1.0, 0.2, 0.05, help='Proportion tested and treated'),
    'treat_eff': cols[1].slider('Treatment Efficacy', 0.0, 1.0, 0.96, 0.01, help='Reduces progression'),
    'prog_time': cols[0].number_input('Prog. Time (d)', 365, 3650, 3650, help='Untreated→AIDS'),
    'prog_sd': cols[0].number_input('Prog. SD (d)', 100, 2000, 1095, help='Variation'),
    'dr_unc': cols[1].slider('Death Rate Untreated', 0.0, 1.0, 0.1, 0.01, help='Annual untreated'),
    'dr_tr': cols[1].slider('Death Rate Treated', 0.0, 1.0, 0.02, 0.005, help='Annual treated'),
    'test_rate': cols[1].slider('Testing Rate', 0.0, 1.0, 0.3, 0.01, help='Annual probability'),
    'days': st.sidebar.slider('Days', 10, 1000, 100, help='Simulation length')
}
if 'stop' not in st.session_state:
    st.session_state.stop = False

col_run, col_stop = st.sidebar.columns(2)
if col_run.button('Run Animation'):
    st.session_state.stop = False
if col_stop.button('Stop Simulation'):
    st.session_state.stop = True

# --- MAIN CONTENT AREA -----------------------------------------------------
ts_placeholder = st.empty()
gr_placeholder = st.empty()

if st.sidebar.button('Start'):
    days = params.pop('days')
    model = HIVModel(**params, width=20, height=20)
    steps, data = [], {k: [] for k in ['Sus','Inf_U','Inf_T','AIDS_U','AIDS_T','Dead']}
    for day in range(days):
        if st.session_state.stop:
            break
        model.step()
        rec = model.datacollector.get_model_vars_dataframe().iloc[-1]
        steps.append(rec['Step'])
        for k in data: data[k].append(rec[k])
        # Plot size reduced
        fig1, ax1 = plt.subplots(figsize=(4,3))
        for k in data: ax1.plot(steps, data[k], label=k)
        ax1.legend(fontsize='small')
        ts_placeholder.pyplot(fig1)
        plt.close(fig1)
        # Grid plot small
        w,h = model.grid.width, model.grid.height
        arr = np.full((w,h), np.nan)
        for cell,(x,y) in model.grid.coord_iter():
            if cell: arr[x,y]=cell[-1].hiv_status
        fig2, ax2 = plt.subplots(figsize=(4,3))
        ax2.imshow(arr, vmin=0, vmax=5, cmap='tab20')
        ax2.axis('off')
        gr_placeholder.pyplot(fig2)
        plt.close(fig2)
        time.sleep(0.05)  # controls speed

# Add acknowledgement in side panel
st.sidebar.markdown('---')
st.sidebar.markdown('**HIV ABM Modeling**')
st.sidebar.caption('Developed using Mesa & Streamlit')
# Final Tweaks

In [None]:
# app.py
import time
import streamlit as st
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from enum import IntEnum
from mesa.agent import Agent
from mesa.model import Model
from mesa.time import RandomActivation
from mesa.space import MultiGrid
from mesa.datacollection import DataCollector

# PARAMETER DEFINITIONS:
# N: Total number of agents (people) in the simulation
# init_inf: Initial number of infected agents at the start
# base_trans: Base probability of HIV transmission per contact
# condom_use: Reduces base transmission rate (as a proportion)
# treat_cov: Proportion of tested individuals who receive treatment
# treat_eff: Effectiveness of treatment in reducing HIV progression
# prog_time: Average time (in days) from infection to AIDS without treatment
# prog_sd: Standard deviation in progression time across agents
# dr_unc: Annual death rate for untreated AIDS cases
# dr_tr: Annual death rate for treated AIDS cases
# test_rate: Probability per year that an individual gets tested
# days: Number of steps/days the simulation runs

# --- PAGE CONFIG & STYLES -------------------------------------------------
st.set_page_config(page_title="HIV MODEL - AGENT BASED SIMULATION", layout="wide")
st.markdown("""
<style>
.main .block-container { padding: 1rem 2rem; }
</style>
""", unsafe_allow_html=True)

# --- HEADER & DESCRIPTION --------------------------------------------------
st.title("HIV MODEL - AGENT BASED SIMULATION")
st.markdown("""
**Objective:** Simulate how HIV spreads and evolves in a population over time, and evaluate the impact of interventions (e.g., increased treatment coverage, condom use) on HIV prevalence, incidence, and mortality.

**Aims:**
1. Model individual-level interactions and disease progression.  
2. Compare epidemic trajectories under different prevention and treatment scenarios.  
3. Provide an interactive dashboard to explore parameter influence on outcomes.
""")

# Explanatory Modeling
with st.expander("What is Explanatory Modeling?"):
    st.markdown("""
- We observe a phenomenon (HIV transmission patterns).  
- We hypothesize key drivers (treatment, condom use, testing).  
- We build an approximate representation (agent-based model).  
- If the model replicates real-world behavior, our hypotheses gain support.
""")

# --- MODEL CLASSES ---------------------------------------------------------
class HIVStatus(IntEnum):
    SUSCEPTIBLE        = 0
    INFECTED_UNTREATED = 1
    INFECTED_TREATED   = 2
    AIDS_UNTREATED     = 3
    AIDS_TREATED       = 4
    DEAD               = 5

class PersonAgent(Agent):
    def __init__(self, unique_id, model):
        super().__init__(unique_id, model)
        self.hiv_status = HIVStatus.SUSCEPTIBLE
        self.infection_time = None
        self.progression_time = None
        self.knows_status = False
        self.on_treatment = False

    def move(self):
        neighbors = self.model.grid.get_neighborhood(self.pos, moore=True, include_center=False)
        self.model.grid.move_agent(self, self.random.choice(neighbors))

    def update_disease(self):
        if self.hiv_status in (HIVStatus.SUSCEPTIBLE, HIVStatus.DEAD): return
        # Testing & treatment
        if not self.knows_status and self.random.random() < self.model.testing_rate:
            self.knows_status = True
            if self.random.random() < self.model.treatment_coverage:
                self.on_treatment = True
                if self.hiv_status == HIVStatus.INFECTED_UNTREATED:
                    self.hiv_status = HIVStatus.INFECTED_TREATED
                elif self.hiv_status == HIVStatus.AIDS_UNTREATED:
                    self.hiv_status = HIVStatus.AIDS_TREATED
        # Progression
        if self.hiv_status in (HIVStatus.INFECTED_UNTREATED, HIVStatus.INFECTED_TREATED):
            elapsed = self.model.schedule.time - self.infection_time
            eff = elapsed/2 if self.on_treatment else elapsed
            if eff >= self.progression_time:
                self.hiv_status = HIVStatus.AIDS_TREATED if self.on_treatment else HIVStatus.AIDS_UNTREATED
        # Death
        drate = {
            HIVStatus.AIDS_UNTREATED: self.model.death_rate_untreated,
            HIVStatus.AIDS_TREATED:   self.model.death_rate_treated,
            HIVStatus.INFECTED_UNTREATED: self.model.death_rate_untreated/10
        }.get(self.hiv_status, 0)
        if self.random.random() < drate:
            self.hiv_status = HIVStatus.DEAD
            self.model.deaths += 1

    def interact(self):
        if self.hiv_status not in (
            HIVStatus.INFECTED_UNTREATED, HIVStatus.INFECTED_TREATED,
            HIVStatus.AIDS_UNTREATED, HIVStatus.AIDS_TREATED): return
        mates = self.model.grid.get_cell_list_contents([self.pos])
        for other in mates:
            if other.hiv_status != HIVStatus.SUSCEPTIBLE: continue
            prob = self.model.base_transmission_rate
            if self.hiv_status == HIVStatus.AIDS_UNTREATED: prob *= 2.5
            elif self.hiv_status == HIVStatus.AIDS_TREATED: prob *= 0.5
            if self.on_treatment: prob *= (1 - self.model.treatment_effectiveness)
            if self.random.random() < prob:
                other.hiv_status = HIVStatus.INFECTED_UNTREATED
                other.infection_time = self.model.schedule.time
                other.progression_time = self.model.get_progression_time()
                self.model.infections += 1

    def step(self):
        self.update_disease()
        self.move()
        self.interact()

class HIVModel(Model):
    def __init__(self, N, width, height, init_inf,
                 base_trans, condom_use, treat_cov, treat_eff,
                 prog_time, prog_sd, dr_unc, dr_tr, test_rate):
        super().__init__()
        self.grid = MultiGrid(width, height, True)
        self.schedule = RandomActivation(self)
        # Parameters
        self.base_transmission_rate = base_trans
        self.condom_use_rate = condom_use
        self.treatment_coverage = treat_cov
        self.treatment_effectiveness = treat_eff
        self.death_rate_untreated = dr_unc/365
        self.death_rate_treated = dr_tr/365
        self.testing_rate = test_rate/365
        self.prog_time = prog_time
        self.prog_sd = prog_sd
        self.infections = 0
        self.deaths = 0
        # Initialize agents
        for i in range(N):
            agent = PersonAgent(i, self)
            self.schedule.add(agent)
            x, y = self.random.randrange(width), self.random.randrange(height)
            self.grid.place_agent(agent, (x, y))
        # Seed infections
        initially = self.random.sample(self.schedule.agents, init_inf)
        for a in initially:
            a.hiv_status = HIVStatus.INFECTED_UNTREATED
            a.infection_time = 0
            a.progression_time = self.get_progression_time()
        # Data collector
        self.datacollector = DataCollector(
            model_reporters={
                'Step': lambda m: m.schedule.time,
                'Sus': lambda m: sum(a.hiv_status==HIVStatus.SUSCEPTIBLE for a in m.schedule.agents),
                'Inf_U': lambda m: sum(a.hiv_status==HIVStatus.INFECTED_UNTREATED for a in m.schedule.agents),
                'Inf_T': lambda m: sum(a.hiv_status==HIVStatus.INFECTED_TREATED for a in m.schedule.agents),
                'AIDS_U': lambda m: sum(a.hiv_status==HIVStatus.AIDS_UNTREATED for a in m.schedule.agents),
                'AIDS_T': lambda m: sum(a.hiv_status==HIVStatus.AIDS_TREATED for a in m.schedule.agents),
                'Dead': lambda m: sum(a.hiv_status==HIVStatus.DEAD for a in m.schedule.agents),
                'NewInf': lambda m: m.infections,
                'Deaths': lambda m: m.deaths
            }
        )
    def get_progression_time(self):
        return int(np.random.normal(self.prog_time, self.prog_sd))
    def step(self):
        self.infections = 0
        self.deaths = 0
        self.schedule.step()
        self.datacollector.collect(self)

# --- STREAMLIT SIDEBAR -----------------------------------------------------
st.sidebar.header('Simulation Parameters')
cols = st.sidebar.columns(2)
params = {
    'N': cols[0].slider('Population', 100, 2000, 500, help='Total agents'),
    'init_inf': cols[0].slider('Initial Infected', 1, 50, 10, help='Seed infections'),
    'base_trans': cols[0].slider('Transmission Prob.', 0.01, 1.0, 0.1, 0.01, help='Per-contact risk'),
    'condom_use': cols[1].slider('Condom Use Rate', 0.0, 1.0, 0.5, 0.05, help='Reduces transmission'),
    'treat_cov': cols[1].slider('Treatment Coverage', 0.0, 1.0, 0.2, 0.05, help='Proportion tested and treated'),
    'treat_eff': cols[1].slider('Treatment Efficacy', 0.0, 1.0, 0.96, 0.01, help='Reduces progression'),
    'prog_time': cols[0].number_input('Prog. Time (d)', 365, 3650, 3650, help='Untreated → AIDS'),
    'prog_sd': cols[0].number_input('Prog. SD (d)', 100, 2000, 1095, help='Variation in progression'),
    'dr_unc': cols[1].slider('Death Rate Untreated', 0.0, 1.0, 0.1, 0.01, help='Annual untreated'),
    'dr_tr': cols[1].slider('Death Rate Treated', 0.0, 1.0, 0.02, 0.005, help='Annual treated'),
    'test_rate': cols[1].slider('Testing Rate', 0.0, 1.0, 0.3, 0.01, help='Annual probability'),
    'days': st.sidebar.slider('Days', 10, 1000, 100, help='Simulation length')
}
if 'stop' not in st.session_state:
    st.session_state.stop = False

col_run, col_stop = st.sidebar.columns(2)
col_run.button('Start') and st.session_state.__setitem__('stop', False)
col_stop.button('Stop') and st.session_state.__setitem__('stop', True)

# --- MAIN CONTENT AREA with Animation -------------------------------------
ts_placeholder = st.empty()
gr_placeholder = st.empty()

# Animation loop
if st.sidebar.button('Run'):
    days = params.pop('days')
    model = HIVModel(**params, width=20, height=20)
    steps, data = [], {k: [] for k in ['Sus','Inf_U','Inf_T','AIDS_U','AIDS_T','Dead']}
    for _ in range(days):
        if st.session_state.stop:
            break
        model.step()
        rec = model.datacollector.get_model_vars_dataframe().iloc[-1]
        steps.append(rec['Step'])
        for k in data: data[k].append(rec[k])
        # Smaller plots
        fig1, ax1 = plt.subplots(figsize=(3,2))
        for k in data: ax1.plot(steps, data[k], label=k)
        ax1.legend(fontsize='x-small')
        ts_placeholder.pyplot(fig1)
        plt.close(fig1)

        fig2, ax2 = plt.subplots(figsize=(3,2))
        w, h = model.grid.width, model.grid.height
        arr = np.full((w,h), np.nan)
        for cell,(x,y) in model.grid.coord_iter():
            if cell: arr[x,y] = cell[-1].hiv_status
        ax2.imshow(arr, vmin=0, vmax=5, cmap='tab20')
        ax2.axis('off')
        gr_placeholder.pyplot(fig2)
        plt.close(fig2)
        time.sleep(0.05)

# Sidebar branding
st.sidebar.markdown('---')
st.sidebar.markdown('**HIV ABM Modeling**')
st.sidebar.caption('Powered by Mesa & Streamlit')


## Disadvantages of This Agent-Based HIV Model

### Computational Load

More intensive than traditional models (e.g., SIR) due to simulating individuals and interactions.

### Parameter Sensitivity

Results are highly dependent on assumptions about transmission, treatment, and behavior.

### Scalability Issues

Difficult to scale for large populations (e.g., 100,000+) in real-time dashboards.

### Requires More Data

Accurate calibration needs rich empirical data, which may not always be available.<br><br>

## Difference in Assumptions: SIR vs. ABM

### Assumptions
**SIR Model**: Homogeneous mixing<br> 
**Agent-Based Model**: Heterogeneous agents & stochastic interactions <br>

### Population
**SIR Model**: Compartmentalized (S, I, R)<br> 
**Agent-Based Model**: Individuals with attributes (age, status, memory)<br>

### Transmission
**SIR Model**: Rate-based<br> 
**Agent-Based Model**: Behavior + local proximity-driven<br>

### Flexibility
**SIR Model**: Low<br> 
**Agent-Based Model**: High(can include social networks, movement, etc.)<br>

### Realism
**SIR Model**: Less realistic<br> 
**Agent-Based Model**: More realistic, complex<br>

## What Could Be Improved In This Project
**Add Real World Data Inputs** - Include data-driven parameters from UNAIDS or CDC<br>
**Export Features** - Enable export to CSV or PDF for reporting<br>
**Sensitivity Analysis Panel** - Allow the user to run multiple scenarios and compare outcomes.<br>
**Live Heatmaps or Network Visuals** - Show evolving spatial spread or social networks for agents<br>
**Parallelization** - Use multiprocessing or MesaBatchRunner for running many simulations at once<br>

## Questions The Class/Lecturer Might Ask
**Why use an Agent-Based Model over a simpler model like SIR?**<br> ABMs simulate individuals rather than aggregated populations. This allows us to incorporate heterogeneity in behavior, treatment access, and network structure - key aspects of HIV spread which are oversimplified in SIR<br>

**How were the parameter values selected?**<br> We chose baseline values by epidemiological estimates and literature (e.g., 96% treatment efficacy). The dashboard allows real-time sensitivity analysis by adjusting sliders.<br>

**What are the main insights this model can provide?**<br> It demonstrates how intervention strategies (condom use, testing, treatment) impact HIV prevalence and mortality - helping policy makers understand dynamic effects over time.<br>

**What are the limitations of this model**<br> The model simplifies sexual behavior patterns and does not currently include reinfections, age structure, or vertical transmission. It also assumes consistent agent behavior across time.<br>

**Can this be extended to model co-infection like TB-HIV?**<br> Yes - ABMs are flexible. Adding co-infection would involve defining additional state and interaction rules between infections.<br>


In [None]:
import time
import streamlit as st
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from enum import IntEnum
from mesa.agent import Agent
from mesa.model import Model
from mesa.time import RandomActivation
from mesa.space import MultiGrid
from mesa.datacollection import DataCollector

# PARAMETER DEFINITIONS:
# N: Total number of agents (people) in the simulation
# init_inf: Initial number of infected agents at the start
# base_trans: Base probability of HIV transmission per contact
# condom_use: Reduces base transmission rate (as a proportion)
# treat_cov: Proportion of tested individuals who receive treatment
# treat_eff: Effectiveness of treatment in reducing HIV progression
# prog_time: Average time (in days) from infection to AIDS without treatment
# prog_sd: Standard deviation in progression time across agents
# dr_unc: Annual death rate for untreated AIDS cases
# dr_tr: Annual death rate for treated AIDS cases
# test_rate: Probability per year that an individual gets tested
# days: Number of steps/days the simulation runs

# --- PAGE CONFIG & STYLES -------------------------------------------------
st.set_page_config(page_title="HIV MODEL - AGENT BASED SIMULATION", layout="wide")
st.markdown("""
<style>
.main .block-container { padding: 1rem 2rem; }
</style>
""", unsafe_allow_html=True)

# --- HEADER & DESCRIPTION --------------------------------------------------
st.title("HIV MODEL - AGENT BASED SIMULATION")
st.markdown("""
**Objective:** Simulate how HIV spreads and evolves in a population over time, and evaluate the impact of interventions (e.g., increased treatment coverage, condom use) on HIV prevalence, incidence, and mortality.

**Aims:**
1. Model individual-level interactions and disease progression.  
2. Compare epidemic trajectories under different prevention and treatment scenarios.  
3. Provide an interactive dashboard to explore parameter influence on outcomes.
""")

# Explanatory Modeling
with st.expander("What is Explanatory Modeling?"):
    st.markdown("""
- We observe a phenomenon (HIV transmission patterns).  
- We hypothesize key drivers (treatment, condom use, testing).  
- We build an approximate representation (agent-based model).  
- If the model replicates real-world behavior, our hypotheses gain support.
""")

# --- MODEL CLASSES ---------------------------------------------------------
class HIVStatus(IntEnum):
    SUSCEPTIBLE        = 0
    INFECTED_UNTREATED = 1
    INFECTED_TREATED   = 2
    AIDS_UNTREATED     = 3
    AIDS_TREATED       = 4
    DEAD               = 5

class PersonAgent(Agent):
    def __init__(self, unique_id, model):
        super().__init__(unique_id, model)
        self.hiv_status = HIVStatus.SUSCEPTIBLE
        self.infection_time = None
        self.progression_time = None
        self.knows_status = False
        self.on_treatment = False

    def move(self):
        neighbors = self.model.grid.get_neighborhood(self.pos, moore=True, include_center=False)
        self.model.grid.move_agent(self, self.random.choice(neighbors))

    def update_disease(self):
        if self.hiv_status in (HIVStatus.SUSCEPTIBLE, HIVStatus.DEAD): return
        # Testing & treatment
        if not self.knows_status and self.random.random() < self.model.testing_rate:
            self.knows_status = True
            if self.random.random() < self.model.treatment_coverage:
                self.on_treatment = True
                if self.hiv_status == HIVStatus.INFECTED_UNTREATED:
                    self.hiv_status = HIVStatus.INFECTED_TREATED
                elif self.hiv_status == HIVStatus.AIDS_UNTREATED:
                    self.hiv_status = HIVStatus.AIDS_TREATED
        # Progression
        if self.hiv_status in (HIVStatus.INFECTED_UNTREATED, HIVStatus.INFECTED_TREATED):
            elapsed = self.model.schedule.time - self.infection_time
            eff = elapsed/2 if self.on_treatment else elapsed
            if eff >= self.progression_time:
                self.hiv_status = HIVStatus.AIDS_TREATED if self.on_treatment else HIVStatus.AIDS_UNTREATED
        # Death
        drate = {
            HIVStatus.AIDS_UNTREATED: self.model.death_rate_untreated,
            HIVStatus.AIDS_TREATED:   self.model.death_rate_treated,
            HIVStatus.INFECTED_UNTREATED: self.model.death_rate_untreated/10
        }.get(self.hiv_status, 0)
        if self.random.random() < drate:
            self.hiv_status = HIVStatus.DEAD
            self.model.deaths += 1

    def interact(self):
        if self.hiv_status not in (
            HIVStatus.INFECTED_UNTREATED, HIVStatus.INFECTED_TREATED,
            HIVStatus.AIDS_UNTREATED, HIVStatus.AIDS_TREATED): return
        mates = self.model.grid.get_cell_list_contents([self.pos])
        for other in mates:
            if other.hiv_status != HIVStatus.SUSCEPTIBLE: continue
            prob = self.model.base_transmission_rate
            if self.hiv_status == HIVStatus.AIDS_UNTREATED: prob *= 2.5
            elif self.hiv_status == HIVStatus.AIDS_TREATED: prob *= 0.5
            if self.on_treatment: prob *= (1 - self.model.treatment_effectiveness)
            if self.random.random() < prob:
                other.hiv_status = HIVStatus.INFECTED_UNTREATED
                other.infection_time = self.model.schedule.time
                other.progression_time = self.model.get_progression_time()
                self.model.infections += 1

    def step(self):
        self.update_disease()
        self.move()
        self.interact()

class HIVModel(Model):
    def __init__(self, N, width, height, init_inf,
                 base_trans, condom_use, treat_cov, treat_eff,
                 prog_time, prog_sd, dr_unc, dr_tr, test_rate):
        super().__init__()
        self.grid = MultiGrid(width, height, True)
        self.schedule = RandomActivation(self)
        # Parameters
        self.base_transmission_rate = base_trans
        self.condom_use_rate = condom_use
        self.treatment_coverage = treat_cov
        self.treatment_effectiveness = treat_eff
        self.death_rate_untreated = dr_unc/365
        self.death_rate_treated = dr_tr/365
        self.testing_rate = test_rate/365
        self.prog_time = prog_time
        self.prog_sd = prog_sd
        self.infections = 0
        self.deaths = 0
        # Initialize agents
        for i in range(N):
            agent = PersonAgent(i, self)
            self.schedule.add(agent)
            x, y = self.random.randrange(width), self.random.randrange(height)
            self.grid.place_agent(agent, (x, y))
        # Seed infections
        initially = self.random.sample(self.schedule.agents, init_inf)
        for a in initially:
            a.hiv_status = HIVStatus.INFECTED_UNTREATED
            a.infection_time = 0
            a.progression_time = self.get_progression_time()
        # Data collector
        self.datacollector = DataCollector(
            model_reporters={
                'Step': lambda m: m.schedule.time,
                'Sus': lambda m: sum(a.hiv_status==HIVStatus.SUSCEPTIBLE for a in m.schedule.agents),
                'Inf_U': lambda m: sum(a.hiv_status==HIVStatus.INFECTED_UNTREATED for a in m.schedule.agents),
                'Inf_T': lambda m: sum(a.hiv_status==HIVStatus.INFECTED_TREATED for a in m.schedule.agents),
                'AIDS_U': lambda m: sum(a.hiv_status==HIVStatus.AIDS_UNTREATED for a in m.schedule.agents),
                'AIDS_T': lambda m: sum(a.hiv_status==HIVStatus.AIDS_TREATED for a in m.schedule.agents),
                'Dead': lambda m: sum(a.hiv_status==HIVStatus.DEAD for a in m.schedule.agents),
                'NewInf': lambda m: m.infections,
                'Deaths': lambda m: m.deaths
            }
        )
    def get_progression_time(self):
        return int(np.random.normal(self.prog_time, self.prog_sd))
    def step(self):
        self.infections = 0
        self.deaths = 0
        self.schedule.step()
        self.datacollector.collect(self)

# --- STREAMLIT SIDEBAR -----------------------------------------------------
st.sidebar.header('Simulation Parameters')
cols = st.sidebar.columns(2)
params = {
    'N': cols[0].slider('Population', 100, 2000, 500, help='Total agents'),
    'init_inf': cols[0].slider('Initial Infected', 1, 50, 10, help='Seed infections'),
    'base_trans': cols[0].slider('Transmission Prob.', 0.01, 1.0, 0.1, 0.01, help='Per-contact risk'),
    'condom_use': cols[1].slider('Condom Use Rate', 0.0, 1.0, 0.5, 0.05, help='Reduces transmission'),
    'treat_cov': cols[1].slider('Treatment Coverage', 0.0, 1.0, 0.2, 0.05, help='Proportion tested and treated'),
    'treat_eff': cols[1].slider('Treatment Efficacy', 0.0, 1.0, 0.96, 0.01, help='Reduces progression'),
    'prog_time': cols[0].number_input('Prog. Time (d)', 365, 3650, 3650, help='Untreated → AIDS'),
    'prog_sd': cols[0].number_input('Prog. SD (d)', 100, 2000, 1095, help='Variation in progression'),
    'dr_unc': cols[1].slider('Death Rate Untreated', 0.0, 1.0, 0.1, 0.01, help='Annual untreated'),
    'dr_tr': cols[1].slider('Death Rate Treated', 0.0, 1.0, 0.02, 0.005, help='Annual treated'),
    'test_rate': cols[1].slider('Testing Rate', 0.0, 1.0, 0.3, 0.01, help='Annual probability'),
    'days': st.sidebar.slider('Days', 10, 1000, 100, help='Simulation length')
}
if 'stop' not in st.session_state:
    st.session_state.stop = False

col_run, col_stop = st.sidebar.columns(2)
col_run.button('Start') and st.session_state.__setitem__('stop', False)
col_stop.button('Stop') and st.session_state.__setitem__('stop', True)

# --- MAIN CONTENT AREA with Animation -------------------------------------
ts_placeholder = st.empty()
gr_placeholder = st.empty()

# Animation loop
if st.sidebar.button('Run'):
    days = params.pop('days')
    model = HIVModel(**params, width=20, height=20)
    steps, data = [], {k: [] for k in ['Sus','Inf_U','Inf_T','AIDS_U','AIDS_T','Dead']}
    for _ in range(days):
        if st.session_state.stop:
            break
        model.step()
        rec = model.datacollector.get_model_vars_dataframe().iloc[-1]
        steps.append(rec['Step'])
        for k in data: data[k].append(rec[k])
        # Smaller plots
        fig1, ax1 = plt.subplots(figsize=(3,2))
        for k in data: ax1.plot(steps, data[k], label=k)
        ax1.legend(fontsize='x-small')
        ts_placeholder.pyplot(fig1)
        plt.close(fig1)

        fig2, ax2 = plt.subplots(figsize=(3,2))
        w, h = model.grid.width, model.grid.height
        arr = np.full((w,h), np.nan)
        for cell,(x,y) in model.grid.coord_iter():
            if cell: arr[x,y] = cell[-1].hiv_status
        ax2.imshow(arr, vmin=0, vmax=5, cmap='tab20')
        ax2.axis('off')
        gr_placeholder.pyplot(fig2)
        plt.close(fig2)
        time.sleep(0.05)

# Sidebar branding
st.sidebar.markdown('---')
st.sidebar.markdown('**HIV ABM Modeling**')
st.sidebar.caption('Powered by Mesa & Streamlit')
