In [1]:
import os

os.environ["HF_HOME"] = "/projects/bhuang/.cache/huggingface"
# os.environ["OMP_NUM_THREADS"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import numpy as np
import pandas as pd
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


## Load data

In [None]:
# drbenchmark_quaero

data_files = {
    "train": [
        "/home/bhuang/icd_10/data/quaero_icd10_by_category_resplitted/drbenchmark_quaero-medline-train-cls-mistral_large_instruct_2407-processed.jsonl",
        "/home/bhuang/icd_10/data/quaero_icd10_by_category_resplitted/drbenchmark_quaero-emea-train-cls-mistral_large_instruct_2407-processed.jsonl",
    ],
    "valid": [
        "/home/bhuang/icd_10/data/quaero_icd10_by_category_resplitted/drbenchmark_quaero-medline-validation-cls-mistral_large_instruct_2407-processed.jsonl",
        "/home/bhuang/icd_10/data/quaero_icd10_by_category_resplitted/drbenchmark_quaero-emea-validation-cls-mistral_large_instruct_2407-processed.jsonl",
    ],
    # "test": [
    #     "/home/bhuang/icd_10/data/quaero_icd10_by_category_resplitted/drbenchmark_quaero-medline-test-cls-mistral_large_instruct_2407-processed.jsonl",
    #     "/home/bhuang/icd_10/data/quaero_icd10_by_category_resplitted/drbenchmark_quaero-emea-test-cls-mistral_large_instruct_2407-processed.jsonl",
    # ],
    "test_quaero_medline": [
        "/home/bhuang/icd_10/data/quaero_icd10_by_category_resplitted/drbenchmark_quaero-medline-test-cls-mistral_large_instruct_2407-processed.jsonl",
    ],
    "test_quaero_emea": [
        "/home/bhuang/icd_10/data/quaero_icd10_by_category_resplitted/drbenchmark_quaero-emea-test-cls-mistral_large_instruct_2407-processed.jsonl",
    ],
}

In [3]:
# synthetic

data_files = {
    "train": [
        "/home/bhuang/icd_10/data/synthetic/synthetic-mistral_large_instruct_2407-240909-processed-train-10k.jsonl",
    ],
    "valid": [
        "/home/bhuang/icd_10/data/synthetic/synthetic-mistral_large_instruct_2407-240909-processed-validation.jsonl",
        "/home/bhuang/icd_10/data/synthetic_test/synthetic-head-processed-validation.jsonl",
        "/home/bhuang/icd_10/data/synthetic_test/synthetic-medium-processed-validation.jsonl",
        "/home/bhuang/icd_10/data/synthetic_test/synthetic-tail-processed-validation.jsonl",
    ],
    "test_synthetic": [
        "/home/bhuang/icd_10/data/synthetic/synthetic-mistral_large_instruct_2407-240909-processed-test.jsonl",
    ],
    "test_synthetic_head": [
        "/home/bhuang/icd_10/data/synthetic_test/synthetic-head-processed-test.jsonl",
    ],
    "test_synthetic_medium": [
        "/home/bhuang/icd_10/data/synthetic_test/synthetic-medium-processed-test.jsonl",
    ],
    "test_synthetic_tail": [
        "/home/bhuang/icd_10/data/synthetic_test/synthetic-tail-processed-test.jsonl",
    ],
}

In [4]:
dataset = load_dataset("json", data_files=data_files)
dataset

DatasetDict({
    train: Dataset({
        features: ['labels', 'text', 'has_diso'],
        num_rows: 10000
    })
    valid: Dataset({
        features: ['labels', 'text', 'has_diso'],
        num_rows: 4000
    })
    test_synthetic: Dataset({
        features: ['labels', 'text', 'has_diso'],
        num_rows: 1000
    })
    test_synthetic_head: Dataset({
        features: ['labels', 'text', 'has_diso'],
        num_rows: 1000
    })
    test_synthetic_medium: Dataset({
        features: ['labels', 'text', 'has_diso'],
        num_rows: 1000
    })
    test_synthetic_tail: Dataset({
        features: ['labels', 'text', 'has_diso'],
        num_rows: 1000
    })
})

In [5]:
# x_train, y_train = dataset["train"]["text"], dataset["train"]["labels"]
# x_valid, y_valid = dataset["valid"]["text"], dataset["valid"]["labels"]
# x_test, y_test = dataset["test"]["text"], dataset["test"]["labels"]

x, y = {}, {}
for name, ds in dataset.items():
    x[name] = ds["text"]
    y[name] = ds["labels"]

## Transform data

In [6]:
from sklearn.preprocessing import MultiLabelBinarizer

# prepare labels

# include valid and test in overall classes
# y = y_train + y_valid + y_test
y_all = sum(y.values(), [])

mlb = MultiLabelBinarizer()
mlb.fit(y_all)

classes = mlb.classes_
num_classes = len(classes)
num_classes

933

In [7]:
# transform labels
# y_train_encoded = mlb.transform(y_train)
# y_valid_encoded = mlb.transform(y_valid)
# y_test_encoded = mlb.transform(y_test)

y_encoded = {k: mlb.transform(v) for k, v in y.items()}

## Train and evaluate

### Evaluate helper

In [8]:
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score


# fmt: off
def evaluate(y, preds, average="micro", verbose=True):
    """evaluate on all metrics"""
    precision, recall, f1, _ = precision_recall_fscore_support(y, preds, average=average, zero_division=1)
    # precision, recall, f1, _ = precision_recall_fscore_support(y, preds, average=average, labels=classes, zero_division=1)
    auc_score = roc_auc_score(y, preds, average=average)
    
    if verbose:
        print(f"precision: {precision:.4f}, recall: {recall:.4f}, f1: {f1:.4f}, auc_score: {auc_score:.4f}")

    # print(classification_report(y, preds, zero_division=1, digits=4))

    """
    conf_mat = confusion_matrix(y, preds)
    conf_mat_df = pd.DataFrame(conf_mat, index=classes, columns=classes)
    # print(conf_mat_df)
    
    plt.figure(figsize=(15, 10))
    sns.heatmap(conf_mat_df, annot=True, vmin=0, vmax=conf_mat.max(), fmt='d', cmap="YlGnBu")
    plt.yticks(rotation=0)
    plt.xticks(rotation=45)
    """

    return {"precision": precision, "recall": recall, "f1": f1, "auc_score": auc_score}
# fmt: on

In [9]:
from collections import Counter

# "cheat" to include valid and test since they have a really poor label coverage
# y_flattened = [item for sublist in y_all for item in sublist]
y_flattened = [item for sublist in y["train"] for item in sublist]

# get label id
y_id_flattened = [np.argwhere(mlb.classes_ == item)[0][0] for item in y_flattened]
print("Total num of labels:", len(y_id_flattened))

# counter of labels in train/valid/test sets
counter = Counter(y_id_flattened)


def k_most_frequent(k):
    return [item for item, _ in counter.most_common(k)]


k_most_frequent(5)

Total num of labels: 30384


[48, 323, 316, 910, 512]

In [10]:
def predict(n, k):
    preds = np.zeros((n, num_classes), dtype=np.int64)
    preds[:, k_most_frequent(k)] = 1
    # preds.sum(1)
    return preds


def predict_evaluate(x, y, k, verbose=True):
    """predict then evaluate"""
    # predict class labels
    preds = predict(len(y), k)
    return evaluate(y, preds, verbose=verbose)

In [11]:
# grid search k on validation set

perf_by_k = []
for k in range(0, 51):
    # r = predict_evaluate(x_valid, y_valid_encoded, k=k, verbose=False)
    r = predict_evaluate(x["valid"], y_encoded["valid"], k=k, verbose=False)
    perf_by_k.append({"k": k, **r})

df_perf_by_k = pd.DataFrame(perf_by_k)
# sort by f1
df_perf_by_k = df_perf_by_k.sort_values("f1", ascending=False)
df_perf_by_k.head()

Unnamed: 0,k,precision,recall,f1,auc_score
24,24,0.020146,0.162276,0.035842,0.568495
39,39,0.018904,0.247441,0.035124,0.60315
19,19,0.020303,0.129468,0.035101,0.554727
26,26,0.019558,0.170666,0.035094,0.571628
43,43,0.018762,0.270767,0.035092,0.612699


In [12]:
# eval on test set using determined k
best_k = int(df_perf_by_k.iloc[0]["k"])
print(f"best k: {best_k}")

# print("perf on test set --> ", end="")
# predict_evaluate(x_test, y_test_encoded, k=best_k)

result = {}
for split in x:
    if split.startswith("test"):
        print(f"perf on test set {split} --> ", end="")
        r = predict_evaluate(x[split], y_encoded[split], k=best_k)
        result.update({f"{split}_{k}": v for k, v in r.items()})

# save result
df_result = pd.DataFrame([result])
# df_result.to_json("tmp_result/tmp.json", orient="records", lines=True, force_ascii=False)
df_result.to_csv("tmp_result/tmp.csv")

best k: 24
perf on test set test_synthetic --> precision: 0.0225, recall: 0.1763, f1: 0.0400, auc_score: 0.5756
perf on test set test_synthetic_head --> precision: 0.0519, recall: 0.4090, f1: 0.0921, auc_score: 0.6923
perf on test set test_synthetic_medium --> precision: 0.0000, recall: 0.0000, f1: 0.0000, auc_score: 0.4871
perf on test set test_synthetic_tail --> precision: 0.0056, recall: 0.0444, f1: 0.0100, auc_score: 0.5094
