In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from collections import defaultdict
from dataclasses import dataclass
import itertools
from math import isclose
from pathlib import Path
from pprint import pprint
import re
import secrets
import numpy as np
import pandas as pd
import scipy
from microsim.person.bpcog_person_records import (
    BPCOGPersonStaticRecordProtocol,
    BPCOGPersonDynamicRecordProtocol,
)
from microsim.store.numpy_record_mapping import NumpyRecordMapping, NumpyEventRecordMapping
from microsim.store.numpy_person_store import NumpyPersonStore
from microsim.test._validation.fixture import StorePopulationValidationFixture
from microsim.test._validation.helper import person_obj_to_person_record

In [None]:
results_dir = Path(globals().get('_dh', [''])[0], 'advance_results').resolve()
num_persons = 1000
nhanes_year = 2013
num_populations = 1
num_advances = 50
#all_load_seeds = [secrets.randbits(32) for _ in range(num_populations)]
#all_advance_seeds = [secrets.randbits(32) for _ in range(num_advances)]
all_load_seeds = [1324982946]
all_advance_seeds = [
    3454524075,
    1015365140,
    158755841,
    4277030764,
    1899620253,
    163344017,
    483132396,
    4225139234,
    2427105107,
    3195315108,
    2377931612,
    2619141431,
    64413361,
    4009282022,
    623668877,
    998728319,
    3168347689,
    832960366,
    1880079587,
    1752605368,
    3001635028,
    908814873,
    1054896205,
    2956781680,
    1191678910,
    4147698952,
    2916075797,
    4206615114,
    420251420,
    2517593098,
    3588285603,
    367289534,
    3854127762,
    524933424,
    3796129397,
    3612644617,
    3381654696,
    1813454644,
    1921100965,
    1575663705,
    900397490,
    2503793528,
    2218278465,
    1134845661,
    2172656593,
    573979210,
    161619223,
    500497969,
    798739760,
    1773047922,
]

In [None]:
if not results_dir.exists():
    raise ValueError(
        f"Results dir does not exist; please create it before continuing: '{results_dir}'"
    )

In [None]:
exactly_eq_props = [
    "gender",
    "raceEthnicity",
    "education",
    "smokingStatus",
    "gcpRandomEffect",
    "otherLipidLowerMedication",
    "selfReportMIAge",
    "selfReportStrokeAge",
    "age",
]
approx_eq_props = [
    "sbp",
    "dbp",
    "a1c",
    "hdl",
    "trig",
    "totChol",
    "bmi",
    "waist",
]
statistically_eq_props = [
    "antiHypertensiveCount",
    "statin",
    "bpMedsAdded",
    "afib",
    "qalys"
    "anyPhysicalActivity",
    "alcoholPerWeek",
]

def diff_exactly_eq_props(vec_rec, store_rec):
    diffs = {}
    for prop_name in exactly_eq_props:
        vec_val = getattr(vec_rec, prop_name)
        store_val = getattr(store_rec, prop_name)
        if not (vec_val == store_val):
            diffs[prop_name] = (vec_val, store_val)
    return diffs

def diff_approx_eq_props(vec_rec, store_rec):
    diffs = {}
    for prop_name in approx_eq_props:
        vec_val = getattr(vec_rec, prop_name)
        store_val = getattr(store_rec, prop_name)
        if not isclose(vec_val, store_val):
            diffs[prop_name] = (vec_val, store_val)
    return diffs

def diff_person_props(vec_rec, store_rec):
    person_diffs = {}

    eq_diffs = diff_exactly_eq_props(vec_rec, store_rec)
    if eq_diffs:
        person_diffs['eq'] = eq_diffs

    approx_eq_diffs = diff_approx_eq_props(vec_rec, store_rec)
    if approx_eq_diffs:
        person_diffs['approx_eq'] = approx_eq_diffs
    
    return person_diffs

def diff_statistically_eq_props(vec_records, store_records):
    pass

def get_all_person_diffs(vec_rec, store_rec):
    pass


In [None]:
static_record_mapping = NumpyRecordMapping(BPCOGPersonStaticRecordProtocol)
dynamic_record_mapping = NumpyRecordMapping(BPCOGPersonDynamicRecordProtocol)
event_record_mapping = NumpyEventRecordMapping()
all_prop_names = (
    static_record_mapping.property_mappings.keys()
    | dynamic_record_mapping.property_mappings.keys()
    | event_record_mapping.property_mappings.keys()
)

def load_store_from_file(npz_file):
    return NumpyPersonStore(
        num_persons,
        1,
        static_record_mapping,
        dynamic_record_mapping,
        event_record_mapping,
        npz_file=npz_file,
    )

def vec_results_to_person_store(fixture):
    vec_results_store = NumpyPersonStore(
        num_persons,
        1,
        static_record_mapping,
        dynamic_record_mapping,
        event_record_mapping,
        [person_obj_to_person_record(p, 0) for p in fixture.vec_pop._people],
    )
    for record, person in zip(vec_results_store.get_population_at(t=0), fixture.vec_pop._people):
        vec_rec = person_obj_to_person_record(person, 1)
        for prop in all_prop_names:
            next_val = getattr(vec_rec, prop)
            setattr(record.next, prop, next_val)
    return vec_results_store

RESULT_NPZ_PATH_PATTERN = re.compile(r'^result.(-?\d+)\.(-?\d+)\.(store|vec)\.npz$', re.ASCII)

def get_result_save_path(load_seed, advance_seed, pop_type):
    npz_path = f"result.{load_seed}.{advance_seed}.{pop_type}.npz"
    if RESULT_NPZ_PATH_PATTERN.match(npz_path) is None:
        raise ValueError(f"Invalid result save path: {npz_path}")
    return npz_path

In [None]:
# cache results of last run to confirm if save/load results are same as original, if desired
if 'results' in locals():
    _results = locals()['results']

results = defaultdict(lambda: defaultdict(dict))
fixture = StorePopulationValidationFixture()
for load_seed in all_load_seeds:
    fixture._loader_seed = load_seed
    fixture._person_records = fixture.get_or_init_person_records(num_persons, nhanes_year)

    for advance_seed in all_advance_seeds:
        vec_result_path = Path(results_dir, get_result_save_path(load_seed, advance_seed, 'vec'))
        store_result_path = Path(
            results_dir, get_result_save_path(load_seed, advance_seed, 'store')
        )
        if vec_result_path.exists() and store_result_path.exists():
            try:
                vec_person_store = load_store_from_file(vec_result_path)
                person_store = load_store_from_file(store_result_path)
                results[load_seed][advance_seed]["store"] = person_store
                results[load_seed][advance_seed]["vec"] = vec_person_store
            except Exception as e:
                print(
                    "Unexpected error occurred while loading results (will proceed with regular"
                    f" advance): {e}"
                )
            else:
                continue

        fixture.setUp()
        np.random.seed(advance_seed)
        fixture.store_pop.advance()
        fixture.vec_pop.advance_vectorized(years=1)

        vec_person_store = vec_results_to_person_store(fixture)
        vec_person_store.save_to_file(vec_result_path)
        person_store = fixture.store_pop.person_store

        person_store.save_to_file(store_result_path)
        results[load_seed][advance_seed]["store"] = person_store
        results[load_seed][advance_seed]["vec"] = vec_person_store

In [None]:
if '_results' in locals() and 'results' in locals():
    for load_seed, advance_seed, result_type in itertools.product(
        all_load_seeds,
        all_advance_seeds,
        ["store", "vec"],
    ):
        original_store = _results[load_seed][advance_seed][result_type]
        loaded_store = results[load_seed][advance_seed][result_type]
        orig_pop = original_store.get_population_at(t=0)
        load_pop = loaded_store.get_population_at(t=0)
        for i, (p, q) in enumerate(zip(orig_pop, load_pop)):
            for n in all_prop_names:
                p_val = getattr(p.current, n)
                q_val = getattr(q.current, n)
                assert p_val == q_val, f"({load_seed}, {advance_seed}, {result_type}, {i}, 0, {n})"

                p_val = getattr(p.next, n)
                q_val = getattr(q.next, n)
                assert p_val == q_val, f"({load_seed}, {advance_seed}, {result_type}, {i}, 1, {n})"


In [None]:
# get person diffs
for load_seed, advance_results in results.items():
    for advance_seed, result_stores in advance_results.items():
        vec_records = [r.next for r in result_stores["vec"].get_population_at(t=0)]
        store_records = [r.next for r in result_stores["store"].get_population_at(t=0)]

        all_person_diffs = []
        for i, (vec_rec, store_rec) in enumerate(zip(vec_records, store_records)):
            cur_person_diffs = diff_person_props(vec_rec, store_rec)
            if cur_person_diffs:
                all_person_diffs.append((i, cur_person_diffs))
        if all_person_diffs:
            pprint(all_person_diffs)
        assert not all_person_diffs

In [None]:
# run statistical person diffs
class TTestEqualityTest:
    def __init__(self, p_threshold):
        self._p_threshold = p_threshold

    @property
    def p_threshold(self):
        return self._p_threshold

    def __call__(self, vec_data, store_data):
        # t-test will fail if vec_data and store_data each have zero variance
        # so handle the exact equality case before running the t-test
        if 0 == np.var(vec_data) == np.var(store_data):
            return ExactEqualityResult(vec_data == store_data)

        t_statistic, p_value = scipy.stats.ttest_ind(vec_data, store_data)
        is_stat_eq = p_value > self.p_threshold
        stat_eq_result = TTestEqualityResult(is_stat_eq, self.p_threshold, p_value, t_statistic)
        return stat_eq_result


@dataclass
class ExactEqualityResult:
    is_equal: bool

@dataclass
class TTestEqualityResult:
    is_equal: bool
    p_threshold: float
    p_value: float
    t_statistic: float

all_stat_eq_tests = [
    ("statin", TTestEqualityTest(0.05)),
    ("afib", TTestEqualityTest(0.05)),
    ("anyPhysicalActivity", TTestEqualityTest(0.05)),
    #"antiHypertensiveCount",
    #"bpMedsAdded",
    #"qalys"
    #"alcoholPerWeek",
]
for advance_results in results.values():
    stat_eq_results = defaultdict(list)  # {prop_name: [result]}
    prop_values = {
        "store": [defaultdict(list) for _ in range(num_persons)], # [{prop: [value]}]
        "vec": [defaultdict(list) for _ in range(num_persons)],
    }
    for prop_name, stat_eq_test in all_stat_eq_tests:
        for result_stores in advance_results.values():
            for i, vec_rec in enumerate(result_stores["vec"].get_population_at(t=0)):
                vec_val = getattr(vec_rec.next, prop_name)
                prop_values["vec"][i][prop_name].append(vec_val)

            for i, store_rec in enumerate(result_stores["store"].get_population_at(t=0)):
                store_val = getattr(store_rec.next, prop_name)
                prop_values["store"][i][prop_name].append(store_val)

        for i in range(num_persons):
            stat_test_result = stat_eq_test(
                prop_values["vec"][i][prop_name],
                prop_values["store"][i][prop_name],
            )
            stat_eq_results[prop_name].append(stat_test_result)
    break  # focusing only on one load seed for now

In [None]:
# print test failure info
for prop_name, prop_stat_eq_results in stat_eq_results.items():
    failed_tests = [(i, r) for i, r in enumerate(prop_stat_eq_results) if not r.is_equal]
    print(f"For prop '{prop_name}':")
    print(f"'{prop_name}': {len(failed_tests)} failures")
    pprint(failed_tests)
    print()