In [1]:
import os
import sys
sys.path.append(os.getcwd()+"/../..")
from src import paths

import pandas as pd
import numpy as np

import torch

from sklearn.metrics.pairwise import cosine_similarity

from sklearn.metrics import f1_score, recall_score, precision_score, classification_report

import warnings
warnings.filterwarnings("ignore")

In [2]:
def load_results(filename:str, labels:str, load_line:bool = False, load_outlines:bool = False, label_correction:dict = None):
    """ 
    Load results from a file and return a pandas dataframe with the results.

    Args:
        filename (str): The name of the file to load.
        labels (str): The name of the file containing the labels.
        load_line (bool, optional): If True, the function will expect a line labelled dataset and return the aggregated results per rid.
        load_outlines (bool, optional): If True, the function will expect the output of an outlines prompt (no hidden states and model answers are labels)
        label_correcttion (dict, optional): A dictionary with rids as keys and the corrected labels as values.

    Returns:
        pd.DataFrame: A pandas dataframe with the results.

    """
    results = torch.load(paths.RESULTS_PATH/"ms-diag"/filename, map_location=torch.device('cpu'))
    labels = torch.load(paths.RESULTS_PATH/"ms-diag"/labels, map_location=torch.device('cpu'))

    output = {"model_answers": results.pop("model_answers")}

    if not load_outlines:
        
        # Get prediction through cosine similarity
        last_hidden_states = results.pop("last_hidden_states")
        last_hidden_states = last_hidden_states.cpu()
        preds = get_prediction(last_hidden_states, labels_hs=labels[1], label_names=labels[0])

        # If rag
        if "rag" in filename:
            output["model_answers_mapped"] = preds + ["no information found" for _ in range(len(output["model_answers"])-len(preds))]
        else:
            output["model_answers_mapped"] = preds
    
    # Map from string answer to int
    label2id = {'primär progrediente Multiple Sklerose': 0,
            'sekundär progrediente Multiple Sklerose': 2,
            'schubförmige remittierende Multiple Sklerose': 1,
            'other': 3,
            'no information found': 3}
    
    if load_outlines:
        key = "model_answers"
    else:
        key = "model_answers_mapped"
    
    output["preds"] = [label2id[i] for i in output[key]]
    output["exact_match"] = [res.lower().split("### output:\n")[-1] in [key.lower() for key in label2id.keys()] for res in output["model_answers"]]

    output["labels"] = results["labels"]
    output["rid"] = results["rid"]
    output["text"] = results["text"]

    if "whole_prompt" in results.keys():
        whole_prompt = results["whole_prompt"]
        if "rag" in filename:
            output["whole_prompt"] = whole_prompt + ["no information found" for _ in range(len(output["model_answers"])-len(whole_prompt))]
        else:
            output["whole_prompt"] = whole_prompt
    
    df = pd.DataFrame(output)

    # Correcting wrong labels from analysis:
    if label_correction:
        for rid, label in label_correction.items():
            df.loc[df.rid == rid, "labels"] = label
    

    if load_line:
        return convert_line2report(df, filename)
    else:
        return df

def get_prediction(hs: torch.Tensor, labels_hs: torch.Tensor, label_names: list):
    """ 
    Get the prediction for a given hidden state and labels.

    Args:
        hs (torch.Tensor): The hidden state to use for prediction. Shape (n_samples, n_features).
        labels_hs (torch.Tensor): The hidden states of the labels. Shape (n_labels, n_features).
        label_names (list): The names of the labels. The order of the names should match the order of the hs-labels in dim 0.

    Returns:
        torch.Tensor: The predicted labels.

    """
    answer_idx = np.argmax(cosine_similarity(hs, labels_hs), axis=1)
    model_answer = [label_names[i] for i in answer_idx]
    return model_answer

def convert_line2report(rids:list[str], preds:list[int])->list[int]:
    """ 
    Aggregeates the results of a line labelled dataset in a majority vote fashion.

    Args:
        rids (list[str]): The rids of the samples.
        preds (list[int]): The predictions of the samples.

    Returns:
        list[int]: The aggregated predictions.
        
    """
    df = pd.DataFrame({"rid":rids, "preds":preds})

    # If len value counts >1 and value counts[0] is 3, then majority vote is value counts[1]
    results = []
    rid = []
    for _df in df.groupby("rid"):
        if _df[1].preds.value_counts().shape[0] > 1 and _df[1].preds.value_counts().iloc[0] == 3:
            result = df.loc[df.rid == _df[0], "preds"] = _df[1].preds.value_counts().index[1]
        else:
            result = df.loc[df.rid == _df[0], "preds"] = _df[1].preds.value_counts().index[0]
        results.append(result)
        rid.append(_df[0])
    return result


def summarize_performance(files: list[str], *args, **kwargs):

    """ 
    Summarizes the performance of a given strategy (Base, RAG, Outlines) for all prompting strategies.

    Args:
        files (list[str]): The files to summarize.
        *args: Variable length argument list.
        **kwargs: Arbitrary keyword arguments.

    Returns:
        pd.DataFrame: A pandas dataframe with the summarized results.

    """
    dfs = []
    for filename in files:
        strategies = ["zero_shot_vanilla", "zero_shot_instruction", "few_shot_vanilla", "few_shot_instruction"]
        
        # Set strategy name to whatever is found in the filename
        for strategy in strategies:
            if strategy in filename:
                break

        results = load_results(filename=filename, labels=kwargs.get("labels"), load_line=kwargs.get("load_lines"), load_outlines=kwargs.get("load_outlines"), label_correction=kwargs.get("label_correction"))

        # Target names
        metric_dict = classification_report(y_true=results["labels"], y_pred=results["preds"], output_dict=True)

        # Create a dictionary with flattened keys
        _df = pd.json_normalize(metric_dict)
        

        # Add additional information
        _df["strategy"] = strategy

        # Look for exact label
        _df["valid_label"] = sum(results["exact_match"])/len(results)

        # Reorder columns
        reordered_cols = _df.columns[-4:].append(_df.columns[:-4])
        _df = _df[reordered_cols]

        dfs.append(_df)

    return pd.concat(dfs)

def generate_latex_table(df: pd.DataFrame, caption: str, label: str, metrics = list[str]) -> str:
    """
    Generate LaTeX code for a table from a pandas DataFrame.

    Args:
        df (pd.DataFrame): DataFrame containing the table data.
        caption (str): Caption for the table.
        label (str): Label for referencing the table.
        metrics (list[str]): List of metrics to include in the table.

    Returns:
        str: LaTeX code for the table.
    """
    # Create the LaTeX code for the table
    latex = "\\begin{table}[h]\n"
    latex += "    \\centering\n"

    # Create tables for each metric
    for metric in metrics:
        latex += f"    \\begin{{tabular}}{{lcccc}}\n"
        latex += f"        \\toprule\n"
        latex += f"        & \\multicolumn{{4}}{{c}}{{{metric}}}\\\\\n"
        latex += f"        \\cmidrule(lr){{2-5}}\n"
        latex += f"        & Base & +S2A & +Outlines & +S2A \\& Outlines \\\\\n"
        latex += f"        \\midrule\n"

        for strategy in df['strategy'].unique():
            base_value = df[(df['strategy'] == strategy) & (df['approach'] == 'base')][metric].iloc[0]
            rag_value = df[(df['strategy'] == strategy) & (df['approach'] == 'rag')][metric].iloc[0]
            outlines_value = df[(df['strategy'] == strategy) & (df['approach'] == 'outlines')][metric].iloc[0]
            outlines_rag_value = df[(df['strategy'] == strategy) & (df['approach'] == 'outlines_rag')][metric].iloc[0]
            
            latex += f"        {strategy.replace('_', '-')} & {base_value:.2f} & {rag_value:.2f} & {outlines_value:.2f} & {outlines_rag_value:.2f} \\\\\n"

        latex += f"        \\bottomrule\n"
        latex += f"    \\end{{tabular}}\n\n"

    # Add caption and label
    latex += f"    \\caption{{{caption}}}\n"
    latex += f"    \\label{{tab:{label}}}\n"
    latex += "\\end{table}\n"

    return latex

# Llama2-MedTuned-13B

## RAG

In [3]:
pd.set_option('display.max_columns', None)
summarize_performance(["ms-diag_Llama2-MedTuned-13b_4bit_all_test_zero_shot_vanilla_rag.pt",
                       "ms-diag_Llama2-MedTuned-13b_4bit_all_test_zero_shot_instruction_rag.pt",
                       "ms-diag_Llama2-MedTuned-13b_4bit_all_test_few_shot_vanilla_rag.pt",
                       "ms-diag_Llama2-MedTuned-13b_4bit_all_test_few_shot_instruction_rag.pt",
                       ], labels="label_encodings_Llama2-MedTuned-13b.pt", load_line=False, load_outlines=False)

Unnamed: 0,weighted avg.f1-score,weighted avg.support,strategy,valid_label,accuracy,0.precision,0.recall,0.f1-score,0.support,1.precision,1.recall,1.f1-score,1.support,2.precision,2.recall,2.f1-score,2.support,3.precision,3.recall,3.f1-score,3.support,macro avg.precision,macro avg.recall,macro avg.f1-score,macro avg.support,weighted avg.precision,weighted avg.recall
0,0.92268,58.0,zero_shot_vanilla,0.62069,0.913793,0.5,1.0,0.666667,4.0,1.0,0.871795,0.931507,39.0,1.0,1.0,1.0,3.0,0.923077,1.0,0.96,12.0,0.855769,0.967949,0.889543,58.0,0.949602,0.913793
0,0.918585,58.0,zero_shot_instruction,0.827586,0.913793,0.571429,1.0,0.727273,4.0,1.0,0.897436,0.945946,39.0,1.0,0.666667,0.8,3.0,0.857143,1.0,0.923077,12.0,0.857143,0.891026,0.849074,58.0,0.940887,0.913793
0,0.683707,58.0,few_shot_vanilla,0.844828,0.62069,0.16,1.0,0.275862,4.0,1.0,0.461538,0.631579,39.0,1.0,0.666667,0.8,3.0,0.923077,1.0,0.96,12.0,0.770769,0.782051,0.66686,58.0,0.926154,0.62069
0,0.947542,58.0,few_shot_instruction,0.931034,0.948276,1.0,1.0,1.0,4.0,0.973684,0.948718,0.961039,39.0,1.0,0.666667,0.8,3.0,0.857143,1.0,0.923077,12.0,0.957707,0.903846,0.921029,58.0,0.952748,0.948276


## Outlines

In [4]:
pd.set_option('display.max_columns', None)
summarize_performance(["ms-diag_outlines_Llama2-MedTuned-13b_4bit_all_test_zero_shot_vanilla.pt",
                       "ms-diag_outlines_Llama2-MedTuned-13b_4bit_all_test_zero_shot_instruction.pt",
                       "ms-diag_outlines_Llama2-MedTuned-13b_4bit_all_test_few_shot_vanilla.pt",
                       "ms-diag_outlines_Llama2-MedTuned-13b_4bit_all_test_few_shot_instruction.pt",
                       ], labels="label_encodings_Llama2-MedTuned-7b.pt", load_line=False, load_outlines=True)

Unnamed: 0,weighted avg.f1-score,weighted avg.support,strategy,valid_label,accuracy,0.precision,0.recall,0.f1-score,0.support,1.precision,1.recall,1.f1-score,1.support,2.precision,2.recall,2.f1-score,2.support,3.precision,3.recall,3.f1-score,3.support,macro avg.precision,macro avg.recall,macro avg.f1-score,macro avg.support,weighted avg.precision,weighted avg.recall
0,0.101649,58.0,zero_shot_vanilla,1.0,0.12069,0.095238,1.0,0.173913,4.0,0.5,0.076923,0.133333,39.0,0.0,0.0,0.0,3.0,0.0,0.0,0.0,12.0,0.14881,0.269231,0.076812,58.0,0.342775,0.12069
0,0.907478,58.0,zero_shot_instruction,1.0,0.896552,0.428571,0.75,0.545455,4.0,1.0,0.897436,0.945946,39.0,1.0,1.0,1.0,3.0,0.846154,0.916667,0.88,12.0,0.818681,0.891026,0.84285,58.0,0.928761,0.896552
0,0.27036,58.0,few_shot_vanilla,1.0,0.241379,0.093023,1.0,0.170213,4.0,0.769231,0.25641,0.384615,39.0,0.0,0.0,0.0,3.0,0.0,0.0,0.0,12.0,0.215564,0.314103,0.138707,58.0,0.523657,0.241379
0,0.933945,58.0,few_shot_instruction,1.0,0.931034,0.666667,1.0,0.8,4.0,1.0,0.897436,0.945946,39.0,1.0,1.0,1.0,3.0,0.857143,1.0,0.923077,12.0,0.880952,0.974359,0.917256,58.0,0.947455,0.931034


## RAG + Outlines

In [5]:
pd.set_option('display.max_columns', None)
summarize_performance(["ms-diag_outlines_Llama2-MedTuned-13b_4bit_all_test_zero_shot_vanilla_rag.pt",
                       "ms-diag_outlines_Llama2-MedTuned-13b_4bit_all_test_zero_shot_instruction_rag.pt",
                       "ms-diag_outlines_Llama2-MedTuned-13b_4bit_all_test_few_shot_vanilla_rag.pt",
                       "ms-diag_outlines_Llama2-MedTuned-13b_4bit_all_test_few_shot_instruction_rag.pt",
                       ], labels="label_encodings_Llama2-MedTuned-13b.pt", load_line=False, load_outlines=True)

Unnamed: 0,weighted avg.f1-score,weighted avg.support,strategy,valid_label,accuracy,0.precision,0.recall,0.f1-score,0.support,1.precision,1.recall,1.f1-score,1.support,2.precision,2.recall,2.f1-score,2.support,3.precision,3.recall,3.f1-score,3.support,macro avg.precision,macro avg.recall,macro avg.f1-score,macro avg.support,weighted avg.precision,weighted avg.recall
0,0.263187,58.0,zero_shot_vanilla,1.0,0.310345,0.25,0.5,0.333333,4.0,1.0,0.025641,0.05,39.0,0.083333,1.0,0.153846,3.0,0.923077,1.0,0.96,12.0,0.564103,0.63141,0.374295,58.0,0.884947,0.310345
0,0.96639,58.0,zero_shot_instruction,1.0,0.965517,1.0,1.0,1.0,4.0,1.0,0.948718,0.973684,39.0,1.0,1.0,1.0,3.0,0.857143,1.0,0.923077,12.0,0.964286,0.987179,0.97419,58.0,0.970443,0.965517
0,0.774999,58.0,few_shot_vanilla,1.0,0.758621,0.285714,1.0,0.444444,4.0,0.933333,0.717949,0.811594,39.0,0.0,0.0,0.0,3.0,0.923077,1.0,0.96,12.0,0.535531,0.679487,0.55401,58.0,0.838272,0.758621
0,0.947542,58.0,few_shot_instruction,1.0,0.948276,1.0,1.0,1.0,4.0,0.973684,0.948718,0.961039,39.0,1.0,0.666667,0.8,3.0,0.857143,1.0,0.923077,12.0,0.957707,0.903846,0.921029,58.0,0.952748,0.948276


### Label Corrections
- rid: 96BC21AA-235F-4EED-A74F-58EFC11C1176, 
    - text: Hochgradiger V.a. entzündliche ZNS-Erkrankung ED 03.05.2019, EM 30.04.2019\nINDENT ätiologisch: möglicherweise multiple Sklerose
    - label: 1
    - corrected label: 3
- rid: B5B6D014-7E02-44E5-8390-F571F7C1D4E5
    - text: Multiple Sklerose (EM 05/2011, ED 09/2011)
    - label: 1
    - corrected label: 3

In [6]:
label_correction = {"96BC21AA-235F-4EED-A74F-58EFC11C1176" : 3, "B5B6D014-7E02-44E5-8390-F571F7C1D4E5": 3}

In [21]:
overall_results = pd.concat([summarize_performance(["ms-diag_Llama2-MedTuned-13b_4bit_all_test_zero_shot_vanilla.pt",
                                                       "ms-diag_Llama2-MedTuned-13b_4bit_all_test_zero_shot_instruction.pt",
                                                       "ms-diag_Llama2-MedTuned-13b_4bit_all_test_few_shot_vanilla.pt",
                                                       "ms-diag_Llama2-MedTuned-13b_4bit_all_test_few_shot_instruction.pt",
                                                       ], labels="label_encodings_Llama2-MedTuned-13b.pt", load_line=False, load_outlines=False, label_correction=label_correction),
                              summarize_performance(["ms-diag_Llama2-MedTuned-13b_4bit_all_test_zero_shot_vanilla_rag.pt",
                                                      "ms-diag_Llama2-MedTuned-13b_4bit_all_test_zero_shot_instruction_rag.pt",
                                                         "ms-diag_Llama2-MedTuned-13b_4bit_all_test_few_shot_vanilla_rag.pt",
                                                         "ms-diag_Llama2-MedTuned-13b_4bit_all_test_few_shot_instruction_rag.pt",
                                                         ], labels="label_encodings_Llama2-MedTuned-13b.pt", load_line=False, load_outlines=False, label_correction=label_correction),
                            summarize_performance(["ms-diag_outlines_Llama2-MedTuned-13b_4bit_all_test_zero_shot_vanilla.pt",
                                                       "ms-diag_outlines_Llama2-MedTuned-13b_4bit_all_test_zero_shot_instruction.pt",
                                                       "ms-diag_outlines_Llama2-MedTuned-13b_4bit_all_test_few_shot_vanilla.pt",
                                                       "ms-diag_outlines_Llama2-MedTuned-13b_4bit_all_test_few_shot_instruction.pt",
                                                       ], labels="label_encodings_Llama2-MedTuned-7b.pt", load_line=False, load_outlines=True, label_correction=label_correction),
                            summarize_performance(["ms-diag_outlines_Llama2-MedTuned-13b_4bit_all_test_zero_shot_vanilla_rag.pt",
                                                       "ms-diag_outlines_Llama2-MedTuned-13b_4bit_all_test_zero_shot_instruction_rag.pt",
                                                       "ms-diag_outlines_Llama2-MedTuned-13b_4bit_all_test_few_shot_vanilla_rag.pt",
                                                       "ms-diag_outlines_Llama2-MedTuned-13b_4bit_all_test_few_shot_instruction_rag.pt",
                                                       ], labels="label_encodings_Llama2-MedTuned-13b.pt", load_line=False, load_outlines=True, label_correction=label_correction)], ignore_index=True)
overall_results["approach"] = ["base"] * 4 + ["rag"] * 4 + ["outlines"] * 4 + ["outlines_rag"] * 4
overall_results.rename(columns={"macro avg.f1-score": "Macro F1-Score",
                                "macro avg.precision": "Macro Precision",
                                "macro avg.recall": "Macro Recall"}, inplace=True)

In [22]:
print(generate_latex_table(overall_results, 
                           caption="Macro Precision, Recall and F1-Score for different prompting strategies using Llama2-MedTuned-13B \cite{rohanian2023exploring}",
                           label="tab:ms-diag-13B",
                           metrics=["Macro Precision", "Macro Recall", "Macro F1-Score"]))

\begin{table}[h]
    \centering
    \begin{tabular}{lcccc}
        \toprule
        & \multicolumn{4}{c}{Macro Precision}\\
        \cmidrule(lr){2-5}
        & Base & +S2A & +Outlines & +S2A \& Outlines \\
        \midrule
        zero-shot-vanilla & 0.56 & 0.88 & 0.15 & 0.56 \\
        zero-shot-instruction & 0.79 & 0.89 & 0.84 & 1.00 \\
        few-shot-vanilla & 0.44 & 0.79 & 0.20 & 0.52 \\
        few-shot-instruction & 0.88 & 0.99 & 0.90 & 0.99 \\
        \bottomrule
    \end{tabular}

    \begin{tabular}{lcccc}
        \toprule
        & \multicolumn{4}{c}{Macro Recall}\\
        \cmidrule(lr){2-5}
        & Base & +S2A & +Outlines & +S2A \& Outlines \\
        \midrule
        zero-shot-vanilla & 0.42 & 0.96 & 0.27 & 0.60 \\
        zero-shot-instruction & 0.77 & 0.90 & 0.89 & 1.00 \\
        few-shot-vanilla & 0.47 & 0.77 & 0.31 & 0.64 \\
        few-shot-instruction & 0.91 & 0.92 & 0.97 & 0.92 \\
        \bottomrule
    \end{tabular}

    \begin{tabular}{lcccc}
        \topru

In [23]:
overall_results.to_csv(paths.THESIS_PATH/"ms_pred_results_prompt13b.csv")

## Valid Label

In [9]:
overall_results[['approach', 'strategy', 'valid_label']]

Unnamed: 0,approach,strategy,valid_label
0,base,zero_shot_vanilla,0.310345
1,base,zero_shot_instruction,0.655172
2,base,few_shot_vanilla,0.689655
3,base,few_shot_instruction,0.965517
4,rag,zero_shot_vanilla,0.62069
5,rag,zero_shot_instruction,0.827586
6,rag,few_shot_vanilla,0.844828
7,rag,few_shot_instruction,0.931034
8,outlines,zero_shot_vanilla,1.0
9,outlines,zero_shot_instruction,1.0


# Llama 7B

## Base

In [10]:
pd.set_option('display.max_columns', None)
summarize_performance(["ms-diag_Llama2-MedTuned-7b_4bit_all_test_zero_shot_vanilla.pt",
                       "ms-diag_Llama2-MedTuned-7b_4bit_all_test_zero_shot_instruction.pt",
                       "ms-diag_Llama2-MedTuned-7b_4bit_all_test_few_shot_vanilla.pt",
                       "ms-diag_Llama2-MedTuned-7b_4bit_all_test_few_shot_instruction.pt",
                       ], labels="label_encodings_Llama2-MedTuned-7b.pt", load_line=False, load_outlines=False)

Unnamed: 0,weighted avg.f1-score,weighted avg.support,strategy,valid_label,accuracy,0.precision,0.recall,0.f1-score,0.support,1.precision,1.recall,1.f1-score,1.support,2.precision,2.recall,2.f1-score,2.support,3.precision,3.recall,3.f1-score,3.support,macro avg.precision,macro avg.recall,macro avg.f1-score,macro avg.support,weighted avg.precision,weighted avg.recall
0,0.358741,58.0,zero_shot_vanilla,0.258621,0.310345,0.125,0.5,0.2,4.0,0.888889,0.205128,0.333333,39.0,0.086957,0.666667,0.153846,3.0,0.6,0.5,0.545455,12.0,0.425211,0.467949,0.308159,58.0,0.734958,0.310345
0,0.096026,58.0,zero_shot_instruction,0.172414,0.068966,0.0,0.0,0.0,4.0,0.75,0.076923,0.139535,39.0,0.022727,0.333333,0.042553,3.0,0.0,0.0,0.0,12.0,0.193182,0.102564,0.045522,58.0,0.505486,0.068966
0,0.241462,58.0,few_shot_vanilla,0.137931,0.189655,0.111111,0.25,0.153846,4.0,1.0,0.128205,0.227273,39.0,0.029412,0.333333,0.054054,3.0,0.4,0.333333,0.363636,12.0,0.385131,0.261218,0.199702,58.0,0.764357,0.189655
0,0.036495,58.0,few_shot_instruction,0.551724,0.051724,0.0,0.0,0.0,4.0,0.5,0.025641,0.04878,39.0,0.037736,0.666667,0.071429,3.0,0.0,0.0,0.0,12.0,0.134434,0.173077,0.030052,58.0,0.338159,0.051724


## Outlines

In [11]:
pd.set_option('display.max_columns', None)
summarize_performance(["ms-diag_outlines_Llama2-MedTuned-7b_4bit_all_test_zero_shot_vanilla.pt",
                       "ms-diag_outlines_Llama2-MedTuned-7b_4bit_all_test_zero_shot_instruction.pt",
                       "ms-diag_outlines_Llama2-MedTuned-7b_4bit_all_test_few_shot_vanilla.pt",
                       "ms-diag_outlines_Llama2-MedTuned-7b_4bit_all_test_few_shot_instruction.pt",
                       ], labels="label_encodings_Llama2-MedTuned-7b.pt", load_line=False, load_outlines=True)

Unnamed: 0,weighted avg.f1-score,weighted avg.support,strategy,valid_label,accuracy,0.precision,0.recall,0.f1-score,0.support,1.precision,1.recall,1.f1-score,1.support,2.precision,2.recall,2.f1-score,2.support,3.precision,3.recall,3.f1-score,3.support,macro avg.precision,macro avg.recall,macro avg.f1-score,macro avg.support,weighted avg.precision,weighted avg.recall
0,0.169644,58.0,zero_shot_vanilla,1.0,0.189655,0.088235,0.75,0.157895,4.0,1.0,0.076923,0.142857,39.0,0.0,0.0,0.0,3.0,0.238095,0.416667,0.30303,12.0,0.331583,0.310897,0.150946,58.0,0.72776,0.189655
0,0.730447,58.0,zero_shot_instruction,1.0,0.689655,0.230769,0.75,0.352941,4.0,1.0,0.717949,0.835821,39.0,1.0,0.333333,0.5,3.0,0.5,0.666667,0.571429,12.0,0.682692,0.616987,0.565048,58.0,0.843501,0.689655
0,0.248468,58.0,few_shot_vanilla,1.0,0.293103,0.076923,0.25,0.117647,4.0,1.0,0.128205,0.227273,39.0,0.0,0.0,0.0,3.0,0.275,0.916667,0.423077,12.0,0.337981,0.323718,0.191999,58.0,0.734615,0.293103
0,0.770662,58.0,few_shot_instruction,1.0,0.758621,0.8,1.0,0.888889,4.0,1.0,0.666667,0.8,39.0,0.5,0.666667,0.571429,3.0,0.521739,1.0,0.685714,12.0,0.705435,0.833333,0.736508,58.0,0.861394,0.758621


## RAG + Outlines

In [12]:
summarize_performance(["ms-diag_outlines_Llama2-MedTuned-7b_4bit_all_test_zero_shot_vanilla_rag.pt",
                          "ms-diag_outlines_Llama2-MedTuned-7b_4bit_all_test_zero_shot_instruction_rag.pt",
                          "ms-diag_outlines_Llama2-MedTuned-7b_4bit_all_test_few_shot_vanilla_rag.pt",
                          "ms-diag_outlines_Llama2-MedTuned-7b_4bit_all_test_few_shot_instruction_rag.pt",
                          ], labels="label_encodings_Llama2-MedTuned-7b.pt", load_line=False, load_outlines=True)

Unnamed: 0,weighted avg.f1-score,weighted avg.support,strategy,valid_label,accuracy,0.precision,0.recall,0.f1-score,0.support,1.precision,1.recall,1.f1-score,1.support,2.precision,2.recall,2.f1-score,2.support,3.precision,3.recall,3.f1-score,3.support,macro avg.precision,macro avg.recall,macro avg.f1-score,macro avg.support,weighted avg.precision,weighted avg.recall
0,0.852706,58.0,zero_shot_vanilla,1.0,0.862069,0.444444,1.0,0.615385,4.0,1.0,0.871795,0.931507,39.0,0.0,0.0,0.0,3.0,0.8,1.0,0.888889,12.0,0.561111,0.717949,0.608945,58.0,0.868582,0.862069
0,0.96639,58.0,zero_shot_instruction,1.0,0.965517,1.0,1.0,1.0,4.0,1.0,0.948718,0.973684,39.0,1.0,1.0,1.0,3.0,0.857143,1.0,0.923077,12.0,0.964286,0.987179,0.97419,58.0,0.970443,0.965517
0,0.18999,58.0,few_shot_vanilla,1.0,0.310345,0.222222,1.0,0.363636,4.0,1.0,0.051282,0.097561,39.0,0.0,0.0,0.0,3.0,0.315789,1.0,0.48,12.0,0.384503,0.512821,0.235299,58.0,0.753075,0.310345
0,0.96639,58.0,few_shot_instruction,1.0,0.965517,1.0,1.0,1.0,4.0,1.0,0.948718,0.973684,39.0,1.0,1.0,1.0,3.0,0.857143,1.0,0.923077,12.0,0.964286,0.987179,0.97419,58.0,0.970443,0.965517


In [26]:
overall_results = pd.concat([summarize_performance(["ms-diag_Llama2-MedTuned-7b_4bit_all_test_zero_shot_vanilla.pt",
                            "ms-diag_Llama2-MedTuned-7b_4bit_all_test_zero_shot_instruction.pt",
                            "ms-diag_Llama2-MedTuned-7b_4bit_all_test_few_shot_vanilla.pt",
                            "ms-diag_Llama2-MedTuned-7b_4bit_all_test_few_shot_instruction.pt",
                            ], labels="label_encodings_Llama2-MedTuned-7b.pt", load_line=False, load_outlines=False, label_correction=label_correction),
                                 summarize_performance(["ms-diag_Llama2-MedTuned-7b_4bit_all_test_zero_shot_vanilla_rag.pt",
                            "ms-diag_Llama2-MedTuned-7b_4bit_all_test_zero_shot_instruction_rag.pt",
                            "ms-diag_Llama2-MedTuned-7b_4bit_all_test_few_shot_vanilla_rag.pt",
                            "ms-diag_Llama2-MedTuned-7b_4bit_all_test_few_shot_instruction_rag.pt",
                            ], labels="label_encodings_Llama2-MedTuned-7b.pt", load_line=False, load_outlines=False, label_correction=label_correction),
                                 summarize_performance(["ms-diag_outlines_Llama2-MedTuned-7b_4bit_all_test_zero_shot_vanilla.pt",
                            "ms-diag_outlines_Llama2-MedTuned-7b_4bit_all_test_zero_shot_instruction.pt",
                            "ms-diag_outlines_Llama2-MedTuned-7b_4bit_all_test_few_shot_vanilla.pt",
                            "ms-diag_outlines_Llama2-MedTuned-7b_4bit_all_test_few_shot_instruction.pt",
                            ], labels="label_encodings_Llama2-MedTuned-7b.pt", load_line=False, load_outlines=True, label_correction=label_correction),
                                 summarize_performance(["ms-diag_outlines_Llama2-MedTuned-7b_4bit_all_test_zero_shot_vanilla_rag.pt",
                            "ms-diag_outlines_Llama2-MedTuned-7b_4bit_all_test_zero_shot_instruction_rag.pt",
                            "ms-diag_outlines_Llama2-MedTuned-7b_4bit_all_test_few_shot_vanilla_rag.pt",
                            "ms-diag_outlines_Llama2-MedTuned-7b_4bit_all_test_few_shot_instruction_rag.pt",
                            ], labels="label_encodings_Llama2-MedTuned-7b.pt", load_line=False, load_outlines=True, label_correction=label_correction)])
overall_results["approach"] = ["base"] * 4 + ["rag"] * 4 + ["outlines"] * 4 + ["outlines_rag"] * 4
overall_results.rename(columns={"macro avg.f1-score": "Macro F1-Score",
                                "macro avg.precision": "Macro Precision",
                                "macro avg.recall": "Macro Recall"}, inplace=True)

In [27]:
overall_results.to_csv(paths.THESIS_PATH/"ms_pred_results_prompt7b.csv")

In [25]:
print(generate_latex_table(overall_results, 
                           caption="Macro Precision, Recall and F1-Score for different prompting strategies using Llama2-MedTuned-13B \cite{rohanian2023exploring}",
                           label="tab:ms-diag-13B",
                           metrics=["Macro Precision", "Macro Recall", "Macro F1-Score"]))

\begin{table}[h]
    \centering
    \begin{tabular}{lcccc}
        \toprule
        & \multicolumn{4}{c}{Macro Precision}\\
        \cmidrule(lr){2-5}
        & Base & +S2A & +Outlines & +S2A \& Outlines \\
        \midrule
        zero-shot-vanilla & 0.45 & 0.67 & 0.36 & 0.58 \\
        zero-shot-instruction & 0.13 & 0.26 & 0.70 & 1.00 \\
        few-shot-vanilla & 0.39 & 0.88 & 0.35 & 0.39 \\
        few-shot-instruction & 0.01 & 0.27 & 0.73 & 1.00 \\
        \bottomrule
    \end{tabular}

    \begin{tabular}{lcccc}
        \toprule
        & \multicolumn{4}{c}{Macro Recall}\\
        \cmidrule(lr){2-5}
        & Base & +S2A & +Outlines & +S2A \& Outlines \\
        \midrule
        zero-shot-vanilla & 0.47 & 0.76 & 0.33 & 0.71 \\
        zero-shot-instruction & 0.10 & 0.38 & 0.62 & 1.00 \\
        few-shot-vanilla & 0.25 & 0.98 & 0.33 & 0.50 \\
        few-shot-instruction & 0.17 & 0.46 & 0.84 & 1.00 \\
        \bottomrule
    \end{tabular}

    \begin{tabular}{lcccc}
        \topru

## Valid Label

In [15]:
overall_results[['approach', 'strategy', 'valid_label']]

Unnamed: 0,approach,strategy,valid_label
0,base,zero_shot_vanilla,0.258621
0,base,zero_shot_instruction,0.172414
0,base,few_shot_vanilla,0.137931
0,base,few_shot_instruction,0.551724
0,rag,zero_shot_vanilla,0.534483
0,rag,zero_shot_instruction,0.551724
0,rag,few_shot_vanilla,0.965517
0,rag,few_shot_instruction,0.982759
0,outlines,zero_shot_vanilla,1.0
0,outlines,zero_shot_instruction,1.0


# Comparison with Old Approach (no other class)

In [16]:
results_fso = torch.load(paths.RESULTS_PATH/"ms-diag/ms-diag_outlines_Llama2-MedTuned-13b_4bit_all_test_few_shot_instruction.pt")
other_idx = other_idx = [i for i, label in enumerate(results_fso["labels"]) if label == 3]

results_fso_no = {}
for key in results_fso.keys():
    results_fso_no[key] = [value for i, value in enumerate(results_fso[key]) if i not in other_idx]

torch.save(results_fso_no, paths.RESULTS_PATH/"ms-diag/ms-diag_outlines_Llama2-MedTuned-13b_4bit_all_test_few_shot_instruction_no.pt")

In [17]:
results_fsos2a = torch.load(paths.RESULTS_PATH/"ms-diag/ms-diag_outlines_Llama2-MedTuned-13b_4bit_all_test_few_shot_instruction_rag.pt")
other_idx = other_idx = [i for i, label in enumerate(results_fsos2a["labels"]) if label == 3]

results_fsos2a_no = {}
for key in results_fsos2a.keys():
    results_fsos2a_no[key] = [value for i, value in enumerate(results_fsos2a[key]) if i not in other_idx]

torch.save(results_fsos2a_no, paths.RESULTS_PATH/"ms-diag/ms-diag_outlines_Llama2-MedTuned-13b_4bit_all_test_few_shot_instruction_rag_no.pt")

In [28]:
summary_no = summarize_performance(["ms-diag_outlines_Llama2-MedTuned-13b_4bit_all_test_few_shot_instruction_no.pt",
                                    "ms-diag_outlines_Llama2-MedTuned-13b_4bit_all_test_few_shot_instruction_rag_no.pt"], labels="label_encodings_Llama2-MedTuned-13b.pt", load_line=False, load_outlines=True)
summary_no["strategy"] = ["FSI Outlines", "FSI Outlines + S2A"]

In [29]:
summary_no

Unnamed: 0,weighted avg.f1-score,weighted avg.support,strategy,valid_label,accuracy,0.precision,0.recall,0.f1-score,0.support,1.precision,1.recall,1.f1-score,1.support,2.precision,2.recall,2.f1-score,2.support,3.precision,3.recall,3.f1-score,3.support,macro avg.precision,macro avg.recall,macro avg.f1-score,macro avg.support,weighted avg.precision,weighted avg.recall
0,0.93678,46.0,FSI Outlines,1.0,0.913043,0.666667,1.0,0.8,4.0,1.0,0.897436,0.945946,39.0,1.0,1.0,1.0,3.0,0.0,0.0,0.0,0.0,0.666667,0.724359,0.686486,46.0,0.971014,0.913043
0,0.953924,46.0,FSI Outlines + S2A,1.0,0.934783,1.0,1.0,1.0,4.0,0.973684,0.948718,0.961039,39.0,1.0,0.666667,0.8,3.0,0.0,0.0,0.0,0.0,0.743421,0.653846,0.69026,46.0,0.977689,0.934783


In [30]:
summary_no.to_csv(paths.THESIS_PATH/"ms_pred_results_prompt13b_no.csv")

In [20]:
for metric in ["precision", "recall", "f1"]:
    cols = [col for col in summary_no.columns if metric in col and not col.startswith("3.")][:3]
    # Row mean
    summary = pd.DataFrame(summary_no[cols].mean(axis=1).round(2), columns = [metric])
    summary["strategy"] = ["FSI Outlines", "FSI Outlines + S2A"]
    display(summary)

Unnamed: 0,precision,strategy
0,0.89,FSI Outlines
0,0.99,FSI Outlines + S2A


Unnamed: 0,recall,strategy
0,0.97,FSI Outlines
0,0.87,FSI Outlines + S2A


Unnamed: 0,f1,strategy
0,0.89,FSI Outlines
0,0.97,FSI Outlines + S2A
