In [None]:
import os
import sys
sys.path.append(os.getcwd()+"/../..")
from src import paths

import torch

import pandas as pd

from src.utils import plot_embeddings, pretty_confusion_matrix, ms_label2id, ms_id2label

from sklearn.metrics import classification_report

import matplotlib.pyplot as plt


In [None]:

def show_results(file_name: str, plot_hidden_states = True, plot_title:str = "") -> None:
    results = torch.load(paths.RESULTS_PATH / "ms-diag" / f"{file_name}")

    # Plot confusion matrix
    display_label_mapping = {0: "PPMS", 1: "RRMS", 2: "SPMS", 3: "Other"}
    pretty_confusion_matrix(y_true=results["labels"], y_pred=results["preds"], label_dict=display_label_mapping, title= plot_title)

    # Plot embeddings
    if plot_hidden_states:
        # Exclude None values (for pipeline approach this is to be expected)
        results["last_hidden_state"] = [batch for batch in results["last_hidden_state"] if batch is not None]
        plot_labels = [results["labels"][i] for i in range(len(results["last_hidden_state"])) if results["last_hidden_state"][i] is not None]
        # plot_labels = [display_label_mapping[label] for label in plot_labels]
        
        # Last hidden states is a list of tensors of shape (seq_len, hidden_size)
        last_hidden_state = [batch[0, :] for batch in results["last_hidden_state"]]  # Use CLS token
        embeddings = torch.stack(last_hidden_state, dim=0).to(torch.float16)
        plot_embeddings(embeddings=embeddings, labels=plot_labels, title=plot_title, method="umap", display_label_mapping=display_label_mapping)
        plt.show()

    # Print classification report
    labels = [display_label_mapping[label] for label in results["labels"]]
    preds = [display_label_mapping[pred] for pred in results["preds"]]
    print(classification_report(y_true=labels, y_pred=preds), "\n\n")

    # Show all wrongly classified samples
    for i in range(len(labels)):
        if labels[i] != preds[i]:
            print(f"Observation: {i}")
            print(f"Label: {labels[i]} - Prediction: {preds[i]}")
            try:
                print(results["text"][i])
            except:
                print(results["original_text"][i])
            # Print the probabilities for each class by converting the logits to probabilities, then rounding them
            print("Probabilities:", dict(zip(ms_label2id.keys(), [round(prob, 3) for prob in torch.softmax(torch.tensor(results["logits"][i]), dim=0).numpy()])), "\n\n")

# MedBERT 512

## Strategy: Classify on single lines, 4 Labels (including no MS) and oversampling for training

In [None]:
show_results("ms-diag_medbert-512_class_line_oversample_test.pt", plot_title="MedBERT Classification Line")

In [None]:
result = torch.load(paths.RESULTS_PATH / "ms-diag" / "ms-diag_medbert-512_class_line_oversample_test.pt")
# Creat DF from rid, pred, label, text
def majority_vote(result):
    # Create a DataFrame from the results
    df = pd.DataFrame({"rid": result["rid"], "preds": result["preds"], 
                   "labels": result["labels"], "text": result["text"], 
                   "logits": result["logits"], "last_hidden_state": result["last_hidden_state"],
                   "index_within_rid": result["index_within_rid"]})
    
    # All rids that have at least one prediction other than 3 (no MS)
    df_list = []

    for i, rid_data in df.groupby("rid"):
        # Get most frequent prediction from classes [0, 1, 2]
        # Take first of the lines predicted as this class
        _df = {}

        value_counts = rid_data["preds"].value_counts()

        if len(value_counts) == 1 or value_counts.index[0] != 3:
            majority_class = value_counts.index[0]
            _df["preds"] = majority_class
            _df["logits"] = rid_data[rid_data["preds"] == majority_class]["logits"].values[0]
            _df["last_hidden_state"] = rid_data[rid_data["preds"] == majority_class]["last_hidden_state"].values[0]

        elif len(value_counts) > 1 and value_counts.index[0] == 3:
            majority_class = value_counts.index[1]
            _df["preds"] = majority_class
            _df["logits"] = rid_data[rid_data["preds"] == majority_class]["logits"].values[0]
            _df["last_hidden_state"] = rid_data[rid_data["preds"] == majority_class]["last_hidden_state"].values[0]

        # There should only be one kind label other than 3 (no MS) or just 3
        if rid_data["labels"].value_counts().index[0] == 3 and len(rid_data["labels"].value_counts()) > 1:
            _df["labels"] = rid_data["labels"].value_counts().index[1]
        else:
            _df["labels"] = rid_data["labels"].value_counts().index[0]
        
        _df["rid"] = i
        _df["text"] = "\n".join(rid_data["text"].tolist())

        df_list.append(_df)
    
    return pd.DataFrame(df_list)

df_agg = majority_vote(result)
torch.save(df_agg.to_dict("list"), paths.RESULTS_PATH / "ms-diag" / "ms-diag_medbert-512_class_line_oversample_test_agg.pt")


In [None]:
show_results("ms-diag_medbert-512_class_line_oversample_test_agg.pt", plot_hidden_states=True, plot_title="MedBERT Classification Line Aggregated")

The bad precision stems from the fact the we have an imbalanced dataset. Even though only 4 RRMS get classified wrong, it makes a huge difference for the precision of PPMS and SPMS as there are only so few examples.

## Strategy: Classify on single lines, 3 Labels (original approach with only dm samples) and oversampling for training

In [None]:
show_results("ms-diag_medbert-512_class_line_original_approach_test.pt", plot_title="MedBERT Classification Line Original Approach")

## Strategy: Classify on whole report, 4 labels (including no ms) training on oversampled

In [None]:
show_results("ms-diag_medbert-512_class_all_oversample_test.pt", plot_title="MedBERT Classification Whole Prompt")

In [None]:
results = torch.load(paths.RESULTS_PATH / "ms-diag" / "ms-diag_medbert-512_class_all_oversample_test.pt")
plot_title = "MedBERT Classification Whole Prompt"
# Plot confusion matrix
display_label_mapping = {0: "PPMS", 1: "RRMS", 2: "SPMS", 3: "Other"}
pretty_confusion_matrix(y_true=results["labels"], y_pred=results["preds"], label_dict=display_label_mapping, title= plot_title, save_dir = paths.THESIS_PATH/"ms_diag_medbert_cm_base.png")
# Exclude None values (for pipeline approach this is to be expected)
results["last_hidden_state"] = [batch for batch in results["last_hidden_state"] if batch is not None]
plot_labels = [results["labels"][i] for i in range(len(results["last_hidden_state"])) if results["last_hidden_state"][i] is not None]
# plot_labels = [display_label_mapping[label] for label in plot_labels]

# Last hidden states is a list of tensors of shape (seq_len, hidden_size)
last_hidden_state = [batch[0, :] for batch in results["last_hidden_state"]]  # Use CLS token
embeddings = torch.stack(last_hidden_state, dim=0).to(torch.float16)
plot_embeddings(embeddings=embeddings, labels=plot_labels, title=plot_title, method="umap", display_label_mapping=display_label_mapping, save_dir = paths.THESIS_PATH/"ms_diag_medbert_embeddings_base.png")

# Print classification report
labels = [display_label_mapping[label] for label in results["labels"]]
preds = [display_label_mapping[pred] for pred in results["preds"]]
pd.DataFrame(classification_report(y_true=labels, y_pred=preds, output_dict = True)).transpose().round(2).to_csv("ms-diag_medbert-512_class_all_oversample_test.csv")

## Strategy: Classify on whole report, 3 Labels (original approach with only reports containing at least one dm line), oversampling for training

In [None]:
show_results("ms-diag_medbert-512_class_all_original_approach_test.pt", plot_title="MedBERT Classification Whole Prompt Original Approach")

## Pipeline Approach:

In [None]:
show_results("ms-diag_medbert-512_pipeline_test.pt")

In [None]:
# Label of obs 25 is wrong, should be Other
results_corrected = torch.load(paths.RESULTS_PATH / "ms-diag" / "ms-diag_medbert-512_pipeline_test.pt")
results_corrected["labels"][25] = 3

torch.save(results_corrected, paths.RESULTS_PATH / "ms-diag" / "ms-diag_medbert-512_pipeline_test_corrected.pt")

In [None]:
show_results("ms-diag_medbert-512_pipeline_test_corrected.pt")

In [None]:
results = torch.load(paths.RESULTS_PATH / "ms-diag" / "ms-diag_medbert-512_pipeline_test_corrected.pt")
plot_title = "MedBERT Classification S2A"
# Plot confusion matrix
display_label_mapping = {0: "PPMS", 1: "RRMS", 2: "SPMS", 3: "Other"}
pretty_confusion_matrix(y_true=results["labels"], y_pred=results["preds"], label_dict=display_label_mapping, title= plot_title, save_dir = paths.THESIS_PATH/"ms_diag_medbert_cm_s2a.png")
# Exclude None values (for pipeline approach this is to be expected)
results["last_hidden_state"] = [batch for batch in results["last_hidden_state"] if batch is not None]
plot_labels = [results["labels"][i] for i in range(len(results["last_hidden_state"])) if results["last_hidden_state"][i] is not None]
# plot_labels = [display_label_mapping[label] for label in plot_labels]

# Last hidden states is a list of tensors of shape (seq_len, hidden_size)
last_hidden_state = [batch[0, :] for batch in results["last_hidden_state"]]  # Use CLS token
embeddings = torch.stack(last_hidden_state, dim=0).to(torch.float16)
plot_embeddings(embeddings=embeddings, labels=plot_labels, title=plot_title, method="umap", display_label_mapping=display_label_mapping, save_dir = paths.THESIS_PATH/"ms_diag_medbert_embeddings_s2a.png")


# Print classification report
labels = [display_label_mapping[label] for label in results["labels"]]
preds = [display_label_mapping[pred] for pred in results["preds"]]
pd.DataFrame(classification_report(y_true=labels, y_pred=preds, output_dict = True)).transpose().round(2).to_csv("ms-diag_medbert-512_pipeline_test_corrected.pt")

In [None]:
# Additionally to compare to old approach remove other labels
other_idx = [i for i, label in enumerate(results_corrected["labels"]) if label == 3]
results_corrected_no = {}
for key in results_corrected.keys():
    results_corrected_no[key] = [value for i, value in enumerate(results_corrected[key]) if i not in other_idx]

torch.save(results_corrected_no, paths.RESULTS_PATH / "ms-diag" / "ms-diag_medbert-512_pipeline_test_corrected_no_other.pt")

In [None]:
show_results("ms-diag_medbert-512_pipeline_test_corrected_no_other.pt")

In [None]:
results = torch.load(paths.RESULTS_PATH / "ms-diag" / "ms-diag_medbert-512_pipeline_test_corrected_no_other.pt")
# Print classification report
labels = [display_label_mapping[label] for label in results["labels"]]
preds = [display_label_mapping[pred] for pred in results["preds"]]
pd.DataFrame(classification_report(y_true=labels, y_pred=preds, output_dict = True)).transpose().round(2).to_csv("ms-diag_medbert-512_pipeline_test_corrected_no_other.pt")

In [None]:
pd.DataFrame(classification_report(y_true=labels, y_pred=preds, output_dict = True)).transpose().round(2)

In [None]:
ds = load_ms_data("all")["test"]

In [None]:
for text in ds["text"]:
    if "V.a. entzündliche ZNS-" in text:
        print(text)