# Demonstrate observables

In [1]:
import mira
from mira.metamodel import Observable, SympyExprStr
from mira.modeling import ModelObservable
import sympy
import torch
from copy import deepcopy as _d
from mira.metamodel.ops import stratify
from sympytorch import SymPyModule
from pyciemss.PetriNetODE.interfaces import (
    load_petri_model,
    setup_petri_model,
    sample_petri,
    calibrate,
    load_and_sample_petri_model,
    load_and_calibrate_and_sample_petri_model, 
)
from pyciemss.Ensemble.interfaces import (
    load_and_sample_petri_ensemble
)
import pandas as pd
from pyciemss.utils.interface_utils import convert_to_output_format, csv_to_list, solutions_to_observations
from typing import Dict, List 

Solution = Dict[str, torch.tensor]

In [2]:
def get_sampled_value(name, model):
    trace_handler = pyro.poutine.trace(model)
    trace = trace_handler.get_trace()
    if name in trace.nodes and trace.nodes[name]['type'] == 'sample':
        return trace.nodes[name]['value']
    raise KeyError(f'{name} not in trace {trace.nodes}')



In [3]:
from sympy import symbols, exp
from sympy.printing.mathml import mathml

I, S, R, beta, gamma = symbols('I S R beta gamma')

expr = I*S*beta/(S+I+R)
mathml_str = mathml(expr)

print(mathml_str)


<apply><divide/><apply><times/><ci>I</ci><ci>S</ci><ci>&#946;</ci></apply><apply><plus/><ci>I</ci><ci>R</ci><ci>S</ci></apply></apply>


## Observe half the true population

In [4]:
from mira.examples.sir import sir_parameterized
tm = _d(sir_parameterized)
tm.initials
symbols = set(tm.initials)
expr = sympy.Add(*[sympy.Symbol(s) for s in symbols])
tm.observables = {'half_population': Observable(
        name='half_population',
        expression=SympyExprStr(expr/2))
    }

tm.observables['half_population'].expression.args[0]

immune_population/2 + infected_population/2 + susceptible_population/2

In [5]:
tm.initials

{'susceptible_population': Initial(concept=Concept(name='susceptible_population', display_name=None, description=None, identifiers={'ido': '0000514'}, context={}, units=None), value=1.0),
 'infected_population': Initial(concept=Concept(name='infected_population', display_name=None, description=None, identifiers={'ido': '0000511'}, context={}, units=None), value=2.0),
 'immune_population': Initial(concept=Concept(name='immune_population', display_name=None, description=None, identifiers={'ido': '0000592'}, context={}, units=None), value=3.0)}

## Compile the observable expression to pytorch

In [6]:
half_population = SymPyModule(expressions=[observable.expression.args[0] 
                                           for observable in tm.observables.values()])

## Expected observable value

In [7]:
expected_total_population = dict(
    infected_population=torch.tensor(1.0),
    immune_population=torch.tensor(0.0),
    susceptible_population=torch.tensor(100.0)
)

expected_half_population = half_population(**expected_total_population)                                
assert expected_half_population == torch.tensor([50.5])

## Generate samples from the template model

In [8]:
G = mira.modeling.Model(tm)
G.observables

{'half_population': <mira.modeling.ModelObservable at 0x158189c90>}

In [9]:
import torch
import pyro
num_samples = 2
data_path = 'sir_data.csv'
#sir_path = '../../test/models/AMR_examples/sir_typed.json'
timepoints = [0.1, 0.2, 0.3]
raw_sir = load_petri_model(tm, compile_observables_p=True)

In [10]:
def observation_model(solution: Solution, var_name: str) -> None:
    pass

In [11]:
raw_sir.G.variables

{('susceptible_population',
  ('identity', 'ido:0000514')): <mira.modeling.Variable at 0x158189b70>,
 ('infected_population',
  ('identity', 'ido:0000511')): <mira.modeling.Variable at 0x15818bc70>,
 ('immune_population',
  ('identity', 'ido:0000592')): <mira.modeling.Variable at 0x1581fc280>}

In [12]:
sir = setup_petri_model(raw_sir, 0.0, dict(susceptible_population=1000.0, infected_population=1.0, immune_population=0.0))
sir_samples = sample_petri(sir, timepoints , num_samples)
sir_samples


{'beta': tensor([0.0978, 0.1035]),
 'gamma': tensor([0.1866, 0.2067]),
 'immune_population_sol': tensor([[0.0186, 0.0370, 0.0553],
         [0.0206, 0.0409, 0.0611]]),
 'infected_population_sol': tensor([[0.9911, 0.9824, 0.9737],
         [0.9897, 0.9796, 0.9695]]),
 'susceptible_population_sol': tensor([[999.9902, 999.9807, 999.9711],
         [999.9897, 999.9794, 999.9683]])}

In [13]:
sir_sample_df = convert_to_output_format(sir_samples, timepoints)
sir_sample_df

Unnamed: 0,timepoint_id,sample_id,beta_param,gamma_param,immune_population_sol,infected_population_sol,susceptible_population_sol,timepoint_(unknown)
0,0,0,0.097803,0.186643,0.018582,0.991145,999.990173,0.1
1,1,0,0.097803,0.186643,0.036999,0.98237,999.980652,0.2
2,2,0,0.097803,0.186643,0.055253,0.973672,999.971069,0.3
3,0,1,0.103537,0.206742,0.020568,0.989722,999.989685,0.1
4,1,1,0.103537,0.206742,0.040924,0.97955,999.97937,0.2
5,2,1,0.103537,0.206742,0.061071,0.969481,999.968323,0.3


In [14]:
observations = solutions_to_observations(timepoints, sir_sample_df.set_index(['timepoint_id', 'sample_id']))
observations[0]

Unnamed: 0_level_0,Unnamed: 1_level_0,Timestep,immune_population,infected_population,susceptible_population
timepoint_id,sample_id,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,0,0.1,0.018582,0.991145,999.990173
1,0,0.2,0.036999,0.98237,999.980652
2,0,0.3,0.055253,0.973672,999.971069


In [15]:
observations[0].to_csv(data_path, index=False)
sir_data = csv_to_list(data_path)
for timepoint, data in sir_data:
    data['half_population'] = (
        data['immune_population'] + data['susceptible_population'] + data['infected_population'])/2
sir_data
    
                      

[(0.1,
  {'immune_population': 0.01858154684305191,
   'infected_population': 0.9911453723907471,
   'susceptible_population': 999.9901733398438,
   'half_population': 500.4999501295388}),
 (0.2,
  {'immune_population': 0.036998581141233444,
   'infected_population': 0.982369601726532,
   'susceptible_population': 999.9806518554688,
   'half_population': 500.50001001916826}),
 (0.3,
  {'immune_population': 0.05525253713130951,
   'infected_population': 0.9736716151237488,
   'susceptible_population': 999.9710693359375,
   'half_population': 500.4999967440963})]

In [16]:
get_sampled_value('infected_population_sol', sir)

tensor([])

In [17]:
inferred_parameters = calibrate(sir, sir_data, num_iterations=10)


immune_population, torch.Size([3])
infected_population, torch.Size([3])
susceptible_population, torch.Size([3])
half_population, torch.Size([3])
immune_population, torch.Size([3])
infected_population, torch.Size([3])
susceptible_population, torch.Size([3])
half_population, torch.Size([3])
immune_population, torch.Size([3])
infected_population, torch.Size([3])
susceptible_population, torch.Size([3])
half_population, torch.Size([3])
immune_population, torch.Size([3])
infected_population, torch.Size([3])
susceptible_population, torch.Size([3])
half_population, torch.Size([3])
immune_population, torch.Size([3])
infected_population, torch.Size([3])
susceptible_population, torch.Size([3])
half_population, torch.Size([3])
immune_population, torch.Size([3])
infected_population, torch.Size([3])
susceptible_population, torch.Size([3])
half_population, torch.Size([3])
immune_population, torch.Size([3])
infected_population, torch.Size([3])
susceptible_population, torch.Size([3])
half_population, t

In [18]:
calibrated_samples = sample_petri(sir, timepoints, num_samples, inferred_parameters)

In [19]:
calibrated_samples

{'beta': tensor([0.0994, 0.1013]),
 'gamma': tensor([0.1945, 0.1967]),
 'immune_population_sol': tensor([[0.0194, 0.0385, 0.0575],
         [0.0196, 0.0390, 0.0582]]),
 'infected_population_sol': tensor([[0.9905, 0.9811, 0.9718],
         [0.9905, 0.9811, 0.9718]]),
 'susceptible_population_sol': tensor([[999.9902, 999.9793, 999.9706],
         [999.9899, 999.9794, 999.9702]])}

In [20]:
convert_to_output_format(calibrated_samples, timepoints)

Unnamed: 0,timepoint_id,sample_id,beta_param,gamma_param,immune_population_sol,infected_population_sol,susceptible_population_sol,timepoint_(unknown)
0,0,0,0.099411,0.194494,0.019357,0.990527,999.990173,0.1
1,1,0,0.099411,0.194494,0.038531,0.981143,999.979309,0.2
2,2,0,0.099411,0.194494,0.057523,0.971849,999.970581,0.3
3,0,1,0.10134,0.196656,0.019572,0.990504,999.989929,0.1
4,1,1,0.10134,0.196656,0.038958,0.981098,999.979431,0.2
5,2,1,0.10134,0.196656,0.05816,0.97178,999.970154,0.3
