In [1]:
import numpy as np

from lymph.models import Unilateral
from lymixture import LymphMixture
from lymixture.utils import map_to_simplex
from fixtures import (
    get_graph,
    get_patient_data,
    SIMPLE_SUBSITE,
)

In [2]:
graph = get_graph(size="medium")
patient_data = get_patient_data()
num_components = 3

mixture = LymphMixture(
    model_cls=Unilateral,
    model_kwargs={"graph_dict": graph},
    num_components=num_components,
)
mixture.update_subgroup_modalities({"max_llh": [1., 1.]})
mixture.load_patient_data(patient_data, split_by=SIMPLE_SUBSITE)
mixture.subgroups

{'C01': <lymph.models.unilateral.Unilateral at 0x7f7b25618d60>,
 'C02': <lymph.models.unilateral.Unilateral at 0x7f7b25619390>,
 'C03': <lymph.models.unilateral.Unilateral at 0x7f7b256b8040>,
 'C04': <lymph.models.unilateral.Unilateral at 0x7f7b258b6d10>,
 'C05': <lymph.models.unilateral.Unilateral at 0x7f7b256e6fb0>,
 'C06': <lymph.models.unilateral.Unilateral at 0x7f7b258b6740>,
 'C09': <lymph.models.unilateral.Unilateral at 0x7f7b256b9090>,
 'C10': <lymph.models.unilateral.Unilateral at 0x7f7b255c22f0>}

In [3]:
mixture.update_subgroup_modalities({"max_llh": [1., 1.]})
mixture.load_patient_data(patient_data, split_by=SIMPLE_SUBSITE)
mixture.subgroups["C05"].modalities

{'max_llh': Clinical(specificity=1.0, sensitivity=1.0, is_trinary=False)}

In [4]:
resp_from_cube = np.random.uniform(size=(len(patient_data), num_components-1))
resp = np.array([map_to_simplex(line) for line in resp_from_cube])

mixture.assign_responsibilities(resp)

In [5]:
mixture.get_responsibilities(component=1)

0       0.429476
1       0.648337
2       0.818446
3       0.531305
4       0.479799
          ...   
1237    0.344419
1238    0.618100
1239    0.021707
1240    0.202453
1241    0.250294
Name: (_mixture, responsibility, 1), Length: 1242, dtype: float64

In [6]:
mixture.get_responsibilities().shape

(1242, 3)

In [7]:
mixture.t_stages

set()

In [8]:
tmp = np.random.uniform(size=(num_components, len(mixture.subgroups)))
tmp /= tmp.sum(axis=0)
mixture.assign_mixture_coefs(tmp)

In [9]:
for subgroup in mixture.subgroups.values():
    print(subgroup.modalities)

{'max_llh': Clinical(specificity=1.0, sensitivity=1.0, is_trinary=False)}
{'max_llh': Clinical(specificity=1.0, sensitivity=1.0, is_trinary=False)}
{'max_llh': Clinical(specificity=1.0, sensitivity=1.0, is_trinary=False)}
{'max_llh': Clinical(specificity=1.0, sensitivity=1.0, is_trinary=False)}
{'max_llh': Clinical(specificity=1.0, sensitivity=1.0, is_trinary=False)}
{'max_llh': Clinical(specificity=1.0, sensitivity=1.0, is_trinary=False)}
{'max_llh': Clinical(specificity=1.0, sensitivity=1.0, is_trinary=False)}
{'max_llh': Clinical(specificity=1.0, sensitivity=1.0, is_trinary=False)}


In [10]:
mixture.update_component_diag_time_dists({"early": np.linspace(0, 10, 11)})
mixture.update_component_diag_time_dists({"late": np.linspace(10, 0, 11)})

In [11]:
for comp in mixture.components:
    print(comp.diag_time_dists)

{'early': <lymph.diagnose_times.Distribution object at 0x7f7b257e9c90>, 'late': <lymph.diagnose_times.Distribution object at 0x7f7b258b7040>}
{'early': <lymph.diagnose_times.Distribution object at 0x7f7b257e9a80>, 'late': <lymph.diagnose_times.Distribution object at 0x7f7b257e8850>}
{'early': <lymph.diagnose_times.Distribution object at 0x7f7b257e9ba0>, 'late': <lymph.diagnose_times.Distribution object at 0x7f7b257e92a0>}


In [14]:
mixture.complete_data_likelihood()

(16, 20)