In [None]:
from summer2 import CompartmentalModel, Stratification
from summer2.parameters import CompartmentValues, Parameter, Time, Function
from summer2.functions import time as stf

from jax import numpy as jnp
import numpy as np
import pandas as pd

In [None]:
from summer2 import inspect as mi

In [None]:
class NStrat:
    def __init__(self, name, strata, stratifies=None,is_base=False):
        self.name = name
        self.strata = strata
        self.is_base = is_base
        self.stratifies = stratifies or {}

    def __repr__(self):
        return f"{self.name}: {self.strata}"

class NComp:
    def __init__(self, name, strata, idx=None):
        self.name = name
        self.strata = strata
        self.idx = idx

    def __repr__(self):
        return self.name
    
    def __hash__(self) -> int:
        return self.name.__hash__()

In [None]:
class CompartmentQuery:
    def __init__(self, data: list[NComp]):
        self.compartments = data

    @property
    def names(self) -> list[str]:
        return [c.name for c in self.compartments]
    
    @property
    def index(self) -> np.ndarray[int]:
        return np.array([c.idx for c in self.compartments], dtype=int)
    
    def __repr__(self):
        return f"CompartmentQuery: {self.compartments.__repr__()}"


In [None]:
class NModel:
    def __init__(self, init_comps, init_strat="state"):
        self.compartments = [NComp(k, {init_strat: k}, i) for i, k in enumerate(init_comps)]
        self.flows = []
        self.stratifications = {init_strat: NStrat(init_strat, init_comps, True)}

    def query_compartments(self, q: dict) -> CompartmentQuery:
        return CompartmentQuery(mi.query_compartments(self, q))

    def stratify(self, strat):
        comps_to_stratify = self.query_compartments(strat.stratifies).compartments

        new_comps = []
        for c in self.compartments:
            if c in comps_to_stratify:
                new_comps += [NComp("_".join((c.name,stratum)), c.strata | {strat.name: stratum}) for stratum in strat.strata]
            else:
                new_comps.append(c)

        for i,c in enumerate(new_comps):
            c.idx = i

        self.compartments = new_comps
        self.stratifications[strat.name] = strat

        #self._transactions.append()

def get_category_indexer(m: NModel, query: list[dict]):
    return np.array([m.query_compartments(q).index for q in query])



In [None]:
def proportional(weights):
    return weights / weights.sum()

In [None]:
proportional(np.array([1.0,2.0,1.0]))

In [None]:
# infection_[S->I]
# 1:1 infection_[S->I]_age
# infection_[S->I]

In [None]:
# strains with exclusivity
# S -> [I1, I2] -> R
# strains with simultaneous infection
# S -> [I1, I2] -> [R1,R2]
# I1 -> I2, I2 ->I1

# Weighted adjustments; sum to 1.0
# x=[1.0,mod2,mod3] -> x/x.sum()

# add_adjustments(flow, source, dest, adjp)
# adjp = proportional(x)

In [None]:
# stratify

# if a stratification affects any compartments involved in a flow;
# if both source and dest are affected (ie mapping remains equivalent), then do nothing?
# if only dest is affected (ie we are branching outward), apply adjustments (default to even split weighting to 1.0)
# if only source is affected (ie we are branching inward), probably don't need anything?

# flow params can be
# scalar
# the shape of the flow IFF source and dest are the same size
# the shape of either source or dest if these 2 are different (with broadcasting rules specified)

In [None]:
nm = NModel(["pop", "extras"], "base")
nm.stratifications

In [None]:
nm.stratify(NStrat("state", ["S","I","R"], {"base": "pop"}))

In [None]:
nm.query_compartments({})

In [None]:
nm.stratify(NStrat("age", ["child", "adult"], {"base": "pop"}))

In [None]:
nm.stratify(NStrat("job", ["frontline", "office", "unemployed"], {"age": "adult"}))

In [None]:
nm.compartments

In [None]:
nm.stratify(NStrat("severity", ["asymptomatic", "mild", "severe"], {"state": "I"}))
nm.compartments

In [None]:
nm.query_compartments({"severity": "mild"})