In [None]:
import sys
sys.path.append("../../")

import random
from tqdm import tqdm
from typing import List
from pathlib import Path
from argparse import Namespace
from collections import Counter

import pandas as pd
from sklearn.model_selection import train_test_split

import torch
from torch.utils.data import DataLoader
from torch.nn.functional import softmax
from transformers import AutoTokenizer
from transformers import logging
logging.set_verbosity_error()

from utilities.preprocess import preprocess_patient_state_tuples, train_valid_test_split, pad_int_icd
from utilities.utils import load_json, save_json, set_seeds, build_reverse_dict, move_bert_input_to_device
from utilities.model import BertNERModel, BiEncoder, BertDxModel, encoder_names_mapping
from utilities.data import MedicalDxDataset
from utilities.term import build_term_ids_lists

from icda import ICDA
from finding_extractor import FindingExtractor, Recognizer, Normalizer
from diagnosis_classifier import DiagnosisClassifier
from term_suggester import TermSuggester
from emr_preprocessor import EMRPreprocessor
from state_tracker import PatientStateTracker

## Utility Functions

In [None]:
def make_incremental_states(states):
    docs = list()

    for doc in states:
        subdocs = list()
        for i in range(1, len(doc) + 1):
            subdoc = list()
            for j in range(i):
                concept = doc[j][0]
                pol = doc[j][1]
                pol_name = "positive" if (pol == 0) else "negative"
                subdoc += [pol_name, concept]
            subdoc = ' '.join(subdoc)
            subdocs.append(subdoc)
        if not subdocs:
            subdocs.append("")
        docs.append(subdocs)
    
    return docs

def calc_all_cnfds(model: BertDxModel, tokenizer: AutoTokenizer, docs: List[List[str]], ys: List[int]):
    all_cnfds = list()

    for subdocs, y in tqdm(zip(docs, ys)):
        valid_set = MedicalDxDataset(emrs=subdocs, dx_labels=[y] * len(subdocs), tokenizer=tokenizer)
        valid_loader = DataLoader(valid_set, batch_size=16, shuffle=False, pin_memory=True, collate_fn=valid_set.collate_fn)

        model.eval()
        all_logits = list()
        for X, _ in valid_loader:
            X = move_bert_input_to_device(X, "cuda")
            with torch.no_grad():
                logits = model(X)
                all_logits.append(logits)
        
        all_logits = torch.cat(all_logits, dim=0)
        cnfds = softmax(all_logits, dim=-1)[:, y].cpu().tolist()
        all_cnfds.append(cnfds)
    
    return all_cnfds

def select_states_gt_threshold(states, labels, cnfds_l, th_low: float) -> tuple:
    assert len(states) == len(labels) == len(cnfds_l)
    sel_states = list()
    sel_labels = list()
    sel_cnfds_l = list()
    # select states where there is a state_obs where cnfd > th_low
    for state, label, cnfds in zip(states, labels, cnfds_l):
        if state:
            assert len(state) == len(cnfds)
            for i, cnfd in enumerate(cnfds):
                if cnfd > th_low:
                    state_obs = state[:i + 1]
                    state_rem = state[i + 1:]
                    sel_states.append({
                        "obs": state_obs,
                        "rem": state_rem
                    })
                    sel_labels.append(label)
                    sel_cnfds_l.append(cnfds)
                    break

    return sel_states, sel_labels, sel_cnfds_l

def static_rank_rem(obs, new_rem, label, icda: ICDA):
    new_rem_with_score = list()
    for term, pol in new_rem:
        term_id = icda.term_suggester.term2id[term]
        score = icda.term_suggester.score_matrix.at[term_id, label]
        new_rem_with_score.append(([term, pol], score))
    
    ranked_new_rem_with_score = sorted(new_rem_with_score, key=lambda t: t[1])
    ranked_new_rem = [term_pol for term_pol, _ in ranked_new_rem_with_score]
    return ranked_new_rem

def dynamic_rank_rem(obs, new_rem, label, icda: ICDA):
    obs = obs.copy()
    ranked_new_rem = list()
    while len(new_rem) > 0:
        rem_with_score = list()

        obs_plus_term_l = list()
        for term_pol in new_rem:
            term = term_pol[0]
            obs_plus_term = obs + [[term, 0]]
            obs_plus_term_l.append(obs_plus_term)

        obs_plus_term_text_l = preprocess_patient_state_tuples(obs_plus_term_l, label2token={0: "positive", 1: "negative"})
        # predict diagnosis
        all_logits = icda.diagnosis_classifier.predict(obs_plus_term_text_l)
        dxs_l, probs_l = icda.diagnosis_classifier.get_top_dxs_with_probs(all_logits, top_k=113)
        
        assert len(dxs_l) == len(probs_l) == len(new_rem)
        for dxs, probs, term_pol in zip(dxs_l, probs_l, new_rem):
            dx_idx = dxs.index(label)
            score = probs[dx_idx]
            # append tuple
            rem_with_score.append((term_pol, score))
        
        ranked_rem_with_score = sorted(rem_with_score, key=lambda t: t[1], reverse=True)
        top_term_pol, score = ranked_rem_with_score[0]
        obs.append(top_term_pol)
        ranked_new_rem.append(top_term_pol)
        new_rem.remove(top_term_pol)

    return ranked_new_rem

def suggest_term_by_scheme(sel_states, sel_labels, icda: ICDA, scheme: str):
    assert len(sel_states) == len(sel_labels)
    if scheme not in ["random", "physician", "static", "dynamic"]:
        raise ValueError()
    
    rem_states = list()
    for state, label in tqdm(zip(sel_states, sel_labels)):
        obs = state["obs"]
        rem = state["rem"]
        new_rem = rem.copy()
        # rearrangement
        if scheme == "random":
            random.shuffle(new_rem)
        elif scheme == "static":
            new_rem = static_rank_rem(obs, new_rem, label, icda)
        elif scheme == "dynamic":
            new_rem = dynamic_rank_rem(obs, new_rem, label, icda)
        
        rem_states.append(new_rem)
    
    return rem_states

def calc_cnfds_l_from_states(obs_states, rem_states, dx_model, tokenizer, labels):
    assert len(obs_states) == len(rem_states) == len(labels)
    full_states = list()
    for obs, rem in zip(obs_states, rem_states):
        full = obs + rem
        full_states.append(full)
    incremental_docs = make_incremental_states(states=full_states)
    cnfds_l = calc_all_cnfds(model=dx_model, tokenizer=tokenizer, docs=incremental_docs, ys=labels)

    return cnfds_l

def calc_term_usage(obs_states, cnfds_l, th_highs: List[float], default_upper_limit: int = 10):
    assert len(obs_states) == len(cnfds_l)
    
    term_usage = {th_high: list() for th_high in th_highs}
    for obs_state, cnfds in zip(obs_states, cnfds_l):
        start_idx = len(obs_state)
        rem_cnfds = cnfds[start_idx:]
        for th in th_highs:
            exceed_th = False
            for i, cnfd in enumerate(rem_cnfds):
                nterms = i + 1
                if nterms >= default_upper_limit:
                    break
                if cnfd > th:
                    term_usage[th].append(nterms)
                    exceed_th = True
                    break
            if not exceed_th:
                term_usage[th].append(default_upper_limit)

    for term_counts in term_usage.values():
        assert len(term_counts) == len(cnfds_l)

    return {th: pd.Series(term_usage[th]).describe() for th in th_highs}

## Configuration

In [None]:
args = Namespace(
    full_emr_path="../../datasets/notes_B_full.json",
    unnorm_states_path="../../datasets/notes_B_unnorm.json",
    norm_states_path="../../datasets/notes_B_norm.json",
    in_icds_path="../../datasets/in_icds.json",
    out_icds_path="../../datasets/out_icds.json",

    ner_model_path="../../models/ner",
    batch_size=16,

    nen_model_path="../../models/nen",

    dx_model_path="../../models/dx",
    target_metric="hat5",

    score_matrix_path="../../models/term/fisher_matrix_mink-3_minp-0.05.csv",
    term2id_path="../../models/term/term2id.json",
    inequality="lesser",
    threshold=0.10,
    ndx=5,

    seed=7,
    train_size=0.8,
    valid_size=0.1,
    test_size=0.1,

    system_mode="test",
    extract_mode="umls",
    front_end="unified",
    device="cuda:0"
)

set_seeds(args.seed)

## Data

In [None]:
# Load data
norm_states = load_json(args.norm_states_path)
in_icds = load_json(args.in_icds_path)
out_icds = load_json(args.out_icds_path)

id2icd = load_json(Path(args.dx_model_path) / "id2icd.json")
icd2id = {pad_int_icd(icd): int(id_) for id_, icd in id2icd.items()}
labels = [icd2id[icd] for icd in out_icds]

# Split data
train_inputs, valid_inputs, test_inputs, train_outs, valid_outs, test_outs = train_valid_test_split(
    inputs=norm_states,
    labels=out_icds,
    train_size=args.train_size,
    valid_size=args.valid_size,
    test_size=args.valid_size,
    seed=args.seed
)

train_ins, valid_ins, test_ins, _, _, _ = train_valid_test_split(
    inputs=in_icds,
    labels=out_icds,
    train_size=args.train_size,
    valid_size=args.valid_size,
    test_size=args.valid_size,
    seed=args.seed
)

## Model

In [None]:
# Models
ner_model = BertNERModel(encoder=encoder_names_mapping["BioLinkBERT"], num_tags=5)
ner_model.load_state_dict(torch.load(Path(args.ner_model_path) / "best_model.pth", map_location=args.device))
ner_tokenizer = AutoTokenizer.from_pretrained(Path(args.ner_model_path) / "tokenizer", use_fast=True)

nen_model = BiEncoder(encoder_name=encoder_names_mapping["BioLinkBERT"])
nen_model.load_state_dict(torch.load(Path(args.nen_model_path) / "best_valid_acc.pth", map_location=args.device))
nen_tokenizer = AutoTokenizer.from_pretrained(Path(args.nen_model_path) / "tokenizer", use_fast=True)
entity_embeddings = torch.load(Path(args.nen_model_path) / "entity_embeddings.pt")
cui2name = load_json(Path(args.nen_model_path) / "smcui2name.json")

dx_model = BertDxModel(encoder_name=encoder_names_mapping["BioLinkBERT"], num_dxs=len(Counter(out_icds)))
dx_model.load_state_dict(torch.load(Path(args.dx_model_path) / f"best_{args.target_metric}.pth"))
dx_tokenizer = AutoTokenizer.from_pretrained(Path(args.ner_model_path) / "tokenizer", use_fast=True)
id2dx = load_json(Path(args.dx_model_path) / "id2icd.json")
dx2name = load_json(Path("../../models/dx") / "icdnine2name_en.json")

fisher_matrix = pd.read_csv(args.score_matrix_path, index_col="term_id")
term2id = load_json(args.term2id_path)
id2term = build_reverse_dict(term2id)

# Components
recognizer = Recognizer(
    model=ner_model,
    tokenizer=ner_tokenizer,
    batch_size=args.batch_size,
    device=args.device
)

normalizer = Normalizer(
    model=nen_model,
    tokenizer=nen_tokenizer,
    entity_embeddings=entity_embeddings,
    cui2name=cui2name,
    device=args.device,
    emr_batch_size=1,
    cui_batch_size=args.batch_size
)

finding_extractor = FindingExtractor(
    recognizer=recognizer,
    normalizer=normalizer
)

emr_preprocessor = EMRPreprocessor(
    finding_extractor=finding_extractor
)

dx_classifier = DiagnosisClassifier(
    model=dx_model,
    tokenizer=dx_tokenizer,
    id2dx=id2dx,
    dx2name=dx2name,
    batch_size=args.batch_size,
    device=args.device
)

term_suggester = TermSuggester(
    score_matrix=fisher_matrix,
    id2term=id2term,
    inequality=args.inequality,
    threshold=args.threshold,
    diagnosis_classifier=dx_classifier,
    umls_classifier=None,
    top_k_dxs=args.ndx
)

icda = ICDA(
    system_mode=args.system_mode,
    extract_mode=args.extract_mode,
    front_end=args.front_end,
    finding_extractor=finding_extractor,
    diagnosis_classifier=dx_classifier,
    term_suggester=term_suggester,
    emr_preprocessor=emr_preprocessor
)

## Evaluation

In [None]:
dx_model = dx_model.to(args.device)

# Estimate model confidence on the incremental states of test set notes
physician_incremental_docs = make_incremental_states(states=test_inputs)
all_cnfds = calc_all_cnfds(model=dx_model, tokenizer=ner_tokenizer, docs=physician_incremental_docs, ys=[dx_classifier.dx2id[dx] for dx in test_outs])

# Select notes where there are states of which model confidence of y_d >= lower threshold
sel_states, sel_labels, sel_cnfds_l = select_states_gt_threshold(states=test_inputs, labels=test_outs, cnfds_l=all_cnfds, th_low=0.25)

# Ranking approaches
approaches = ["random", "physician", "static", "dynamic"]

obs_states = [state["obs"] for state in sel_states]

# Suggest terms by different ranking approaches
random_rem_states = suggest_term_by_scheme(sel_states, sel_labels, icda, scheme="random")
physician_rem_states = suggest_term_by_scheme(sel_states, sel_labels, icda, scheme="physician")
static_rem_states = suggest_term_by_scheme(sel_states, sel_labels, icda, scheme="static")
dynamic_rem_states = suggest_term_by_scheme(sel_states, sel_labels, icda, scheme="dynamic")

# Calculate model confidence for y_d
random_cnfds_l = calc_cnfds_l_from_states(obs_states, rem_states=random_rem_states, dx_model=dx_model, tokenizer=ner_tokenizer, labels=[dx_classifier.dx2id[dx] for dx in sel_labels])
physician_cnfds_l = calc_cnfds_l_from_states(obs_states, rem_states=physician_rem_states, dx_model=dx_model, tokenizer=ner_tokenizer, labels=[dx_classifier.dx2id[dx] for dx in sel_labels])
static_cnfds_l = calc_cnfds_l_from_states(obs_states, rem_states=static_rem_states, dx_model=dx_model, tokenizer=ner_tokenizer, labels=[dx_classifier.dx2id[dx] for dx in sel_labels])
dynamic_cnfds_l = calc_cnfds_l_from_states(obs_states, rem_states=dynamic_rem_states, dx_model=dx_model, tokenizer=ner_tokenizer, labels=[dx_classifier.dx2id[dx] for dx in sel_labels])

# Calculate turns
th_highs = [0.5, 0.7, 0.9]

turns_by_approach = dict()

random_term_usage = calc_term_usage(obs_states, cnfds_l=random_cnfds_l, th_highs=th_highs)
physician_term_usage = calc_term_usage(obs_states, cnfds_l=physician_cnfds_l, th_highs=th_highs)
static_term_usage = calc_term_usage(obs_states, cnfds_l=static_cnfds_l, th_highs=th_highs)
dynamic_term_usage = calc_term_usage(obs_states, cnfds_l=dynamic_cnfds_l, th_highs=th_highs)