In [None]:
import sys
sys.path.append("../")

from tqdm.auto import tqdm
from pathlib import Path
from argparse import Namespace

import pandas as pd

import torch
from transformers import AutoTokenizer
from transformers import logging
logging.set_verbosity_error()

from utilities.preprocess import train_valid_test_split, pad_int_icd
from utilities.utils import load_json, set_seeds, build_reverse_dict
from utilities.model import BertNERModel, BiEncoder, BertDxModel, encoder_names_mapping

from icda import ICDA
from finding_extractor import FindingExtractor, Recognizer, Normalizer
from diagnosis_classifier import DiagnosisClassifier
from term_suggester import TermSuggester, UMLSClassifier
from emr_preprocessor import EMRPreprocessor

## Configuration

In [None]:
args = Namespace(
    full_emr_path="../datasets/notes_B_full.json",
    unnorm_states_path="../datasets/notes_B_unnorm.json",
    norm_states_path="../datasets/notes_B_norm.json",
    in_icds_path="../datasets/in_icds.json",
    out_icds_path="../datasets/out_icds.json",

    ner_model_path="../models/ner",
    batch_size=16,

    nen_model_path="../models/nen",

    dx_model_path="../models/dx",
    target_metric="hat5",

    score_matrix_path="../models/term/fisher_matrix_mink-3_minp-0.05.csv",
    term2id_path="../models/term/term2id.json",
    inequality="lesser",
    threshold=0.05,
    ndx=5,

    seed=7,
    train_size=0.8,
    valid_size=0.1,
    test_size=0.1,

    system_mode="deploy",
    extract_mode="umls",
    front_end="unified",
    device="cuda:0"
)

set_seeds(args.seed)

## Load ICDA Modules

In [None]:
# Models
ner_model = BertNERModel(encoder=encoder_names_mapping["BioLinkBERT"], num_tags=5)
ner_model.load_state_dict(torch.load(Path(args.ner_model_path) / "best_model.pth", map_location=args.device))
ner_tokenizer = AutoTokenizer.from_pretrained(Path(args.ner_model_path) / "tokenizer", use_fast=True)

nen_model = BiEncoder(encoder_name=encoder_names_mapping["BioLinkBERT"])
nen_model.load_state_dict(torch.load(Path(args.nen_model_path) / "best_valid_acc.pth", map_location=args.device))
nen_tokenizer = AutoTokenizer.from_pretrained(Path(args.nen_model_path) / "tokenizer", use_fast=True)
entity_embeddings = torch.load(Path(args.nen_model_path) / "entity_embeddings.pt")

cui2name = load_json(Path(args.nen_model_path) / "smcui2name.json")
cui2typeinfo = load_json(Path(args.nen_model_path) / "smcui2typeinfo.json")
cat2typenames = load_json(Path(args.nen_model_path) / "cat2typenames.json")

id2dx = load_json(Path(args.dx_model_path) / "id2icd.json")
dx2name = load_json(Path(args.dx_model_path) / "icdnine2name_en.json")
dx_model = BertDxModel(encoder_name=encoder_names_mapping["BioLinkBERT"], num_dxs=len(id2dx))
dx_model.load_state_dict(torch.load(Path(args.dx_model_path) / f"best_{args.target_metric}.pth"))
dx_tokenizer = AutoTokenizer.from_pretrained(Path(args.ner_model_path) / "tokenizer", use_fast=True)

fisher_matrix = pd.read_csv(args.score_matrix_path, index_col="term_id")
term2id = load_json(args.term2id_path)
id2term = build_reverse_dict(term2id)

# Sub-components
recognizer = Recognizer(
    model=ner_model,
    tokenizer=ner_tokenizer,
    batch_size=args.batch_size,
    device=args.device
)

normalizer = Normalizer(
    model=nen_model,
    tokenizer=nen_tokenizer,
    entity_embeddings=entity_embeddings,
    cui2name=cui2name,
    device=args.device,
    emr_batch_size=1,
    cui_batch_size=args.batch_size
)

umls_classifier = UMLSClassifier(
    cui2name=cui2name,
    cui2typeinfo=cui2typeinfo,
    cat2typenames=cat2typenames
)

# Components
finding_extractor = FindingExtractor(
    recognizer=recognizer,
    normalizer=normalizer
)

emr_preprocessor = EMRPreprocessor(
    finding_extractor=finding_extractor
)

dx_classifier = DiagnosisClassifier(
    model=dx_model,
    tokenizer=dx_tokenizer,
    id2dx=id2dx,
    dx2name=dx2name,
    batch_size=args.batch_size,
    device=args.device
)

term_suggester = TermSuggester(
    score_matrix=fisher_matrix,
    id2term=id2term,
    inequality=args.inequality,
    threshold=args.threshold,
    diagnosis_classifier=dx_classifier,
    umls_classifier=umls_classifier,
    top_k_dxs=args.ndx
)

icda = ICDA(
    system_mode=args.system_mode,
    extract_mode=args.extract_mode,
    front_end=args.front_end,
    finding_extractor=finding_extractor,
    diagnosis_classifier=dx_classifier,
    term_suggester=term_suggester,
    emr_preprocessor=emr_preprocessor
)

## Data

In [None]:
text_l = load_json(args.full_emr_path)
out_icds = load_json(args.out_icds_path)

id2icd = load_json(Path(args.dx_model_path) / "id2icd.json")
icd2id = {pad_int_icd(icd): int(id_) for id_, icd in id2icd.items()}
labels = [icd2id[icd] for icd in out_icds]

train_inputs, valid_inputs, test_inputs, train_outs, valid_outs, test_outs = train_valid_test_split(
    inputs=text_l,
    labels=labels,
    train_size=args.train_size,
    valid_size=args.valid_size,
    test_size=args.valid_size,
    seed=args.seed
)

## Inference

### Instance-by-Instance

In [None]:
import time

def time_it(func, *args, **kwargs):
    start = time.time()
    func(*args, **kwargs)
    end = time.time()
    return end - start

times = list()

for text in tqdm(test_inputs):
    t = time_it(icda.generate_support, [text], n_dx=5)
    times.append(t)

### Batch Inference

In [None]:
batch_size = 16
batch_times = list()

for i in tqdm(range(0, len(test_inputs), batch_size)):
    input_l = test_inputs[i:i + batch_size]
    batch_t = time_it(icda.generate_support, input_l, n_dx=5)
    batch_times.append(batch_t)

## Plot

In [None]:
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt

plt.rcParams["figure.dpi"] = 500

times_df = pd.DataFrame(times).rename({0: "Execution time (seconds)"}, axis=1)

sns.histplot(times_df, x="Execution time (seconds)", bins=40, kde=True)

In [None]:
pd.Series(times).describe()