In [16]:
from utils import get_args
import os
import pandas as pd
import ehr_diagnosis_env
from ehr_diagnosis_env.envs import EHRDiagnosisEnv
import gymnasium
from tqdm import tqdm
from collections import Counter


args = get_args('config.yaml')
def get_env(split, truncate=None):
    print(f'loading {split} dataset...')
    eval_df = pd.read_csv(os.path.join(
        args.data.path, args.data.dataset, f'{split}.data'),
        compression='gzip')
    print(f'length={len(eval_df)}')
    if truncate is not None:
        print(f'Truncating to {truncate}')
        eval_df = eval_df[:truncate]
    env: EHRDiagnosisEnv = gymnasium.make(
        'ehr_diagnosis_env/EHRDiagnosisEnv-v0',
        instances=eval_df,
        cache_path=args.env[f'{split}_cache_path'],
        llm_name_or_interface=None,
        fmm_name_or_interface=None,
        fuzzy_matching_threshold=None,
        reward_type=args.env.reward_type,
        num_future_diagnoses_threshold=args.env.num_future_diagnoses_threshold,
        progress_bar=lambda *a, **kwa: tqdm(*a, **kwa, leave=False),
        top_k_evidence=args.env.top_k_evidence,
        verbosity=1, # don't print anything when an environment is dead
        add_risk_factor_queries=args.env.add_risk_factor_queries,
        limit_options_with_llm=args.env.limit_options_with_llm,
        add_none_of_the_above_option=args.env.add_none_of_the_above_option,
        true_positive_minimum=args.env.true_positive_minimum,
        use_confident_diagnosis_mapping=
            args.env.use_confident_diagnosis_mapping,
        skip_instances_with_gt_n_reports=
            args.env.skip_instances_with_gt_n_reports,
    ) # type: ignore
    return env


import math
def inverse_sigmoid(x):
    return math.log(x / (1-x))


def get_counts(env, subset=None):
    print('getting cached instance info')
    cache_info = env.get_cached_instance_dataframe()
    print('counting...')
    counts = Counter()
    for i, row in cache_info.iterrows():
        if subset is not None and i not in subset:
            continue
        if row['is valid timestep'] is not None and \
                row['is valid timestep'] == row['is valid timestep'] and \
                sum(row['is valid timestep']) > 0:
            counts['total'] += 1
            for target in row['target diagnosis countdown'][0].keys():
                counts[target] += 1
            if len(row['target diagnosis countdown'][0]) == 0:
                counts['negatives'] += 1
    print('counts')
    print(counts)
    print('prevelance')
    print({k: v / counts['total'] for k, v in counts.items()})
    print('inverse sigmoid of prevelance')
    print({k: inverse_sigmoid(v / counts['total']) for k, v in counts.items() if k != 'total'})
    return counts

In [None]:
train_env = get_env('train')

In [17]:
train_counts = get_counts(train_env, subset=None)

getting cached instance info
counting...
counts
Counter({'total': 30252, 'negatives': 21774, 'pulmonary edema': 5098, 'pneumonia': 3744, 'cancer': 2010})
prevelance
{'total': 1.0, 'negatives': 0.7197540658468862, 'pulmonary edema': 0.16851778394816871, 'pneumonia': 0.12376041253470844, 'cancer': 0.06644188813962713}
inverse sigmoid of prevelance
{'negatives': 0.943242024528543, 'pulmonary edema': -1.5961686236710402, 'pneumonia': -1.9572920156106348, 'cancer': -2.642675509576076}


In [18]:
import pickle as pkl
with open('train_subset.pkl', 'rb') as f:
    subset = pkl.load(f)
modified_train_counts = get_counts(train_env, subset=subset)


getting cached instance info
counting...
counts
Counter({'total': 10491, 'pulmonary edema': 5098, 'pneumonia': 3744, 'negatives': 2013, 'cancer': 2010})
prevelance
{'total': 1.0, 'pulmonary edema': 0.4859403298065008, 'pneumonia': 0.35687732342007433, 'cancer': 0.1915927938232771, 'negatives': 0.19187875321704317}
inverse sigmoid of prevelance
{'pulmonary edema': -0.05625351040116628, 'pneumonia': -0.5889434030299426, 'cancer': -1.4396936453085785, 'negatives': -1.437848426272791}
