In [None]:
import sys
sys.path.append("../../")

from tqdm import tqdm
from typing import List, Tuple, Dict
from pathlib import Path
from argparse import Namespace
from collections import Counter

import torch
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer

from utilities.preprocess import train_valid_test_split, preprocess_patient_state_tuples, pad_int_icd
from utilities.utils import load_json
from utilities.model import BertDxModel, encoder_names_mapping

from diagnosis_classifier import DiagnosisClassifier

## Utility Functions

In [None]:
def get_disease_diff_subset(target: str, data: list, in_icds: list, out_icds: list) -> list:
    assert len(data) == len(in_icds) == len(out_icds)
    subset = list()
    subset_labels = list()
    for sample, in_icd, out_icd in zip(data, in_icds, out_icds):
        if target == "all":
            if in_icd != out_icd:
                subset.append(sample)
                subset_labels.append(out_icd)
        else:
            if (in_icd != out_icd) and (out_icd == target):
                subset.append(sample)
                subset_labels.append(out_icd)

    return subset, subset_labels

def make_incremental_docs(patient_states: List[List[Tuple[str, int]]], label2token: Dict[int, str]) -> List[List[str]]:
    docs = list()
    for patient_state in patient_states:
        subdocs = list()
        for i in range(len(patient_state)):
            sub_patient_state = patient_state[:i + 1]
            subdoc = preprocess_patient_state_tuples([sub_patient_state], label2token=label2token)[0]
            subdocs.append(subdoc)
        docs.append(subdocs)
    
    return docs

def eval_dx_reminder(dx_classifier: DiagnosisClassifier, docs: List[List[str]], labels: List[str], hit_k: int) -> Tuple[float, float]:
    assert len(docs) == len(labels)
    earliest_hits = [1] * len(labels)
    incremental_hits = [0] * len(labels)
    for i in tqdm(range(len(labels))):
        subdocs = docs[i]
        label = labels[i]

        all_logits = dx_classifier.predict(subdocs)
        dxs_l, _ = dx_classifier.get_top_dxs_with_probs(all_logits, top_k=hit_k)
        for j, dxs in enumerate(dxs_l):
            if label in dxs:
                if earliest_hits[i] == 1:
                    earliest_hits[i] = (j + 1) / len(subdocs)
                incremental_hits[i] += 1 / len(subdocs)

    mean_earliest_hit = sum(earliest_hits) / len(earliest_hits)
    mean_incremental_hit = sum(incremental_hits) / len(incremental_hits)
    return mean_earliest_hit, mean_incremental_hit

## Configuration

In [None]:
args = Namespace(
    full_emr_path="../../datasets/notes_B_full.json",
    unnorm_states_path="../../datasets/notes_B_unnorm.json",
    norm_states_path="../../datasets/notes_B_norm.json",
    in_icds_path="../../datasets/in_icds.json",
    out_icds_path="../../datasets/out_icds.json",
    
    ner_model_path="../../models/ner",
    dx_model_path="../../training/dx/models_increment/encoder-BioLinkBERT__optimizer-AdamW__scheduler-linear__lr-5e-05__n_partials-4__input_type-norm__label_type-outicd__scheme-everyk",
    target_metric="micro_f1",
    batch_size=16,

    seed=7,
    train_size=0.8,
    valid_size=0.1,
    test_size=0.1,

    device="cuda:1"
)

## Data

In [None]:
# Load data
text_l = load_json(args.full_emr_path)
norm_states = load_json(args.norm_states_path)
in_icds = load_json(args.in_icds_path)
out_icds = load_json(args.out_icds_path)

id2icd = load_json(Path(args.dx_model_path) / "id2icd.json")
icd2id = {pad_int_icd(icd): int(id_) for id_, icd in id2icd.items()}
labels = [icd2id[icd] for icd in out_icds]

# split data
train_inputs, valid_inputs, test_inputs, train_outs, valid_outs, test_outs = train_valid_test_split(
    inputs=norm_states,
    labels=out_icds,
    train_size=args.train_size,
    valid_size=args.valid_size,
    test_size=args.valid_size,
    seed=args.seed
)

train_ins, valid_ins, test_ins, _, _, _ = train_valid_test_split(
    inputs=in_icds,
    labels=out_icds,
    train_size=args.train_size,
    valid_size=args.valid_size,
    test_size=args.valid_size,
    seed=args.seed
)

# get subset
train_diff_set, train_diff_labels = get_disease_diff_subset(target="all", data=train_inputs, in_icds=train_ins, out_icds=train_outs)
valid_diff_set, valid_diff_labels = get_disease_diff_subset(target="all", data=valid_inputs, in_icds=valid_ins, out_icds=valid_outs)
test_diff_set, test_diff_labels = get_disease_diff_subset(target="all", data=test_inputs, in_icds=test_ins, out_icds=test_outs)

# make incremental documents
test_incremental_docs = make_incremental_docs(patient_states=test_diff_set, label2token={0: "positive", 1: "negative"})

## Model

In [None]:
dx_model = BertDxModel(encoder_name=encoder_names_mapping["BioLinkBERT"], num_dxs=len(Counter(out_icds)))
dx_model.load_state_dict(torch.load(Path(args.dx_model_path) / "best_models" / f"best_{args.target_metric}.pth"))
dx_tokenizer = AutoTokenizer.from_pretrained(Path(args.ner_model_path) / "tokenizer", use_fast=True)
id2dx = load_json(Path(args.dx_model_path) / "id2icd.json")
dx2name = load_json(Path("../../models/dx") / "icdnine2name_en.json")

dx_classifier = DiagnosisClassifier(
    model=dx_model,
    tokenizer=dx_tokenizer,
    id2dx=id2dx,
    dx2name=dx2name,
    batch_size=args.batch_size,
    device=args.device
)

## Evaluation

In [None]:
test_outs2count = Counter(test_outs)
target_dxs = [
    "486", # pneumonia
    "428", # heart failure
    "590" # pyelonephritis
]
target_hitks = [3, 5, 8]

for target_dx in target_dxs:
    print(f"Evaluating target diagnosis {target_dx} - {dx_classifier.dx2name[target_dx] if target_dx in dx_classifier.dx2name else 'all'}")
    test_diff_set, test_diff_labels = get_disease_diff_subset(target=target_dx, data=test_inputs, in_icds=test_ins, out_icds=test_outs)
    test_incremental_docs = make_incremental_docs(patient_states=test_diff_set, label2token={0: "positive", 1: "negative"})
    print(f"Number of samples: diff = {len(test_diff_labels)} (all = {test_outs2count[target_dx]})")
    for hitk in target_hitks:
        print(f"Evaluating hit@{hitk}...")
        earliest_hit, incremental_hit = eval_dx_reminder(dx_classifier, docs=test_incremental_docs, labels=test_diff_labels, hit_k=hitk)
        print(f"Mean earliest hit = {earliest_hit}; Mean incremental hit = {incremental_hit}")