In [None]:
import os
import sys
sys.path.append(os.getcwd()+"/../..")
from src import paths

import pandas as pd
import numpy as np

import torch

from datasets import load_dataset, concatenate_datasets

from sklearn.metrics import f1_score, recall_score, precision_score, ConfusionMatrixDisplay, confusion_matrix, classification_report
from sklearn.metrics.pairwise import cosine_similarity

from src.utils import load_ms_data, ms_label2id

from typing import Callable

import warnings
warnings.filterwarnings('ignore')

# import seaborn as sns
import matplotlib.pyplot as plt

# Prompting Strategies 13B

## Base

### Whole Report

In [None]:
results = torch.load(paths.RESULTS_PATH/"ms-diag"/"ms-diag_outlines_Llama2-MedTuned-13b_4bit_all_test_few_shot_vanilla_rag.pt")
base_labels = torch.load(paths.RESULTS_PATH/"ms-diag"/"label_encodings_Llama2-MedTuned-13b.pt")

In [None]:
results["model_answers"]

In [None]:
label2id = {'primär progrediente Multiple Sklerose': 0,
              'sekundär progrediente Multiple Sklerose': 2,
              'schubförmige remittierende Multiple Sklerose': 1,
              'other': 3,
              'no information found': 3}

In [None]:
def load_p

In [None]:
def load_results(filename:str, labels:str, load_line:bool = False):
    results = torch.load(paths.RESULTS_PATH/"ms-diag"/filename, map_location=torch.device('cpu'))
    labels = torch.load(paths.RESULTS_PATH/"ms-diag"/labels, map_location=torch.device('cpu'))
    
    last_hidden_states = results.pop("last_hidden_states")
    last_hidden_states = last_hidden_states.cpu()
    
    pred_idx = np.argmax(cosine_similarity(last_hidden_states, labels[1]), axis=1)
    results["preds"] = [labels[0][i] for i in pred_idx]

    df = pd.DataFrame(results)

    if load_line:
        return convert_line2report(df, filename)
    else:
        return df

def show_results(results:pd.DataFrame):
    contains_any = results["prediction"].str.contains('|'.join(base_labels[0]), case=False)
    print("Percent of correctly formatted labels: ", sum(contains_any)/len(results))
    print(classification_report(y_true = results["labels"], y_pred = results["preds"]))
    ConfusionMatrixDisplay.from_predictions(y_true = results["labels"], y_pred = results["preds"], display_labels=ms_label2id, xticks_rotation="vertical")

def convert_line2report(results: pd.DataFrame, filename: str):
    split = filename.split("_")[4]
    df = load_ms_data("line")
    
    if split == "all":
        df = concatenate_datasets([df["train"], df["val"], df["test"]])
    else:
        df = df[split]
    
    results['index_within_rid'] = df['index_within_rid']

    df_list = []

    for i, rid_data in results.groupby("rid"):
        _df = rid_data[rid_data["index_within_rid"] == 0].copy()  # Create a copy to avoid chained indexing warnings

        rid_data_sorted = rid_data.sort_values('index_within_rid')
        _df.loc[:, "report"] = "\n".join(rid_data_sorted["report"].tolist())

        # There should only be one kind label other than 3 (no MS) or just 3
        if rid_data["labels"].value_counts().index[0] == 3 and len(rid_data["labels"].value_counts()) > 1:
            _df.loc[:, "labels"] = rid_data["labels"].value_counts().index[1]
        else:
            _df.loc[:, "labels"] = rid_data["labels"].value_counts().index[0]

        _df.loc[:, "rid"] = i
        df_list.append(_df)

    df_agg = pd.concat(df_list, ignore_index=True)
    
    return df_agg

def summarize_performance(files: list[str]):
    dfs = []
    for file in files:
        splitted_filename = file.split("_")
        model_name = splitted_filename[1]
        level = splitted_filename[3]
        split = splitted_filename[4]
        strategy = " ".join(splitted_filename[5:])[:-3]

        results = load_results(file)

        metric_dict = classification_report(y_true=results["labels"], y_pred=results["preds"], output_dict=True,
                                            target_names=target_names)

        # Create a dictionary with flattened keys
        _df = pd.json_normalize(metric_dict)
        

        # Add additional information
        _df["strategy"] = strategy
        _df["level"] = level
        _df["split"] = split
        contains_any = results["prediction"].str.contains('|'.join(base_labels[0]), case=False)
        _df["valid_label"] = sum(contains_any)/len(results)

        # Reorder columns
        reordered_cols = _df.columns[-4:].append(_df.columns[:-4])
        _df = _df[reordered_cols]

        dfs.append(_df)

    return pd.concat(dfs)



In [None]:
pd.set_option('display.max_columns', 50)
base_labels = torch.load(paths.RESULTS_PATH/"ms-diag"/"label_encodings_Llama2-MedTuned-13b.pt")
base_labels[0], ms_label2id
# Reformat base labels such that dim 0 of tensor corresponds to ms_label2id index
labels_encoded = base_labels[1][[2,0,1,3],:].cpu()
summarize_performance(["ms-diag_Llama2-MedTuned-13b_4bit_all_all_zero_shot_vanilla.pt",
                      "ms-diag_Llama2-MedTuned-13b_4bit_all_all_zero_shot_instruction.pt",
                      "ms-diag_Llama2-MedTuned-13b_4bit_all_all_few_shot_vanilla.pt",
                      "ms-diag_Llama2-MedTuned-13b_4bit_all_all_few_shot_instruction.pt",
                       "ms-diag_Llama2-MedTuned-13b_4bit_line_all_zero_shot_vanilla.pt",
                      "ms-diag_Llama2-MedTuned-13b_4bit_line_all_zero_shot_instruction.pt",
                      "ms-diag_Llama2-MedTuned-13b_4bit_line_all_few_shot_vanilla.pt",
                      "ms-diag_Llama2-MedTuned-13b_4bit_line_all_few_shot_instruction.pt"
                      ])

- Improvement from zero shot vanilla to few shot vanilla. This helps especially with valid_label output (so model knows better how to structure output)
- Improvement from vanilla to instruct (leveraging pretraining)
- Combining few shot with instruct has mixed results. Unclear how pretrain objective and few shot interact maybe model is confused because of mixed input format
- Zero-Shot Instruct seems to perform best. Is in line with their fine-tune objective
- Improvement from all of report to just line. Prompt gets less washed out. Model can focus on relevant input. Especially recall for actual ms classes is really good, indicating that when just presenting relevant info to the model it is quite adept at filtering out. Precision is often less because model has trouble with not enough info (no_ms) recall, it is prone to hallucinations. Because of imbalancedness, just a few falsely classified no_ms classes can have huge impact

## Zero-Shot

### Vanilla

In [None]:
# Whole report
results = load_results("ms-diag_Llama2-MedTuned-13b_4bit_all_all_zero_shot_vanilla.pt")
show_results(results)

In [None]:
# Line report
results = load_results("ms-diag_Llama2-MedTuned-13b_4bit_line_all_zero_shot_vanilla.pt")
show_results(results)

### Instruction

In [None]:
# Whole Report
results = load_results("ms-diag_Llama2-MedTuned-13b_4bit_all_all_zero_shot_instruction.pt")
show_results(results)

In [None]:
# Line Result
results = load_results("ms-diag_Llama2-MedTuned-13b_4bit_line_all_zero_shot_instruction.pt")
show_results(results)

## Few-Shot

### Vanilla

In [None]:
results = load_results("ms-diag_Llama2-MedTuned-13b_4bit_all_all_few_shot_vanilla.pt")
show_results(results)

In [None]:
results = load_results("ms-diag_Llama2-MedTuned-13b_4bit_line_all_few_shot_vanilla.pt")
show_results(results)

### Instruction

In [None]:
results = load_results("ms-diag_Llama2-MedTuned-13b_4bit_all_all_few_shot_instruction.pt")
show_results(results)

In [None]:
results = load_results("ms-diag_Llama2-MedTuned-13b_4bit_line_all_few_shot_instruction.pt")
show_results(results)

## Multi-Turn

In [None]:
results = load_results("ms-diag_Llama2-MedTuned-13b_4bit_all_test_two_steps.pt")
show_results(results)

In [None]:
results = load_results("ms-diag_Llama2-MedTuned-13b_4bit_line_all_zero_shot_instruction.pt", load_line = True)
show_results(results)

# Prompting Strategies 7B

# Summary

In [None]:
base_labels = torch.load(paths.RESULTS_PATH/"ms-diag"/"label_encodings_Llama2-MedTuned-7b.pt")
base_labels[0], ms_label2id
# Reformat base labels such that dim 0 of tensor corresponds to ms_label2id index
labels_encoded = base_labels[1][[2,0,1,3],:].cpu()
summarize_performance(["ms-diag_Llama2-MedTuned-7b_4bit_all_all_zero_shot_vanilla.pt",
                      "ms-diag_Llama2-MedTuned-7b_4bit_all_all_zero_shot_instruction.pt",
                      "ms-diag_Llama2-MedTuned-7b_4bit_all_all_few_shot_vanilla.pt",
                      "ms-diag_Llama2-MedTuned-7b_4bit_all_all_few_shot_instruction.pt",
                       "ms-diag_Llama2-MedTuned-7b_4bit_line_all_zero_shot_vanilla.pt",
                      "ms-diag_Llama2-MedTuned-7b_4bit_line_all_zero_shot_instruction.pt",
                      "ms-diag_Llama2-MedTuned-7b_4bit_line_all_few_shot_vanilla.pt",
                      "ms-diag_Llama2-MedTuned-7b_4bit_line_all_few_shot_instruction.pt"
                      ])

- Improvement from zero shot vanilla to few shot vanilla to . This helps especially with valid_label output (so model knows better how to structure output)
- Improvement from vanilla to instruct (leveraging pretraining)
- Few Shot Vanilla and Zero Shot instruct sometimes seem to perform equally well.
- Combining few shot with instruct has mixed results. Unclear how pretrain objective and few shot interact maybe model is confused because of mixed input format
- Improvement from all of report to just line. Prompt gets less washed out. Model can focus on relevant input. Especially recall for actual ms classes is really good, indicating that when just presenting relevant info to the model it is quite adept at filtering out. Precision is often less because model has trouble with not enough info (no_ms) recall, it is prone to hallucinations. Because of imbalancedness, just a few falsely classified no_ms classes can have huge impact

## Zero-Shot

### Vanilla

In [None]:
# Whole report
results = load_results("ms-diag_Llama2-MedTuned-7b_4bit_all_all_zero_shot_vanilla.pt")
show_results(results)

In [None]:
# Line report
results = load_results("ms-diag_Llama2-MedTuned-7b_4bit_line_all_zero_shot_vanilla.pt")
show_results(results)

### Instruction

In [None]:
results = load_results("ms-diag_Llama2-MedTuned-7b_4bit_all_all_zero_shot_instruction.pt")
show_results(results)

In [None]:
# Line Result
results = load_results("ms-diag_Llama2-MedTuned-7b_4bit_line_all_zero_shot_instruction.pt")
show_results(results)

## Few-Shot

### Vanilla

In [None]:
results = load_results("ms-diag_Llama2-MedTuned-7b_4bit_all_all_few_shot_vanilla.pt")
show_results(results)

In [None]:
results = load_results("ms-diag_Llama2-MedTuned-7b_4bit_line_all_few_shot_vanilla.pt")
show_results(results)

### Instruction

In [None]:
results = load_results("ms-diag_Llama2-MedTuned-7b_4bit_all_all_few_shot_instruction.pt")
show_results(results)

In [None]:
results = load_results("ms-diag_Llama2-MedTuned-7b_4bit_line_all_few_shot_instruction.pt")
show_results(results)

## Multi-Turn

In [None]:
results = load_results("ms-diag_Llama2-MedTuned-7b_4bit_all_all_two_steps.pt")
show_results(results)

# Generation Strategies

In [None]:
job_id = 3592235
output = pd.read_csv(paths.RESULTS_PATH/'ms-diag/ms_diag-llama2-chat_zero-shot_generation-strats_3595180.csv')
# Map from output to label
# Dictionary to map keywords to labels
keyword_label_mapping = {
    "rrms": 'relapsing_remitting_multiple_sclerosis',
    "spms": 'secondary_progressive_multiple_sclerosis',
    "ppms": 'primary_progressive_multiple_sclerosis',
    "remittierend": 'relapsing_remitting_multiple_sclerosis',
    "schubförmig": 'relapsing_remitting_multiple_sclerosis',
    "sekundär": 'secondary_progressive_multiple_sclerosis',
    "primär": 'primary_progressive_multiple_sclerosis',
}

# Number of columns in the results dataframe
n_cols = len(output.columns)

# Function to assign labels based on text content
def assign_label(text):
    for keyword, label in keyword_label_mapping.items():
        if keyword in text.lower():
            return label
    return "unknown"  # Default label if no keyword is found

# Assign labels to each text in the list
for col in output.columns[1:]:
    output[f'{col}_r'] = output[col].apply(assign_label)

In [None]:
# Count "unknown" labels per column
for col in output.columns[n_cols:]:
    print(f'{col}: {output[col].value_counts()["unknown"]}')

In [None]:
def calculate_metrics(cm:np.ndarray):
    """Calculate metrics from a confusion matrix. Even if matrix is not square, it will calculate metrics for each class."""
    num_classes = len(cm)

    precision = []
    recall = []
    f1_score = []

    for i in range(num_classes):
        # True Positive, False Positive, False Negative
        TP = cm[i, i]
        FP = np.sum(cm[:, i]) - TP
        FN = np.sum(cm[i, :]) - TP

        # Precision, Recall, F1 Score
        precision_i = TP / (TP + FP) if (TP + FP) > 0 else 0
        recall_i = TP / (TP + FN) if (TP + FN) > 0 else 0
        f1_score_i = 2 * precision_i * recall_i / (precision_i + recall_i) if (precision_i + recall_i) > 0 else 0

        precision.append(precision_i)
        recall.append(recall_i)
        f1_score.append(f1_score_i)

    weighted_precision = np.sum(np.array(precision) * np.sum(cm, axis=1)) / np.sum(cm)
    weighted_recall = np.sum(np.array(recall) * np.sum(cm, axis=1)) / np.sum(cm)
    weighted_f1_score = np.sum(np.array(f1_score) * np.sum(cm, axis=1)) / np.sum(cm)

    macro_precision = np.mean(precision)
    macro_recall = np.mean(recall)
    macro_f1_score = np.mean(f1_score)

    micro_precision = np.sum(cm.diagonal()) / np.sum(cm)
    micro_recall = np.sum(cm.diagonal()) / np.sum(cm)
    micro_f1_score = 2 * (micro_precision * micro_recall) / (micro_precision + micro_recall) if (micro_precision + micro_recall) > 0 else 0

    return {
        'weighted_precision': weighted_precision,
        'weighted_recall': weighted_recall,
        'weighted_f1_score': weighted_f1_score,
        'macro_precision': macro_precision,
        'macro_recall': macro_recall,
        'macro_f1_score': macro_f1_score,
        'micro_precision': micro_precision,
        'micro_recall': micro_recall,
        'micro_f1_score': micro_f1_score
    }


In [None]:
# Confusion matrix
labels = list(set(output["labels"])) + ["unknown"]
results = []
for col in output.columns[n_cols:]:
    print(f'{col}:')
    conf_mat = confusion_matrix(y_true=output['labels'], y_pred=output[col], labels=labels)
    sns.heatmap(conf_mat, annot=True)
    results.append(calculate_metrics(conf_mat))
    print()
    break

pd.DataFrame(results, index=output.columns[n_cols:])

In [None]:
labels = ['primary_progressive_multiple_sclerosis', 'relapsing_remitting_multiple_sclerosis', 'secondary_progressive_multiple_sclerosis', 'unknown']
label_names = ['PPMS', 'RRMS', 'SPMS', 'unknown']
results = []

# Set the number of columns in the grid
grid_cols = 2
grid_rows = 6

# Create subplots
fig, axes = plt.subplots(grid_rows, grid_cols, figsize=(20, 10))

for idx, col in enumerate(output.columns[n_cols:]):
    row_idx = idx // grid_cols
    col_idx = idx % grid_cols

    # Calculate confusion matrix
    conf_mat = confusion_matrix(y_true=output['labels'], y_pred=output[col], labels=labels)

    # Plot heatmap in the corresponding subplot
    sns.heatmap(conf_mat[:-1,:], ax=axes[row_idx, col_idx], annot=True, fmt='d', cmap='Blues')
    axes[row_idx, col_idx].set_title(col)
    axes[row_idx, col_idx].set(yticklabels=[])
    axes[row_idx, col_idx].set(xticklabels=[])
    axes[row_idx, col_idx].tick_params(left=False, bottom=False)

    if col_idx == 0:
        axes[row_idx, col_idx].set(yticklabels=label_names[:-1])
        plt.setp(axes[row_idx, col_idx].get_yticklabels(), rotation=0, horizontalalignment='right')
        axes[row_idx, col_idx].collections[0].colorbar.remove()
    
    if row_idx == grid_rows - 1:
        axes[row_idx, col_idx].set(xticklabels=label_names)

    
    # Calculate metrics and append to the results
    results.append(calculate_metrics(conf_mat))

# Adjust layout
plt.tight_layout()

fig.add_subplot(111, frameon=False)
# hide tick and tick label of the big axis
plt.tick_params(labelcolor='none', which='both', top=False, bottom=False, left=False, right=False)
plt.xlabel("Predicted Class", labelpad=20, fontsize=16)
plt.ylabel("True Class", labelpad=20, fontsize=16)

# Display the grid of heatmaps
plt.show()

# Display the results in a DataFrame
results_df = pd.DataFrame(results, index=output.columns[n_cols:]).round(2)
display(results_df)


In [None]:
# Latex Table
num_strategies = 4
latex_string_first_row = r"\multirow{{4}}{{*}}{{{truncation_size}}} & {{{strategy}}} & {{{precision}}} & {{{recall}}} & {{{f1}}} \\"
latex_string_other_rows = r" & {{{strategy}}} & {{{precision}}} & {{{recall}}} & {{{f1}}} \\"

for idx, index in enumerate(results_df.index):
    truncation_size = index.split('_')[1]
    strategy = index.split('_')[3].capitalize()
    if idx %num_strategies == 0:
        print(r"\midrule")
        print(latex_string_first_row.format(truncation_size=truncation_size, strategy=strategy, precision=results_df.iloc[idx, 0], recall=results_df.iloc[idx, 1], f1=results_df.iloc[idx, 2]))
    else:
        print(latex_string_other_rows.format(strategy=strategy, precision=results_df.iloc[idx, 0], recall=results_df.iloc[idx, 1], f1=results_df.iloc[idx, 2]))