In [None]:
import os
import sys
sys.path.append(os.getcwd()+"/../..")
from src import paths

from src.utils import (
                        line_label_id2label, 
                        line_label_label2id, 
                        line_label_token_id2label, 
                        plot_embeddings, 
                        get_df_classificationreport,
                        pretty_confusion_matrix,
)

import pandas as pd

import torch

# MedBERT Finetune on Lines

In [None]:
results = torch.load(paths.RESULTS_PATH/"line-label"/"line-label_medbert-512_class_test.pt")
y_true = results["labels"]
last_hidden_state = results["last_hidden_state"] # If inference was run without --output_hidden_states, this will be None and no plottings will be possible
y_pred = results["preds"]

In [None]:
# Last hidden stats is list of tensors, each tensor is (seq_len, hidden_size). Will take CLS token
cls_tokens = [tensor[0,:] for tensor in last_hidden_state]

embeddings = torch.stack(cls_tokens, dim = 0)

In [None]:
display_labels = {k:v.replace("_", "/") for k,v in line_label_id2label.items()}

In [None]:
pretty_confusion_matrix(y_true, y_pred, display_labels)

In [None]:
y_valid = [line_label_id2label[label] for label in y_true]
y_pred = [line_label_id2label[pred] for pred in y_pred]
get_df_classificationreport(y_valid, y_pred, labels = sorted(line_label_id2label.values())).round(2)

In [None]:
print("Accuracy: ", (pd.Series(y_valid) == pd.Series(y_pred)).mean())

In [None]:
plot_embeddings(embeddings, [line_label_id2label[label].replace("_", "/") for label in y_true], method="umap", )

In [None]:
# Analysis of FP and TP
# results.pop("last_hidden_state")
# results_df = pd.DataFrame(results)
# results_df.replace(line_label_id2label, inplace = True)

In [None]:
# pd.set_option('display.max_rows', 100)
# pd.set_option('display.max_colwidth', None)
# results_df[results_df["preds"] != results_df["labels"]]

In [None]:
# results_df[results_df["labels"] == "head"]

# MedBERT Token Classification

In [None]:
results = torch.load(paths.RESULTS_PATH/"line-label/line-label_medbert-512_token_test.pt")

In [None]:
labels, preds = [], []
for obs in results:
    labels.extend(obs["labs"])
    preds.extend(obs["preds"])
y_true = [line_label_label2id[label] for label in labels]
y_pred = [line_label_label2id[pred] for pred in preds]


In [None]:
pretty_confusion_matrix(y_true, y_pred, display_labels, save_dir = paths.THESIS_PATH/"token-level-cm")

In [None]:
token_results = get_df_classificationreport(labels, preds, labels = sorted(line_label_id2label.values())).round(2)
token_results.to_csv(paths.THESIS_PATH/"line-label_token_results.csv")
token_results

In [None]:
print("Accuracy: ", (pd.Series(y_true) == pd.Series(y_pred)).mean())

In [None]:
# Hidden States
last_hidden_state = torch.load(paths.RESULTS_PATH/"line-label/line-label_medbert-512_token_test_hidden_states.pt")

# Only plot B-labels
b_labels = [label for label in last_hidden_state["labels"] if label in list(range(8))]
b_labels = [line_label_token_id2label[label].replace("_", "/") for label in b_labels]
b_label_idx = [i for i, label in enumerate(last_hidden_state["labels"]) if label in list(range(8))]
b_hidden_states = last_hidden_state["last_hidden_states"][b_label_idx]

In [None]:
plot_embeddings(b_hidden_states, b_labels, method="umap", save_dir = paths.THESIS_PATH/"token-label-embeddings.png")

# Comparing Both models
Need to truncate the outputs of line classifier too

In [None]:
res_line = torch.load(paths.RESULTS_PATH/"line-label"/"line-label_medbert-512_class_test.pt")
res_token = torch.load(paths.RESULTS_PATH/"line-label/line-label_medbert-512_token_test.pt")

In [None]:
res_line_rids = pd.DataFrame(res_line["rid"], columns = ["rid"])

In [None]:
indexes = []
lengths = 0
for rid, data in res_line_rids.groupby("rid"):
    # Find corresponding token obs
    for obs in res_token:
        if obs["rid"] == rid:
            length = len(obs["labs"])
            break
    # Select the indexes up to the length of the token obs
    indexes.extend(list(data.index)[:length])

In [None]:
res_line_df = pd.DataFrame(res_line, columns=res_line.keys())
res_line_df = res_line_df.loc[indexes]

In [None]:
y_true = res_line_df["labels"]
last_hidden_state = res_line_df["last_hidden_state"] # If inference was run without --output_hidden_states, this will be None and no plottings will be possible
y_pred = res_line_df["preds"]

In [None]:
# Last hidden stats is list of tensors, each tensor is (seq_len, hidden_size). Will take CLS token
cls_tokens = [tensor[0,:] for tensor in last_hidden_state]

embeddings = torch.stack(cls_tokens, dim = 0)

In [None]:
display_labels = {k:v.replace("_", "/") for k,v in line_label_id2label.items()}

In [None]:
pretty_confusion_matrix(y_true, y_pred, display_labels, save_dir = paths.THESIS_PATH/"line-level-cm.png")

In [None]:
y_valid = [line_label_id2label[label] for label in y_true]
y_pred = [line_label_id2label[pred] for pred in y_pred]
line_results = get_df_classificationreport(y_valid, y_pred, labels = sorted(line_label_id2label.values())).round(2)
line_results.to_csv(paths.THESIS_PATH/"line-label_line_results.csv")
line_results

In [None]:
print("Accuracy: ", (pd.Series(y_valid) == pd.Series(y_pred)).mean())

In [None]:
plot_embeddings(embeddings, [line_label_id2label[label].replace("_", "/") for label in y_true], method="umap", save_dir = paths.THESIS_PATH/"line-label-embeddings.png")

In [None]:
print(pd.concat([line_results.iloc[:, :-1], token_results.iloc[:, 1:-1]], axis = 1).to_latex(index=False, float_format="%.2f"))