In [None]:
import numpy as np

from lymph.models import Unilateral
from lymixture import LymphMixture
from lymixture.utils import map_to_simplex
from fixtures import (
    get_graph,
    get_patient_data,
    SIMPLE_SUBSITE,
)

In [None]:
graph = get_graph(size="medium")
patient_data = get_patient_data()
num_components = 3

mixture = LymphMixture(
    model_cls=Unilateral,
    model_kwargs={"graph_dict": graph},
    num_components=num_components,
)
mixture.update_subgroup_modalities({"max_llh": [1., 1.]})
mixture.load_patient_data(patient_data, split_by=SIMPLE_SUBSITE)
mixture.subgroups

In [None]:
mixture.update_subgroup_modalities({"max_llh": [1., 1.]})
mixture.load_patient_data(patient_data, split_by=SIMPLE_SUBSITE)
mixture.subgroups["C05"].modalities

In [None]:
resp_from_cube = np.random.uniform(size=(len(patient_data), num_components-1))
resp = np.array([map_to_simplex(line) for line in resp_from_cube])

mixture.assign_responsibilities(resp)

In [None]:
mixture.get_responsibilities(component=1)

In [None]:
mixture.get_responsibilities().shape

In [None]:
mixture.t_stages

In [None]:
tmp = np.random.uniform(size=(num_components, len(mixture.subgroups)))
tmp /= tmp.sum(axis=0)
mixture.set_mixture_coefs(tmp)

In [None]:
for subgroup in mixture.subgroups.values():
    print(subgroup.modalities)

In [None]:
mixture.update_component_diag_time_dists({"early": np.linspace(0, 10, 11)})
mixture.update_component_diag_time_dists({"late": np.linspace(10, 0, 11)})

In [None]:
for comp in mixture.components:
    print(comp.diag_time_dists)

In [None]:
rng = np.random.default_rng(42)
params_to_set = mixture.get_component_params(flatten=True)
for param in params_to_set.keys():
    mixture.assign_component_params(**{param: rng.uniform()})

mixture.get_component_params()

In [None]:
mixture.set_mixture_coefs(rng.uniform(size=(num_components, len(mixture.subgroups))))

In [None]:
mixture.get_mixture_coefs().sum(axis=0)

In [None]:
mixture.normalize_mixture_coefs()
mixture.get_mixture_coefs().sum(axis=0)

In [None]:
mixture.complete_data_likelihood()