# Demonstrate observables

In [1]:
import mira
from mira.metamodel import Observable, SympyExprStr
from mira.modeling import ModelObservable
import sympy
import torch
from copy import deepcopy as _d
from mira.metamodel.ops import stratify
from sympytorch import SymPyModule
from pyciemss.PetriNetODE.interfaces import (
    load_petri_model,
    setup_petri_model,
    sample_petri,
    calibrate,
    load_and_sample_petri_model,
    load_and_calibrate_and_sample_petri_model, 
)
from pyciemss.Ensemble.interfaces import (
    load_and_sample_petri_ensemble
)
import pandas as pd
from pyciemss.utils.interface_utils import convert_to_output_format, csv_to_list, solutions_to_observations
from typing import Dict, List 
Solution = Dict[str, torch.tensor]

In [2]:
def get_sampled_value(name, model):
    trace_handler = pyro.poutine.trace(model)
    trace = trace_handler.get_trace()
    if name in trace.nodes and trace.nodes[name]['type'] == 'sample':
        return trace.nodes[name]['value']
    raise KeyError(f'{name} not in trace {trace.nodes}')



In [3]:
from sympy import symbols, exp
from sympy.printing.mathml import mathml

I, S, R, beta, gamma = symbols('infected_population susceptible_population immune_population beta gamma')

expr = I*S*beta/(S+I+R)
mathml_str = mathml(expr)

print(mathml_str)


<apply><divide/><apply><times/><ci>&#946;</ci><ci><mml:msub><mml:mi>infected</mml:mi><mml:mi>population</mml:mi></mml:msub></ci><ci><mml:msub><mml:mi>susceptible</mml:mi><mml:mi>population</mml:mi></mml:msub></ci></apply><apply><plus/><ci><mml:msub><mml:mi>immune</mml:mi><mml:mi>population</mml:mi></mml:msub></ci><ci><mml:msub><mml:mi>infected</mml:mi><mml:mi>population</mml:mi></mml:msub></ci><ci><mml:msub><mml:mi>susceptible</mml:mi><mml:mi>population</mml:mi></mml:msub></ci></apply></apply>


## Observe half the true population

In [4]:
from mira.examples.sir import sir_parameterized
tm = _d(sir_parameterized)
tm.initials
symbols = set(tm.initials)
expr = sympy.Add(*[sympy.Symbol(s) for s in symbols])
tm.observables = {'half_population': Observable(
        name='half_population',
        expression=SympyExprStr(expr/2))
    }

tm.observables['half_population'].expression.args[0]

immune_population/2 + infected_population/2 + susceptible_population/2

In [5]:
tm.initials

{'susceptible_population': Initial(concept=Concept(name='susceptible_population', display_name=None, description=None, identifiers={'ido': '0000514'}, context={}, units=None), value=1.0),
 'infected_population': Initial(concept=Concept(name='infected_population', display_name=None, description=None, identifiers={'ido': '0000511'}, context={}, units=None), value=2.0),
 'immune_population': Initial(concept=Concept(name='immune_population', display_name=None, description=None, identifiers={'ido': '0000592'}, context={}, units=None), value=3.0)}

## Compile the observable expression to pytorch

In [6]:
half_population = SymPyModule(expressions=[observable.expression.args[0] 
                                           for observable in tm.observables.values()])

## Expected observable value

In [7]:
expected_total_population = dict(
    infected_population=torch.tensor(1.0),
    immune_population=torch.tensor(0.0),
    susceptible_population=torch.tensor(100.0)
)

expected_half_population = half_population(**expected_total_population)                                
assert expected_half_population == torch.tensor([50.5])

## Generate samples from the template model

In [8]:
G = mira.modeling.Model(tm)
G.observables

{'half_population': <mira.modeling.ModelObservable at 0x14bea49a0>}

In [9]:
import torch
import pyro
num_samples = 2
data_path = 'sir_data.csv'
#sir_path = '../../test/models/AMR_examples/sir_typed.json'
timepoints = [0.1, 0.2, 0.3]
raw_sir = load_petri_model(tm, compile_observables_p=True)

In [10]:
def observation_model(solution: Solution, var_name: str) -> None:
    pass

In [11]:
raw_sir.G.variables

{('susceptible_population',
  ('identity', 'ido:0000514')): <mira.modeling.Variable at 0x14bea5a80>,
 ('infected_population',
  ('identity', 'ido:0000511')): <mira.modeling.Variable at 0x14bea6680>,
 ('immune_population',
  ('identity', 'ido:0000592')): <mira.modeling.Variable at 0x14bea6500>}

In [12]:
sir = setup_petri_model(raw_sir, 0.0, dict(susceptible_population=1000.0, infected_population=1.0, immune_population=0.0))
sir_samples = sample_petri(sir, timepoints , num_samples)
sir_samples


{'beta': tensor([0.0999, 0.1084]),
 'gamma': tensor([0.2114, 0.1813]),
 'immune_population_sol': tensor([[0.0210, 0.0418, 0.0624],
         [0.0181, 0.0360, 0.0538]]),
 'infected_population_sol': tensor([[0.9889, 0.9779, 0.9671],
         [0.9927, 0.9855, 0.9784]]),
 'susceptible_population_sol': tensor([[999.9901, 999.9809, 999.9705],
         [999.9892, 999.9785, 999.9679]])}

In [13]:
sir_sample_df = convert_to_output_format(sir_samples, timepoints)
sir_sample_df

Unnamed: 0,timepoint_id,sample_id,beta_param,gamma_param,immune_population_sol,infected_population_sol,susceptible_population_sol,timepoint_(unknown)
0,0,0,0.099863,0.21141,0.021023,0.988897,999.990051,0.1
1,1,0,0.099863,0.21141,0.041813,0.977919,999.980896,0.2
2,2,0,0.099863,0.21141,0.062373,0.96706,999.97052,0.3
3,0,1,0.108438,0.181257,0.01806,0.992734,999.989197,0.1
4,1,1,0.108438,0.181257,0.035988,0.98552,999.978455,0.2
5,2,1,0.108438,0.181257,0.053787,0.978359,999.967896,0.3


In [14]:
observations = solutions_to_observations(timepoints, sir_sample_df.set_index(['timepoint_id', 'sample_id']))
observations[0]

Unnamed: 0_level_0,Unnamed: 1_level_0,Timestep,immune_population,infected_population,susceptible_population
timepoint_id,sample_id,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,0,0.1,0.021023,0.988897,999.990051
1,0,0.2,0.041813,0.977919,999.980896
2,0,0.3,0.062373,0.96706,999.97052


In [15]:
observations[0].to_csv(data_path, index=False)
sir_data = csv_to_list(data_path)
for timepoint, data in sir_data:
    data['half_population'] = (
        data['immune_population'] + data['susceptible_population'] + data['infected_population'])/2
sir_data
    
                      

[(0.1,
  {'immune_population': 0.0210234597325325,
   'infected_population': 0.9888972043991089,
   'susceptible_population': 999.9900512695312,
   'half_population': 500.49998596683145}),
 (0.2,
  {'immune_population': 0.04181347414851189,
   'infected_population': 0.9779185652732849,
   'susceptible_population': 999.9808959960938,
   'half_population': 500.5003140177578}),
 (0.3,
  {'immune_population': 0.06237271428108215,
   'infected_population': 0.9670599699020386,
   'susceptible_population': 999.9705200195312,
   'half_population': 500.4999763518572})]

In [16]:
get_sampled_value('infected_population_sol', sir)

tensor([])

In [17]:
inferred_parameters = calibrate(sir, sir_data, num_iterations=10)


In [18]:
calibrated_samples = sample_petri(sir, timepoints, num_samples, inferred_parameters)

In [19]:
calibrated_samples

{'beta': tensor([0.0995, 0.0993]),
 'gamma': tensor([0.2044, 0.2032]),
 'immune_population_sol': tensor([[0.0203, 0.0405, 0.0604],
         [0.0202, 0.0402, 0.0600]]),
 'infected_population_sol': tensor([[0.9896, 0.9792, 0.9690],
         [0.9897, 0.9794, 0.9693]]),
 'susceptible_population_sol': tensor([[999.9902, 999.9803, 999.9701],
         [999.9901, 999.9804, 999.9707]])}

In [20]:
convert_to_output_format(calibrated_samples, timepoints)

Unnamed: 0,timepoint_id,sample_id,beta_param,gamma_param,immune_population_sol,infected_population_sol,susceptible_population_sol,timepoint_(unknown)
0,0,0,0.099521,0.204428,0.020336,0.989554,999.990173,0.1
1,1,0,0.099521,0.204428,0.040459,0.979217,999.980347,0.2
2,2,0,0.099521,0.204428,0.060373,0.968987,999.970093,0.3
3,0,1,0.0993,0.203188,0.020213,0.989655,999.990112,0.1
4,1,1,0.0993,0.203188,0.040218,0.979417,999.980408,0.2
5,2,1,0.0993,0.203188,0.060015,0.969285,999.970703,0.3


## Nonobservable data

In [21]:
sir_data = csv_to_list(data_path)                     
for timepoint, data in sir_data:
    data['definitely_not_half_population'] = (
        data['immune_population'] + data['susceptible_population'] + data['infected_population'])/2
sir_data
    

[(0.1,
  {'immune_population': 0.0210234597325325,
   'infected_population': 0.9888972043991089,
   'susceptible_population': 999.9900512695312,
   'definitely_not_half_population': 500.49998596683145}),
 (0.2,
  {'immune_population': 0.04181347414851189,
   'infected_population': 0.9779185652732849,
   'susceptible_population': 999.9808959960938,
   'definitely_not_half_population': 500.5003140177578}),
 (0.3,
  {'immune_population': 0.06237271428108215,
   'infected_population': 0.9670599699020386,
   'susceptible_population': 999.9705200195312,
   'definitely_not_half_population': 500.4999763518572})]

In [22]:
try:
    inferred_parameters = calibrate(sir, sir_data, num_iterations=10)
except KeyError as k:
    print(k)

ERROR:root:
                ###############################

                There was an exception in pyciemss
                
                Error occured in function: calibrate_petri

                Function docs : 
    Use variational inference with a mean-field variational family to infer the parameters of the model.
    

                ################################
            
Traceback (most recent call last):
  File "/Users/zuck016/Projects/Proposals/ASKEM/build/clean-build/src/pyciemss/custom_decorators.py", line 9, in wrapped
    result = function(*args, **kwargs)
  File "/Users/zuck016/Projects/Proposals/ASKEM/build/clean-build/src/pyciemss/PetriNetODE/interfaces.py", line 755, in calibrate_petri
    loss = svi.step(method=method)
  File "/Users/zuck016/.pyenv/versions/clean-build/lib/python3.10/site-packages/pyro/infer/svi.py", line 145, in step
    loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
  File "/Users/zuck016/.pyenv/versions/clean-buil

'definitely_not_half_population'


#### Lambdify

In [23]:
from mira.metamodel.template_model import Observable
from sympy import lambdify, symbols
S, I = symbols('Susceptible Infected')
observables = {'S+I': Observable(expression=S + I, name='S+I')}
observable_function = {k: lambdify([S, I], S+I)
                                                for k, v in observables.items()}
        


### SympyModule

In [24]:
from mira.metamodel.template_model import Observable
from sympytorch import SymPyModule
S, I = symbols('Susceptible Infected')
observables = {'S+I': Observable(expression=S + I, name='S+I')}
observable_function = {k: SymPyModule(expressions=[S+I])
                                                for k, v in observables.items()}
        


In [25]:
observables['S+I'].expression.args[0]

Infected + Susceptible

### Test observable function

In [26]:
observable_function['S+I'](**dict(Susceptible=100.0, Infected=1.0))

tensor([101.])

In [27]:
from mira.sources.askenet import model_from_json_file
from pyciemss.utils.interface_utils import solutions_to_observations
sidarthe_data_path = '../../test/test_mira/sidarthe_data.csv'
sidarthe_model_path = '../../test/models/AMR_examples/SIDARTHE.amr.json'
idx = ['timepoint_id', 'sample_id']
timepoints = [t*0.1 for t in range(100)]
sidarthe_mira = model_from_json_file(sidarthe_model_path)        
sidarthe_samples = load_and_sample_petri_model(sidarthe_mira, 1, timepoints)['data'].set_index(idx)
sol_obs = [c for c in sidarthe_samples.columns if c[-4:] in ['_sol', '_obs']]

sidarthe_samples.reset_index(level='sample_id', drop=True)

expression_vars: {'Ailing': tensor([1.6667e-08, 6.5087e-08, 1.1267e-07, 1.5951e-07, 2.0570e-07, 2.5130e-07,
        2.9642e-07, 3.4112e-07, 3.8549e-07, 4.2960e-07, 4.7353e-07, 5.1735e-07,
        5.6114e-07, 6.0499e-07, 6.4895e-07, 6.9311e-07, 7.3750e-07, 7.8219e-07,
        8.2723e-07, 8.7270e-07, 9.1863e-07, 9.6511e-07, 1.0122e-06, 1.0599e-06,
        1.1084e-06, 1.1576e-06, 1.2078e-06, 1.2588e-06, 1.3108e-06, 1.3639e-06,
        1.4181e-06, 1.4735e-06, 1.5302e-06, 1.5883e-06, 1.6477e-06, 1.7086e-06,
        1.7711e-06, 1.8351e-06, 1.9008e-06, 1.9682e-06, 2.0374e-06, 2.1085e-06,
        2.1816e-06, 2.2566e-06, 2.3338e-06, 2.4131e-06, 2.4948e-06, 2.5788e-06,
        2.6652e-06, 2.7542e-06, 2.8458e-06, 2.9402e-06, 3.0374e-06, 3.1375e-06,
        3.2407e-06, 3.3470e-06, 3.4565e-06, 3.5695e-06, 3.6859e-06, 3.8059e-06,
        3.9296e-06, 4.0571e-06, 4.1886e-06, 4.3242e-06, 4.4640e-06, 4.6081e-06,
        4.7568e-06, 4.9102e-06, 5.0684e-06, 5.2315e-06, 5.3998e-06, 5.5734e-06,
        5.75

Unnamed: 0_level_0,beta_param,gamma_param,delta_param,alpha_param,epsilon_param,zeta_param,lambda_param,eta_param,rho_param,theta_param,...,Diagnosed_sol,Extinct_sol,Healed_sol,Infected_sol,Recognized_sol,Susceptible_sol,Threatened_sol,Cases_obs,Hospitalizations_obs,Deaths_obs
timepoint_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,0.009201,0.519076,0.009801,0.547397,0.165703,0.148593,0.033269,0.140474,0.037794,0.361518,...,3.333333e-07,5.958760e-32,1.243178e-17,0.000003,3.333333e-08,0.999996,1.246214e-19,3.666667e-07,3.333333e-08,5.958760e-32
1,0.009201,0.519076,0.009801,0.547397,0.165703,0.148593,0.033269,0.140474,0.037794,0.361518,...,3.827578e-07,7.577925e-14,1.268764e-08,0.000003,3.968124e-08,0.999996,1.756296e-10,4.226147e-07,3.985687e-08,7.577925e-14
2,0.009201,0.519076,0.009801,0.547397,0.165703,0.148593,0.033269,0.140474,0.037794,0.361518,...,4.324895e-07,3.694298e-13,2.589829e-08,0.000003,4.842554e-08,0.999996,4.567468e-10,4.813718e-07,4.888229e-08,3.694298e-13
3,0.009201,0.519076,0.009801,0.547397,0.165703,0.148593,0.033269,0.140474,0.037794,0.361518,...,4.825880e-07,9.840602e-13,3.964909e-08,0.000004,5.953089e-08,0.999996,8.485231e-10,5.429674e-07,6.037941e-08,9.840602e-13
4,0.009201,0.519076,0.009801,0.547397,0.165703,0.148593,0.033269,0.140474,0.037794,0.361518,...,5.331121e-07,2.027941e-12,5.395750e-08,0.000004,7.296917e-08,0.999996,1.356097e-09,6.074374e-07,7.432527e-08,2.027941e-12
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,0.009201,0.519076,0.009801,0.547397,0.165703,0.148593,0.033269,0.140474,0.037794,0.361518,...,1.904889e-05,4.082430e-08,9.264608e-06,0.000057,1.830585e-05,0.999883,1.858121e-06,3.921286e-05,2.016397e-05,4.082430e-08
96,0.009201,0.519076,0.009801,0.547397,0.165703,0.148593,0.033269,0.140474,0.037794,0.361518,...,1.966099e-05,4.263458e-08,9.583470e-06,0.000059,1.893099e-05,0.999879,1.928117e-06,4.052010e-05,2.085910e-05,4.263458e-08
97,0.009201,0.519076,0.009801,0.547397,0.165703,0.148593,0.033269,0.140474,0.037794,0.361518,...,2.029257e-05,4.451286e-08,9.912605e-06,0.000061,1.957596e-05,0.999875,2.000437e-06,4.186897e-05,2.157640e-05,4.451286e-08
98,0.009201,0.519076,0.009801,0.547397,0.165703,0.148593,0.033269,0.140474,0.037794,0.361518,...,2.094426e-05,4.646142e-08,1.025235e-05,0.000062,2.024142e-05,0.999871,2.075157e-06,4.326084e-05,2.231658e-05,4.646142e-08


In [28]:
sidarthe_mira.observables

{'Cases': Observable(name='Cases', display_name=None, description=None, identifiers={}, context={}, units=None, expression=Diagnosed + Recognized + Threatened),
 'Hospitalizations': Observable(name='Hospitalizations', display_name=None, description=None, identifiers={}, context={}, units=None, expression=Recognized + Threatened),
 'Deaths': Observable(name='Deaths', display_name=None, description=None, identifiers={}, context={}, units=None, expression=Extinct)}

In [29]:
sidarthe_model = load_petri_model(sidarthe_mira, compile_rate_law_p=True, 
                                  compile_observables_p=True)
sidarthe_model.compiled_observables

{'Cases': SymPyModule(expressions=(Diagnosed + Recognized + Threatened,)),
 'Hospitalizations': SymPyModule(expressions=(Recognized + Threatened,)),
 'Deaths': SymPyModule(expressions=(Extinct,))}

In [30]:
sidarthe_samples.reset_index(level='sample_id', drop=True)

Unnamed: 0_level_0,beta_param,gamma_param,delta_param,alpha_param,epsilon_param,zeta_param,lambda_param,eta_param,rho_param,theta_param,...,Diagnosed_sol,Extinct_sol,Healed_sol,Infected_sol,Recognized_sol,Susceptible_sol,Threatened_sol,Cases_obs,Hospitalizations_obs,Deaths_obs
timepoint_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,0.009201,0.519076,0.009801,0.547397,0.165703,0.148593,0.033269,0.140474,0.037794,0.361518,...,3.333333e-07,5.958760e-32,1.243178e-17,0.000003,3.333333e-08,0.999996,1.246214e-19,3.666667e-07,3.333333e-08,5.958760e-32
1,0.009201,0.519076,0.009801,0.547397,0.165703,0.148593,0.033269,0.140474,0.037794,0.361518,...,3.827578e-07,7.577925e-14,1.268764e-08,0.000003,3.968124e-08,0.999996,1.756296e-10,4.226147e-07,3.985687e-08,7.577925e-14
2,0.009201,0.519076,0.009801,0.547397,0.165703,0.148593,0.033269,0.140474,0.037794,0.361518,...,4.324895e-07,3.694298e-13,2.589829e-08,0.000003,4.842554e-08,0.999996,4.567468e-10,4.813718e-07,4.888229e-08,3.694298e-13
3,0.009201,0.519076,0.009801,0.547397,0.165703,0.148593,0.033269,0.140474,0.037794,0.361518,...,4.825880e-07,9.840602e-13,3.964909e-08,0.000004,5.953089e-08,0.999996,8.485231e-10,5.429674e-07,6.037941e-08,9.840602e-13
4,0.009201,0.519076,0.009801,0.547397,0.165703,0.148593,0.033269,0.140474,0.037794,0.361518,...,5.331121e-07,2.027941e-12,5.395750e-08,0.000004,7.296917e-08,0.999996,1.356097e-09,6.074374e-07,7.432527e-08,2.027941e-12
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,0.009201,0.519076,0.009801,0.547397,0.165703,0.148593,0.033269,0.140474,0.037794,0.361518,...,1.904889e-05,4.082430e-08,9.264608e-06,0.000057,1.830585e-05,0.999883,1.858121e-06,3.921286e-05,2.016397e-05,4.082430e-08
96,0.009201,0.519076,0.009801,0.547397,0.165703,0.148593,0.033269,0.140474,0.037794,0.361518,...,1.966099e-05,4.263458e-08,9.583470e-06,0.000059,1.893099e-05,0.999879,1.928117e-06,4.052010e-05,2.085910e-05,4.263458e-08
97,0.009201,0.519076,0.009801,0.547397,0.165703,0.148593,0.033269,0.140474,0.037794,0.361518,...,2.029257e-05,4.451286e-08,9.912605e-06,0.000061,1.957596e-05,0.999875,2.000437e-06,4.186897e-05,2.157640e-05,4.451286e-08
98,0.009201,0.519076,0.009801,0.547397,0.165703,0.148593,0.033269,0.140474,0.037794,0.361518,...,2.094426e-05,4.646142e-08,1.025235e-05,0.000062,2.024142e-05,0.999871,2.075157e-06,4.326084e-05,2.231658e-05,4.646142e-08


In [31]:
sidarthe_data = solutions_to_observations(timepoints, sidarthe_samples)
sidarthe_data[0].to_csv(sidarthe_data_path, index=False)
sidarthe_calibrated_samples = load_and_calibrate_and_sample_petri_model(
    sidarthe_model_path, sidarthe_data_path, num_samples=100, timepoints=timepoints
)['data'].set_index(idx).groupby(level='timepoint_id').mean()

In [32]:
sidarthe_calibrated_samples

Unnamed: 0_level_0,beta_param,gamma_param,delta_param,alpha_param,epsilon_param,zeta_param,lambda_param,eta_param,rho_param,theta_param,...,tau_param,sigma_param,Ailing_sol,Diagnosed_sol,Extinct_sol,Healed_sol,Infected_sol,Recognized_sol,Susceptible_sol,Threatened_sol
timepoint_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,0.011091,0.436484,0.011136,0.568219,0.169467,0.124141,0.034058,0.124929,0.034371,0.377407,...,0.009925,0.017033,1.666667e-08,3.333333e-07,5.763366e-32,1.258381e-17,0.000003,3.333333e-08,0.999996,1.164111e-19
1,0.011091,0.436484,0.011136,0.568219,0.169467,0.124141,0.034058,0.124929,0.034371,0.377407,...,0.009925,0.017033,5.704003e-08,3.847988e-07,7.092378e-14,1.285209e-08,0.000003,3.905254e-08,0.999996,1.569321e-10
2,0.011091,0.436484,0.011136,0.568219,0.169467,0.124141,0.034058,0.124929,0.034371,0.377407,...,0.009925,0.017033,9.682137e-08,4.368750e-07,3.382763e-13,2.625306e-08,0.000004,4.689707e-08,0.999996,3.979242e-10
3,0.011091,0.436484,0.011136,0.568219,0.169467,0.124141,0.034058,0.124929,0.034371,0.377407,...,0.009925,0.017033,1.360843e-07,4.896199e-07,8.871602e-13,4.022162e-08,0.000004,5.684367e-08,0.999996,7.274722e-10
4,0.011091,0.436484,0.011136,0.568219,0.169467,0.124141,0.034058,0.124929,0.034371,0.377407,...,0.009925,0.017033,1.748982e-07,5.430905e-07,1.807375e-12,5.477694e-08,0.000004,6.887450e-08,0.999995,1.150064e-09
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,0.011091,0.436484,0.011136,0.568219,0.169467,0.124141,0.034058,0.124929,0.034371,0.377407,...,0.009925,0.017033,1.180371e-05,2.392104e-05,3.684765e-08,1.062271e-05,0.000072,1.851401e-05,0.999861,1.662657e-06
96,0.011091,0.436484,0.011136,0.568219,0.169467,0.124141,0.034058,0.124929,0.034371,0.377407,...,0.009925,0.017033,1.221523e-05,2.475419e-05,3.852763e-08,1.101256e-05,0.000075,1.918711e-05,0.999857,1.728138e-06
97,0.011091,0.436484,0.011136,0.568219,0.169467,0.124141,0.034058,0.124929,0.034371,0.377407,...,0.009925,0.017033,1.264135e-05,2.561665e-05,4.027366e-08,1.141618e-05,0.000077,1.988358e-05,0.999851,1.795951e-06
98,0.011091,0.436484,0.011136,0.568219,0.169467,0.124141,0.034058,0.124929,0.034371,0.377407,...,0.009925,0.017033,1.308261e-05,2.650949e-05,4.208813e-08,1.183405e-05,0.000080,2.060427e-05,0.999846,1.866178e-06


In [33]:
import numpy as np
np.testing.assert_allclose(sidarthe_samples.reset_index(level='sample_id', drop=True),
                           sidarthe_calibrated_samples,
                           rtol=.2)


AssertionError: 
Not equal to tolerance rtol=0.2, atol=0

(shapes (100, 27), (100, 24) mismatch)
 x: array([[9.201490e-03, 5.190758e-01, 9.800992e-03, ..., 3.666667e-07,
        3.333333e-08, 5.958760e-32],
       [9.201490e-03, 5.190758e-01, 9.800992e-03, ..., 4.226147e-07,...
 y: array([[1.109056e-02, 4.364837e-01, 1.113598e-02, ..., 3.333333e-08,
        9.999963e-01, 1.164111e-19],
       [1.109056e-02, 4.364837e-01, 1.113598e-02, ..., 3.905254e-08,...