In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from collections import defaultdict
from math import isclose
import secrets
from pprint import pprint
import numpy as np
import pandas as pd
import scipy
from microsim.test._validation.fixture import StorePopulationValidationFixture
from microsim.test._validation.helper import person_obj_to_person_record

In [None]:
num_persons = 1000
nhanes_year = 2013
num_populations = 1
num_advances = 10
all_load_seeds = [secrets.randbits(32) for _ in range(num_populations)]
all_advance_seeds = [secrets.randbits(32) for _ in range(num_advances)]

In [None]:
exactly_eq_props = [
    "gender",
    "raceEthnicity",
    "education",
    "smokingStatus",
    "gcpRandomEffect",
    "otherLipidLowerMedication",
    "selfReportMIAge",
    "selfReportStrokeAge",
    "age",
]
approx_eq_props = [
    "sbp",
    "dbp",
    "a1c",
    "hdl",
    "trig",
    "totChol",
    "bmi",
    "waist",
]
statistically_eq_props = [
    "antiHypertensiveCount",
    "statin",
    "bpMedsAdded",
    "afib",
    "qalys"
    "anyPhysicalActivity",
    "alcoholPerWeek",
]

def diff_exactly_eq_props(vec_rec, store_rec):
    diffs = {}
    for prop_name in exactly_eq_props:
        vec_val = getattr(vec_rec, prop_name)
        store_val = getattr(store_rec, prop_name)
        if not (vec_val == store_val):
            diffs[prop_name] = (vec_val, store_val)
    return diffs

def diff_approx_eq_props(vec_rec, store_rec):
    diffs = {}
    for prop_name in approx_eq_props:
        vec_val = getattr(vec_rec, prop_name)
        store_val = getattr(store_rec, prop_name)
        if not isclose(vec_val, store_val):
            diffs[prop_name] = (vec_val, store_val)
    return diffs

def diff_person_props(vec_rec, store_rec):
    person_diffs = {}

    eq_diffs = diff_exactly_eq_props(vec_rec, store_rec)
    if eq_diffs:
        person_diffs['eq'] = eq_diffs

    approx_eq_diffs = diff_approx_eq_props(vec_rec, store_rec)
    if approx_eq_diffs:
        person_diffs['approx_eq'] = approx_eq_diffs
    
    return person_diffs

def diff_statistically_eq_props(vec_records, store_records):
    pass

def get_all_person_diffs(vec_rec, store_rec):
    pass

In [None]:
people_by_seeds = [[] for _ in all_load_seeds]
fixture = StorePopulationValidationFixture()
for load_seed, results in zip(all_load_seeds, people_by_seeds):
    fixture._loader_seed = load_seed
    fixture._person_records = fixture.get_or_init_person_records(num_persons, nhanes_year)

    for advance_seed in all_advance_seeds:
        fixture.setUp()
        np.random.seed(advance_seed)

        fixture.store_pop.advance()
        fixture.vec_pop.advance_vectorized(years=1)
        results.append((fixture.store_pop.person_store, fixture.vec_pop._people))

In [None]:
# get person diffs
for i, results_by_load_seed in enumerate(people_by_seeds):
    for j, (person_store, vec_people) in enumerate(results_by_load_seed):
        vec_records = [person_obj_to_person_record(p, 1) for p in vec_people]
        store_records = [r.next for r in person_store.get_population_at(t=0)]

        all_person_diffs = []
        for k, (vec_rec, store_rec) in enumerate(zip(vec_records, store_records)):
            cur_person_diffs = diff_person_props(vec_rec, store_rec)
            if cur_person_diffs:
                all_person_diffs.append((k, cur_person_diffs))
        if all_person_diffs:
            pprint(all_person_diffs)
        assert not all_person_diffs