In [None]:
import os
import sys
sys.path.append(os.getcwd()+"/../..")
from src import paths

import pandas as pd
import numpy as np

import torch

from sklearn.metrics.pairwise import cosine_similarity

from sklearn.metrics import f1_score, recall_score, precision_score, classification_report

import warnings
warnings.filterwarnings("ignore")

In [None]:
def load_results(filename:str, labels:str, load_line:bool = False, load_outlines:bool = False, label_correction:dict = None):
    """ 
    Load results from a file and return a pandas dataframe with the results.

    Args:
        filename (str): The name of the file to load.
        labels (str): The name of the file containing the labels.
        load_line (bool, optional): If True, the function will expect a line labelled dataset and return the aggregated results per rid.
        load_outlines (bool, optional): If True, the function will expect the output of an outlines prompt (no hidden states and model answers are labels)
        label_correcttion (dict, optional): A dictionary with rids as keys and the corrected labels as values.

    Returns:
        pd.DataFrame: A pandas dataframe with the results.

    """
    results = torch.load(paths.RESULTS_PATH/"ms-diag"/filename, map_location=torch.device('cpu'))
    labels = torch.load(paths.RESULTS_PATH/"ms-diag"/labels, map_location=torch.device('cpu'))

    output = {"model_answers": results.pop("model_answers")}

    if not load_outlines:
        
        # Get prediction through cosine similarity
        last_hidden_states = results.pop("last_hidden_states")
        last_hidden_states = last_hidden_states.cpu()
        preds = get_prediction(last_hidden_states, labels_hs=labels[1], label_names=labels[0])

        # If rag
        if "rag" in filename:
            output["model_answers_mapped"] = preds + ["no information found" for _ in range(len(output["model_answers"])-len(preds))]
        else:
            output["model_answers_mapped"] = preds
    
    # Map from string answer to int
    label2id = {'primär progrediente Multiple Sklerose': 0,
            'sekundär progrediente Multiple Sklerose': 2,
            'schubförmige remittierende Multiple Sklerose': 1,
            'other': 3,
            'no information found': 3}
    
    if load_outlines:
        key = "model_answers"
    else:
        key = "model_answers_mapped"
    
    output["preds"] = [label2id[i] for i in output[key]]
    output["exact_match"] = [res.lower().split("### output:\n")[-1] in [key.lower() for key in label2id.keys()] for res in output["model_answers"]]

    output["labels"] = results["labels"]
    output["rid"] = results["rid"]
    output["text"] = results["text"]

    if "whole_prompt" in results.keys():
        whole_prompt = results["whole_prompt"]
        if "rag" in filename:
            output["whole_prompt"] = whole_prompt + ["no information found" for _ in range(len(output["model_answers"])-len(whole_prompt))]
        else:
            output["whole_prompt"] = whole_prompt
    
    df = pd.DataFrame(output)

    # Correcting wrong labels from analysis:
    if label_correction:
        for rid, label in label_correction.items():
            df.loc[df.rid == rid, "labels"] = label
    

    if load_line:
        return convert_line2report(df, filename)
    else:
        return df

def get_prediction(hs: torch.Tensor, labels_hs: torch.Tensor, label_names: list):
    """ 
    Get the prediction for a given hidden state and labels.

    Args:
        hs (torch.Tensor): The hidden state to use for prediction. Shape (n_samples, n_features).
        labels_hs (torch.Tensor): The hidden states of the labels. Shape (n_labels, n_features).
        label_names (list): The names of the labels. The order of the names should match the order of the hs-labels in dim 0.

    Returns:
        torch.Tensor: The predicted labels.

    """
    answer_idx = np.argmax(cosine_similarity(hs, labels_hs), axis=1)
    model_answer = [label_names[i] for i in answer_idx]
    return model_answer

def convert_line2report(rids:list[str], preds:list[int])->list[int]:
    """ 
    Aggregeates the results of a line labelled dataset in a majority vote fashion.

    Args:
        rids (list[str]): The rids of the samples.
        preds (list[int]): The predictions of the samples.

    Returns:
        list[int]: The aggregated predictions.
        
    """
    df = pd.DataFrame({"rid":rids, "preds":preds})

    # If len value counts >1 and value counts[0] is 3, then majority vote is value counts[1]
    results = []
    rid = []
    for _df in df.groupby("rid"):
        if _df[1].preds.value_counts().shape[0] > 1 and _df[1].preds.value_counts().iloc[0] == 3:
            result = df.loc[df.rid == _df[0], "preds"] = _df[1].preds.value_counts().index[1]
        else:
            result = df.loc[df.rid == _df[0], "preds"] = _df[1].preds.value_counts().index[0]
        results.append(result)
        rid.append(_df[0])
    return result


def summarize_performance(files: list[str], *args, **kwargs):

    """ 
    Summarizes the performance of a given strategy (Base, RAG, Outlines) for all prompting strategies.

    Args:
        files (list[str]): The files to summarize.
        *args: Variable length argument list.
        **kwargs: Arbitrary keyword arguments.

    Returns:
        pd.DataFrame: A pandas dataframe with the summarized results.

    """
    dfs = []
    for filename in files:
        strategies = ["zero_shot_vanilla", "zero_shot_instruction", "few_shot_vanilla", "few_shot_instruction"]
        
        # Set strategy name to whatever is found in the filename
        for strategy in strategies:
            if strategy in filename:
                break

        results = load_results(filename=filename, labels=kwargs.get("labels"), load_line=kwargs.get("load_lines"), load_outlines=kwargs.get("load_outlines"), label_correction=kwargs.get("label_correction"))

        # Target names
        metric_dict = classification_report(y_true=results["labels"], y_pred=results["preds"], output_dict=True)

        # Create a dictionary with flattened keys
        _df = pd.json_normalize(metric_dict)
        

        # Add additional information
        _df["strategy"] = strategy

        # Look for exact label
        _df["valid_label"] = sum(results["exact_match"])/len(results)

        # Reorder columns
        reordered_cols = _df.columns[-4:].append(_df.columns[:-4])
        _df = _df[reordered_cols]

        dfs.append(_df)

    return pd.concat(dfs)

def generate_latex_table(df: pd.DataFrame, caption: str, label: str, metrics = list[str]) -> str:
    """
    Generate LaTeX code for a table from a pandas DataFrame.

    Args:
        df (pd.DataFrame): DataFrame containing the table data.
        caption (str): Caption for the table.
        label (str): Label for referencing the table.
        metrics (list[str]): List of metrics to include in the table.

    Returns:
        str: LaTeX code for the table.
    """
    # Create the LaTeX code for the table
    latex = "\\begin{table}[h]\n"
    latex += "    \\centering\n"

    # Create tables for each metric
    for metric in metrics:
        latex += f"    \\begin{{tabular}}{{lcccc}}\n"
        latex += f"        \\toprule\n"
        latex += f"        & \\multicolumn{{4}}{{c}}{{{metric}}}\\\\\n"
        latex += f"        \\cmidrule(lr){{2-5}}\n"
        latex += f"        & Base & +S2A & +Outlines & +S2A \\& Outlines \\\\\n"
        latex += f"        \\midrule\n"

        for strategy in df['strategy'].unique():
            base_value = df[(df['strategy'] == strategy) & (df['approach'] == 'base')][metric].iloc[0]
            rag_value = df[(df['strategy'] == strategy) & (df['approach'] == 'rag')][metric].iloc[0]
            outlines_value = df[(df['strategy'] == strategy) & (df['approach'] == 'outlines')][metric].iloc[0]
            outlines_rag_value = df[(df['strategy'] == strategy) & (df['approach'] == 'outlines_rag')][metric].iloc[0]
            
            latex += f"        {strategy.replace('_', '-')} & {base_value:.2f} & {rag_value:.2f} & {outlines_value:.2f} & {outlines_rag_value:.2f} \\\\\n"

        latex += f"        \\bottomrule\n"
        latex += f"    \\end{{tabular}}\n\n"

    # Add caption and label
    latex += f"    \\caption{{{caption}}}\n"
    latex += f"    \\label{{tab:{label}}}\n"
    latex += "\\end{table}\n"

    return latex

# Llama2-MedTuned-13B

## RAG

In [None]:
pd.set_option('display.max_columns', None)
summarize_performance(["ms-diag_Llama2-MedTuned-13b_4bit_all_test_zero_shot_vanilla_rag.pt",
                       "ms-diag_Llama2-MedTuned-13b_4bit_all_test_zero_shot_instruction_rag.pt",
                       "ms-diag_Llama2-MedTuned-13b_4bit_all_test_few_shot_vanilla_rag.pt",
                       "ms-diag_Llama2-MedTuned-13b_4bit_all_test_few_shot_instruction_rag.pt",
                       ], labels="label_encodings_Llama2-MedTuned-13b.pt", load_line=False, load_outlines=False)

## Outlines

In [None]:
pd.set_option('display.max_columns', None)
summarize_performance(["ms-diag_outlines_Llama2-MedTuned-13b_4bit_all_test_zero_shot_vanilla.pt",
                       "ms-diag_outlines_Llama2-MedTuned-13b_4bit_all_test_zero_shot_instruction.pt",
                       "ms-diag_outlines_Llama2-MedTuned-13b_4bit_all_test_few_shot_vanilla.pt",
                       "ms-diag_outlines_Llama2-MedTuned-13b_4bit_all_test_few_shot_instruction.pt",
                       ], labels="label_encodings_Llama2-MedTuned-7b.pt", load_line=False, load_outlines=True)

## RAG + Outlines

In [None]:
pd.set_option('display.max_columns', None)
summarize_performance(["ms-diag_outlines_Llama2-MedTuned-13b_4bit_all_test_zero_shot_vanilla_rag.pt",
                       "ms-diag_outlines_Llama2-MedTuned-13b_4bit_all_test_zero_shot_instruction_rag.pt",
                       "ms-diag_outlines_Llama2-MedTuned-13b_4bit_all_test_few_shot_vanilla_rag.pt",
                       "ms-diag_outlines_Llama2-MedTuned-13b_4bit_all_test_few_shot_instruction_rag.pt",
                       ], labels="label_encodings_Llama2-MedTuned-13b.pt", load_line=False, load_outlines=True)

### Label Corrections
- rid: 96BC21AA-235F-4EED-A74F-58EFC11C1176, 
    - text: Hochgradiger V.a. entzündliche ZNS-Erkrankung ED 03.05.2019, EM 30.04.2019\nINDENT ätiologisch: möglicherweise multiple Sklerose
    - label: 1
    - corrected label: 3
- rid: B5B6D014-7E02-44E5-8390-F571F7C1D4E5
    - text: Multiple Sklerose (EM 05/2011, ED 09/2011)
    - label: 1
    - corrected label: 3

In [None]:
label_correction = {"96BC21AA-235F-4EED-A74F-58EFC11C1176" : 3, "B5B6D014-7E02-44E5-8390-F571F7C1D4E5": 3}

In [None]:
overall_results = pd.concat([summarize_performance(["ms-diag_Llama2-MedTuned-13b_4bit_all_test_zero_shot_vanilla.pt",
                                                       "ms-diag_Llama2-MedTuned-13b_4bit_all_test_zero_shot_instruction.pt",
                                                       "ms-diag_Llama2-MedTuned-13b_4bit_all_test_few_shot_vanilla.pt",
                                                       "ms-diag_Llama2-MedTuned-13b_4bit_all_test_few_shot_instruction.pt",
                                                       ], labels="label_encodings_Llama2-MedTuned-13b.pt", load_line=False, load_outlines=False, label_correction=label_correction),
                              summarize_performance(["ms-diag_Llama2-MedTuned-13b_4bit_all_test_zero_shot_vanilla_rag.pt",
                                                      "ms-diag_Llama2-MedTuned-13b_4bit_all_test_zero_shot_instruction_rag.pt",
                                                         "ms-diag_Llama2-MedTuned-13b_4bit_all_test_few_shot_vanilla_rag.pt",
                                                         "ms-diag_Llama2-MedTuned-13b_4bit_all_test_few_shot_instruction_rag.pt",
                                                         ], labels="label_encodings_Llama2-MedTuned-13b.pt", load_line=False, load_outlines=False, label_correction=label_correction),
                            summarize_performance(["ms-diag_outlines_Llama2-MedTuned-13b_4bit_all_test_zero_shot_vanilla.pt",
                                                       "ms-diag_outlines_Llama2-MedTuned-13b_4bit_all_test_zero_shot_instruction.pt",
                                                       "ms-diag_outlines_Llama2-MedTuned-13b_4bit_all_test_few_shot_vanilla.pt",
                                                       "ms-diag_outlines_Llama2-MedTuned-13b_4bit_all_test_few_shot_instruction.pt",
                                                       ], labels="label_encodings_Llama2-MedTuned-7b.pt", load_line=False, load_outlines=True, label_correction=label_correction),
                            summarize_performance(["ms-diag_outlines_Llama2-MedTuned-13b_4bit_all_test_zero_shot_vanilla_rag.pt",
                                                       "ms-diag_outlines_Llama2-MedTuned-13b_4bit_all_test_zero_shot_instruction_rag.pt",
                                                       "ms-diag_outlines_Llama2-MedTuned-13b_4bit_all_test_few_shot_vanilla_rag.pt",
                                                       "ms-diag_outlines_Llama2-MedTuned-13b_4bit_all_test_few_shot_instruction_rag.pt",
                                                       ], labels="label_encodings_Llama2-MedTuned-13b.pt", load_line=False, load_outlines=True, label_correction=label_correction)], ignore_index=True)
overall_results["approach"] = ["base"] * 4 + ["rag"] * 4 + ["outlines"] * 4 + ["outlines_rag"] * 4
overall_results.rename(columns={"macro avg.f1-score": "Macro F1-Score",
                                "macro avg.precision": "Macro Precision",
                                "macro avg.recall": "Macro Recall"}, inplace=True)

In [None]:
print(generate_latex_table(overall_results, 
                           caption="Macro Precision, Recall and F1-Score for different prompting strategies using Llama2-MedTuned-13B \cite{rohanian2023exploring}",
                           label="tab:ms-diag-13B",
                           metrics=["Macro Precision", "Macro Recall", "Macro F1-Score"]))

## Valid Label

In [None]:
overall_results[['approach', 'strategy', 'valid_label']]

# Llama 7B

## Base

In [None]:
pd.set_option('display.max_columns', None)
summarize_performance(["ms-diag_Llama2-MedTuned-7b_4bit_all_test_zero_shot_vanilla.pt",
                       "ms-diag_Llama2-MedTuned-7b_4bit_all_test_zero_shot_instruction.pt",
                       "ms-diag_Llama2-MedTuned-7b_4bit_all_test_few_shot_vanilla.pt",
                       "ms-diag_Llama2-MedTuned-7b_4bit_all_test_few_shot_instruction.pt",
                       ], labels="label_encodings_Llama2-MedTuned-7b.pt", load_line=False, load_outlines=False)

## Outlines

In [None]:
pd.set_option('display.max_columns', None)
summarize_performance(["ms-diag_outlines_Llama2-MedTuned-7b_4bit_all_test_zero_shot_vanilla.pt",
                       "ms-diag_outlines_Llama2-MedTuned-7b_4bit_all_test_zero_shot_instruction.pt",
                       "ms-diag_outlines_Llama2-MedTuned-7b_4bit_all_test_few_shot_vanilla.pt",
                       "ms-diag_outlines_Llama2-MedTuned-7b_4bit_all_test_few_shot_instruction.pt",
                       ], labels="label_encodings_Llama2-MedTuned-7b.pt", load_line=False, load_outlines=True)

## RAG + Outlines

In [None]:
summarize_performance(["ms-diag_outlines_Llama2-MedTuned-7b_4bit_all_test_zero_shot_vanilla_rag.pt",
                          "ms-diag_outlines_Llama2-MedTuned-7b_4bit_all_test_zero_shot_instruction_rag.pt",
                          "ms-diag_outlines_Llama2-MedTuned-7b_4bit_all_test_few_shot_vanilla_rag.pt",
                          "ms-diag_outlines_Llama2-MedTuned-7b_4bit_all_test_few_shot_instruction_rag.pt",
                          ], labels="label_encodings_Llama2-MedTuned-7b.pt", load_line=False, load_outlines=True)

In [None]:
overall_results = pd.concat([summarize_performance(["ms-diag_Llama2-MedTuned-7b_4bit_all_test_zero_shot_vanilla.pt",
                            "ms-diag_Llama2-MedTuned-7b_4bit_all_test_zero_shot_instruction.pt",
                            "ms-diag_Llama2-MedTuned-7b_4bit_all_test_few_shot_vanilla.pt",
                            "ms-diag_Llama2-MedTuned-7b_4bit_all_test_few_shot_instruction.pt",
                            ], labels="label_encodings_Llama2-MedTuned-7b.pt", load_line=False, load_outlines=False, label_correction=label_correction),
                                 summarize_performance(["ms-diag_Llama2-MedTuned-7b_4bit_all_test_zero_shot_vanilla_rag.pt",
                            "ms-diag_Llama2-MedTuned-7b_4bit_all_test_zero_shot_instruction_rag.pt",
                            "ms-diag_Llama2-MedTuned-7b_4bit_all_test_few_shot_vanilla_rag.pt",
                            "ms-diag_Llama2-MedTuned-7b_4bit_all_test_few_shot_instruction_rag.pt",
                            ], labels="label_encodings_Llama2-MedTuned-7b.pt", load_line=False, load_outlines=False, label_correction=label_correction),
                                 summarize_performance(["ms-diag_outlines_Llama2-MedTuned-7b_4bit_all_test_zero_shot_vanilla.pt",
                            "ms-diag_outlines_Llama2-MedTuned-7b_4bit_all_test_zero_shot_instruction.pt",
                            "ms-diag_outlines_Llama2-MedTuned-7b_4bit_all_test_few_shot_vanilla.pt",
                            "ms-diag_outlines_Llama2-MedTuned-7b_4bit_all_test_few_shot_instruction.pt",
                            ], labels="label_encodings_Llama2-MedTuned-7b.pt", load_line=False, load_outlines=True, label_correction=label_correction),
                                 summarize_performance(["ms-diag_outlines_Llama2-MedTuned-7b_4bit_all_test_zero_shot_vanilla_rag.pt",
                            "ms-diag_outlines_Llama2-MedTuned-7b_4bit_all_test_zero_shot_instruction_rag.pt",
                            "ms-diag_outlines_Llama2-MedTuned-7b_4bit_all_test_few_shot_vanilla_rag.pt",
                            "ms-diag_outlines_Llama2-MedTuned-7b_4bit_all_test_few_shot_instruction_rag.pt",
                            ], labels="label_encodings_Llama2-MedTuned-7b.pt", load_line=False, load_outlines=True, label_correction=label_correction)])
overall_results["approach"] = ["base"] * 4 + ["rag"] * 4 + ["outlines"] * 4 + ["outlines_rag"] * 4
overall_results.rename(columns={"macro avg.f1-score": "Macro F1-Score",
                                "macro avg.precision": "Macro Precision",
                                "macro avg.recall": "Macro Recall"}, inplace=True)

In [None]:
print(generate_latex_table(overall_results, 
                           caption="Macro Precision, Recall and F1-Score for different prompting strategies using Llama2-MedTuned-13B \cite{rohanian2023exploring}",
                           label="tab:ms-diag-13B",
                           metrics=["Macro Precision", "Macro Recall", "Macro F1-Score"]))

## Valid Label

In [None]:
overall_results[['approach', 'strategy', 'valid_label']]

# Comparison with Old Approach (no other class)

In [None]:
results_fso = torch.load(paths.RESULTS_PATH/"ms-diag/ms-diag_outlines_Llama2-MedTuned-13b_4bit_all_test_few_shot_instruction.pt")
other_idx = other_idx = [i for i, label in enumerate(results_fso["labels"]) if label == 3]

results_fso_no = {}
for key in results_fso.keys():
    results_fso_no[key] = [value for i, value in enumerate(results_fso[key]) if i not in other_idx]

torch.save(results_fso_no, paths.RESULTS_PATH/"ms-diag/ms-diag_outlines_Llama2-MedTuned-13b_4bit_all_test_few_shot_instruction_no.pt")

In [None]:
results_fsos2a = torch.load(paths.RESULTS_PATH/"ms-diag/ms-diag_outlines_Llama2-MedTuned-13b_4bit_all_test_few_shot_instruction_rag.pt")
other_idx = other_idx = [i for i, label in enumerate(results_fsos2a["labels"]) if label == 3]

results_fsos2a_no = {}
for key in results_fsos2a.keys():
    results_fsos2a_no[key] = [value for i, value in enumerate(results_fsos2a[key]) if i not in other_idx]

torch.save(results_fsos2a_no, paths.RESULTS_PATH/"ms-diag/ms-diag_outlines_Llama2-MedTuned-13b_4bit_all_test_few_shot_instruction_rag_no.pt")

In [None]:
summary_no = summarize_performance(["ms-diag_outlines_Llama2-MedTuned-13b_4bit_all_test_few_shot_instruction_no.pt",
                                    "ms-diag_outlines_Llama2-MedTuned-13b_4bit_all_test_few_shot_instruction_rag_no.pt"], labels="label_encodings_Llama2-MedTuned-13b.pt", load_line=False, load_outlines=True)

In [None]:
summary_no

In [None]:
for metric in ["precision", "recall", "f1"]:
    cols = [col for col in summary_no.columns if metric in col and not col.startswith("3.")][:3]
    # Row mean
    summary = pd.DataFrame(summary_no[cols].mean(axis=1).round(2), columns = [metric])
    summary["strategy"] = ["FSI Outlines", "FSI Outlines + S2A"]
    display(summary)