In [95]:
import sys
import os
sys.path.append(os.getcwd()+"/../..")
from src import paths

from src.utils import get_default_pydantic_model, SideEffect, SideEffectList

import pandas as pd
import torch

import json

import numpy as np

from bert_score import BERTScorer

from transformers import BertModel, BertTokenizer
from sklearn.metrics.pairwise import cosine_similarity

In [96]:
with open(paths.DATA_PATH_PREPROCESSED/"side-effects/examples.json", "r") as f:
        examples = json.load(f)

# Load examples to check if model was confused
examples = [json.loads(e["labels"]) for e in examples]
medication_side_effects = []
for item in examples:
    for med_effect_dict in item:
        medication_side_effects.append((med_effect_dict['medication'].lower(), med_effect_dict['side_effect'].lower()))

medication_side_effects = list(set(medication_side_effects) - set(("unknown", "unknown")))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

default_model = get_default_pydantic_model("side_effects")
model = BertModel.from_pretrained(os.path.join(paths.MODEL_PATH, "medbert-512")).to(device)
tokenizer = BertTokenizer.from_pretrained(os.path.join(paths.MODEL_PATH, "medbert-512"))

# scorer = BERTScorer(model_type=os.path.join(paths.MODEL_PATH, "medbert-512"), num_layers=4)

Some weights of BertModel were not initialized from the model checkpoint at /mnt/c/Users/marc_/OneDrive/ETH/MSC_Thesis/inf-extr/resources/models/medbert-512 and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [97]:
def high_cosine_similarity(ground_truth:str, prediction:str, model:BertModel, tokenizer:BertTokenizer, threshhold:float = 0.7):
    with torch.no_grad():
        ground_truth = model(**tokenizer(ground_truth, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device))["last_hidden_state"].to("cpu")
        prediction = model(**tokenizer(prediction, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device))["last_hidden_state"].to("cpu")
        ground_truth = ground_truth.mean(dim=1)
        prediction = prediction.mean(dim=1)
        return cosine_similarity(ground_truth, prediction).item() > threshhold


In [98]:
def preprocess_group(grouped_df: pd.DataFrame)->pd.DataFrame:
    """
    Preprocess a group of rows (grouped by obs index) from the dataframe. The preprocessing consists of:
    - Fusing all the side effects per medication with "," (so only one row per medication is left)
    - Replacing the medication name and side effects with "unknown" if the group has the same medication and side effects as example, which is a sign that the model was confused.

    Args:
        grouped_df (pd.DataFrame): A group of rows from the dataframe.

    Returns:
        pd.DataFrame: A dataframe with the group preprocessed.
    """

    df = []
    duplicated = False

    # If the group has the same medication and side effects as example just put "unknown" for both, as model was confused
    duplicate_count = 0
    for med, side_effect in medication_side_effects:
         for index, row in grouped_df.iterrows():
            if row["medication"].lower() == med and row["side_effect"].lower() == side_effect:
                duplicate_count += 1
    # if duplicate_count >= 3:
    #     # print("Changed at index: ", grouped_df["index"].iloc[0])
    #     grouped_df = pd.DataFrame({"medication": "unknown", "side_effect": "unknown", "text": grouped_df["text"].iloc[0], "original_text": grouped_df["original_text"].iloc[0], "index": grouped_df["index"].iloc[0]}, index=[0])
    if duplicate_count >=1:
        duplicated = True
    
    # group by medication name
    for med, group in grouped_df.groupby("medication"):
       # Fuse all the side effects with ","
        side_effects = group["side_effect"].str.cat(sep=",")
        text = group["text"].iloc[0]
        original_text = group["original_text"].iloc[0]
        index = group["index"].iloc[0]
        # Create a new row
        new_row = {"medication": med, "side_effect": side_effects, "text": text, "original_text": original_text, "index": index, "duplicated": duplicated, "successful": group["successful"].iloc[0]}
        # Append the new row to the dataframe
        df.append(new_row)

    return pd.DataFrame(df)

def prepare_results(path: str)->pd.DataFrame:
    """ 
    Prepare the results from the model to be compared with the labels. The preprocessing consists of:
    - Fixing the model answers, replacing them with the default model answer if they are invalid (in regards to JSON structure).
    - Grouping the results by index and preprocessing the group. (See preprocess_group() function for more details.)
    - Lowercasing and converting everything to string.

    Args:
        path (str): Path to the results file.
        
    Returns:
        pd.DataFrame: A dataframe with the results preprocessed.
    """
    results = torch.load(path)
    df = pd.DataFrame(results)

    df["successful"] = True

    # Fix model_answers
    for idx, row in df.iterrows():
        try:
            # Does model answer have key "side_effects"?
            answer = json.loads(row["model_answers"])

            # Is answer a valid SideEffectList?
            SideEffectList(**answer)

            # Is the element an empty list?
            if len(answer["side_effects"]) == 0:
                raise ValueError("Empty list")
            
            # Are all the elements valid SideEffect?
            for med in answer["side_effects"]:
                SideEffect(**med)
            
            df.at[idx, "model_answers"] = json.dumps(answer)

            
        except:
            df.at[idx, "model_answers"] = default_model.model_dump_json()
            # print(f"Error at index {idx}")
            df.at[idx, "successful"] = False
    
    dfs = []
    for idx, row in df.iterrows():
        row = row.to_dict()
        answer = json.loads(row["model_answers"])
        medications = answer["side_effects"]
        for med in medications:
            med["text"] = row.get("text", "")
            med["original_text"] = row.get("original_text", "")
            med["index"] = row.get("index", "")
            med["successful"] = row.get("successful", "")
            dfs.append((med))

    res = pd.DataFrame(dfs)

    res = res.sort_values(by="index")

    # Group by index and preprocess
    _dfs = []

    for index, group in res.groupby("index"):
        _dfs.append(preprocess_group(group))

    res = pd.concat(_dfs)

    # Convert everything to string and lowercase
    res = res.map(lambda x: str(x).lower())

    return res

def prepare_labels(path: str)->pd.DataFrame:
    """
    Prepare the labels to be compared with the results. The preprocessing consists of:
    - Lowercasing and converting everything to string.

    Args:
        path (str): Path to the labels file.

    Returns:

        pd.DataFrame: A dataframe with the labels preprocessed.
    """
    labels = pd.read_excel(path) 

    # Convert everything to string and lowercase
    labels = labels.map(lambda x: str(x).lower())

    return labels

def calculate_precision_recall_f1(ground_truth:list, predicted:list)->tuple[float, float, float]:
    """
    Calculate precision and recall from two lists, allowing for partial string matches.
    A string does not have to be exactly the same to be considered a match, it is enough if one string is a substring of the other.
    This function is used to evaluate the medication names.

    Args:
        ground_truth (list): List of ground truth medication names.
        predicted (list): List of predicted medication names.

    Returns:
        float: Precision
        float: Recall
        float: F1 Score
    """
    # Convert lists to sets for easier comparison
    ground_truth_set = set(ground_truth)
    predicted_set = set(predicted)

    # Initialize counters
    true_positives = 0
    false_positives = 0

    # Calculate True Positives and False Positives
    for pred_med in predicted_set:
        found_match = False
        for truth_med in ground_truth_set:
            if pred_med in truth_med or truth_med in pred_med:
                true_positives += 1
                found_match = True
                break
        if not found_match:
            false_positives += 1

    # Calculate False Negatives
    false_negatives = len(ground_truth_set) - true_positives

    # Calculate Precision
    if true_positives + false_positives == 0:
        precision = 0  # Handle division by zero
    else:
        precision = true_positives / (true_positives + false_positives)

    # Calculate Recall
    if true_positives + false_negatives == 0:
        recall = 0  # Handle division by zero
    else:
        recall = true_positives / (true_positives + false_negatives)

    # Calculate F1 Score
    if precision + recall == 0:
        f1 = 0
    else:
        f1 = 2 * (precision * recall) / (precision + recall)

    return precision, recall, f1

    # texts_gt = ""
    # texts_pred = ""

    # # Reorder the lists with approximate matching
    # for pred in predicted:
    #     for gt in ground_truth:
    #         if pred in gt or gt in pred:
    #             texts_gt += gt + " "
    #             texts_pred += pred + " "
    #             ground_truth.remove(gt)
    #             predicted.remove(pred)
    #             break
    
    # # Leftover ground truth and predictions
    # texts_gt += " ".join(ground_truth)
    # texts_pred += " ".join(predicted)

    # # Calculate bert score
    # precision, recall, f1 = scorer.score([texts_pred], [texts_gt])

    # return precision.item(), recall.item(), f1.item()


def calculate_bert_score(ground_truth:list[dict], predicted:list[dict], model:BertModel, tokenizer:BertTokenizer, threshold:float=0.9)->tuple[float, float, float]:
    """ 
    Calculate bert score from two lists of dictionaries. This function is used to evaluate the side effects.

    Args:
        ground_truth (list): List of ground truth medication names and their side effect.
        predicted (list): List of predicted medication names and their side effect.
        model (BertModel): BertModel object to calculate the cosine similarity.
        tokenizer (BertTokenizer): BertTokenizer object to tokenize the text.
    """

    # Match the prediction to the ground truth with medication name
    tp = 0
    fn = 0
    fp = 0

    for pred_id, pred in enumerate(predicted):
        for gt_id, gt in enumerate(ground_truth):
            if pred["medication"] in gt["medication"] or gt["medication"] in pred["medication"]:
                if high_cosine_similarity([pred["side_effect"]], [gt["side_effect"]], model, tokenizer, threshold):
                    tp += 1
                else:
                    fp += 1
                    fn += 1
                # Remove the ground truth and prediction from the list
                ground_truth.pop(gt_id)
                predicted.pop(pred_id)

    # For remaining ground truth
    fn += len(ground_truth)
    fp += len(predicted)

    # Calculate Precision
    if tp + fp == 0:
        precision = 0
    else:
        precision = tp / (tp + fp)

    # Calculate Recall
    if tp + fn == 0:
        recall = 0
    else:
        recall = tp / (tp + fn)

    # Calculate F1 Score
    if precision + recall == 0:
        f1 = 0
    else:
        f1 = 2 * (precision * recall) / (precision + recall)

    return precision, recall, f1


def create_metrics_df(labels_path:str, results_path:str, *args, **kwargs)->pd.DataFrame:
    """
    Create a dataframe with the metrics for each index. The metrics are:
    - Precision, Recall and F1 score for medication names
    - Precision, Recall and F1 score for side effects

    Args:
        labels_path (str): Path to the labels file.
        results_path (str): Path to the results file.

    Returns:
        pd.DataFrame: A dataframe with the metrics for each index.
    """
    
    labels = prepare_labels(labels_path)
    res = prepare_results(results_path)

    dfs = []
    for index, group in res.groupby("index"):
        text = group["text"].iloc[0]
        original_text = group["original_text"].iloc[0]
        index = group["index"].iloc[0]
        preds_med = group["medication"].to_list()
        preds_se = group["side_effect"].to_list()
        labels_med = labels[labels["index"] == index]["medication"].to_list()
        labels_se = labels[labels["index"] == index]["side_effect"].to_list()

        # Side effects
        ground_truth = labels[labels["index"] == index].to_dict(orient="records")
        predicted = group.to_dict(orient="records")
        model = kwargs.get("model")
        tokenizer = kwargs.get("tokenizer")
        threshold = kwargs.get("threshold", 0.9)
        assert model is not None and tokenizer is not None, "Model and tokenizer must be provided as keyword arguments for side effects evaluation."
        side_effects_precision, side_effects_recall, side_effects_f1 = calculate_bert_score(ground_truth, predicted, model, tokenizer, threshold)

        # Medication
        ground_truth = labels[labels["index"] == index]["medication"].to_list()
        predicted = group["medication"].to_list()
        medication_precision, medication_recall, medication_f1 = calculate_precision_recall_f1(ground_truth, predicted)

        # Check if duplicated is the same for all in group
        assert len(group["duplicated"].unique()) == 1
        assert len(group["successful"].unique()) == 1

        if group["duplicated"].iloc[0] == "false":
            duplicated = False
        else:
            duplicated = True
        
        if group["successful"].iloc[0] == "false":
            successful = False
        else:
            successful = True

        dfs.append({"text": text, 
                    "original_text": original_text, 
                    "index": index, 
                    "preds_med": preds_med,
                    "labels_med": labels_med,
                    "preds_se": preds_se,
                    "labels_se": labels_se,
                    "duplicated": duplicated,
                    "successful": successful,
                    "side_effects_precision": side_effects_precision, "side_effects_recall": side_effects_recall, "side_effects_f1": side_effects_f1, "medication_precision": medication_precision, "medication_recall": medication_recall, "medication_f1": medication_f1})

    return pd.DataFrame(dfs)

def create_agg_metrics_df(labels_path:str, results_path:str, *args, **kwargs)->pd.DataFrame:
    """
    Create a dataframe with the aggregated metrics for the whole dataset. The metrics are:
    - Precision, Recall and F1 score for medication names
    - Precision, Recall and F1 score for side effects

    Args:
        labels_path (str): Path to the labels file.
        results_path (str): Path to the results file.
        scorer (BERTScorer): BERTScorer object to score the similarity between the side effects.

    Returns:
        pd.DataFrame: A dataframe with the aggregated metrics for the whole dataset.
    """
    metrics = create_metrics_df(labels_path, results_path, *args, **kwargs)

    # Calculate the average of the metrics
    agg_metrics = metrics.agg({"side_effects_precision": "mean", "side_effects_recall": "mean", "side_effects_f1": "mean", "medication_precision": "mean", "medication_recall": "mean", "medication_f1": "mean", "duplicated": "sum", "successful": "sum"})

    return agg_metrics

In [99]:
def summarize(paths:list[str], labels:str, *args, **kwargs)->pd.DataFrame:
    """
    Summarize the results from the model. The summary consists of:
    - Precision, Recall and F1 score for medication names
    - Precision, Recall and F1 score for side effects
    - Aggregated metrics for the whole dataset

    Args:
        paths (list): List of paths to the results files.
        labels (str): Path to the labels file.

    Returns:
        pd.DataFrame: A dataframe with the summary of the results.
    """
    dfs = []
    for path in paths:
        metrics = create_agg_metrics_df(labels, path, *args, **kwargs)
        path = str(path)
        if path.endswith("rag.pt"):
            metrics["approach"] = "S2A-1"
        elif path.endswith("s2a.pt"):
            metrics["approach"] = "S2A-2"
        else:
            metrics["approach"] = "Base"
        
        if "few_shot_vanilla" in path:
            metrics["strategy"] = "Few-Shot Base"
        elif "few_shot_instruction" in path:
            metrics["strategy"] = "Few-Shot Instruction"
        elif "zero_shot_vanilla" in path:
            metrics["strategy"] = "Zero-Shot Base"
        elif "zero_shot_instruction" in path:
            metrics["strategy"] = "Zero-Shot Instruction"
        else:
            ValueError("Unknown strategy")
        dfs.append(metrics)
    return pd.DataFrame(dfs)

In [100]:
file_paths = [paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_zero_shot_vanilla.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_zero_shot_instruction.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_few_shot_vanilla_10_examples.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_few_shot_instruction_10_examples.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_zero_shot_vanilla_rag.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_zero_shot_instruction_rag.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_few_shot_vanilla_10_examples_rag.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_few_shot_instruction_10_examples_rag.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_zero_shot_vanilla_s2a.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_zero_shot_instruction_s2a.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_few_shot_vanilla_10_examples_s2a.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_few_shot_instruction_10_examples_s2a.pt",]

# 13B

## Base

In [107]:
# Zero Shot Vanilla
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_zero_shot_vanilla.pt", model=model, tokenizer=tokenizer, threshold = 0.5)


side_effects_precision     0.188633
side_effects_recall        0.460000
side_effects_f1            0.253973
medication_precision       0.310531
medication_recall          0.750000
medication_f1              0.415100
duplicated                11.000000
successful                96.000000
dtype: float64

In [108]:
# Zero Shot instruction
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_zero_shot_instruction.pt", model=model, tokenizer=tokenizer, threshold = 0.5)

side_effects_precision     0.291452
side_effects_recall        0.576667
side_effects_f1            0.361858
medication_precision       0.427310
medication_recall          0.881667
medication_f1              0.538481
duplicated                13.000000
successful                94.000000
dtype: float64

In [109]:
# Few Shot Vanilla
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_few_shot_vanilla_10_examples.pt", model=model, tokenizer=tokenizer, threshold = 0.5)

side_effects_precision     0.359897
side_effects_recall        0.586667
side_effects_f1            0.420048
medication_precision       0.435992
medication_recall          0.725000
medication_f1              0.509984
duplicated                41.000000
successful                95.000000
dtype: float64

In [110]:
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_few_shot_instruction_10_examples.pt", model=model, tokenizer=tokenizer, threshold = 0.5)

side_effects_precision      0.343980
side_effects_recall         0.656667
side_effects_f1             0.421826
medication_precision        0.432687
medication_recall           0.845000
medication_f1               0.532716
duplicated                 41.000000
successful                100.000000
dtype: float64

## RAG

In [111]:
# Zero Shot Vanilla
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_zero_shot_vanilla_rag.pt", model=model, tokenizer=tokenizer, threshold = 0.5)


side_effects_precision     0.289988
side_effects_recall        0.530000
side_effects_f1            0.346468
medication_precision       0.401992
medication_recall          0.778333
medication_f1              0.490021
duplicated                24.000000
successful                99.000000
dtype: float64

In [112]:
# Zero Shot instruction
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_zero_shot_instruction_rag.pt", model=model, tokenizer=tokenizer, threshold = 0.5)


side_effects_precision     0.313845
side_effects_recall        0.540000
side_effects_f1            0.369103
medication_precision       0.422183
medication_recall          0.768333
medication_f1              0.506132
duplicated                25.000000
successful                99.000000
dtype: float64

In [113]:
# Few Shot Vanilla
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_few_shot_vanilla_10_examples_rag.pt", model=model, tokenizer=tokenizer, threshold = 0.5)

side_effects_precision     0.366198
side_effects_recall        0.538333
side_effects_f1            0.409179
medication_precision       0.429722
medication_recall          0.680000
medication_f1              0.489893
duplicated                50.000000
successful                96.000000
dtype: float64

In [114]:
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_few_shot_instruction_10_examples_rag.pt", model=model, tokenizer=tokenizer, threshold = 0.5)

side_effects_precision     0.335929
side_effects_recall        0.605000
side_effects_f1            0.405405
medication_precision       0.436619
medication_recall          0.791667
medication_f1              0.526373
duplicated                46.000000
successful                97.000000
dtype: float64

## S2A

In [115]:
# Zero Shot Vanilla
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_zero_shot_vanilla_s2a.pt", model=model, tokenizer=tokenizer, threshold = 0.5)


side_effects_precision      0.546429
side_effects_recall         0.586667
side_effects_f1             0.554500
medication_precision        0.806429
medication_recall           0.870000
medication_f1               0.820833
duplicated                 12.000000
successful                100.000000
dtype: float64

In [116]:
# Zero Shot instruction
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_zero_shot_instruction_s2a.pt", model=model, tokenizer=tokenizer, threshold = 0.5)


side_effects_precision     0.697262
side_effects_recall        0.731667
side_effects_f1            0.706500
medication_precision       0.953095
medication_recall          0.990000
medication_f1              0.961500
duplicated                19.000000
successful                92.000000
dtype: float64

In [117]:
# Few Shot Vanilla
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_few_shot_vanilla_10_examples_s2a.pt", model=model, tokenizer=tokenizer, threshold = 0.5)

side_effects_precision     0.719524
side_effects_recall        0.728333
side_effects_f1            0.710667
medication_precision       0.882857
medication_recall          0.900000
medication_f1              0.874667
duplicated                32.000000
successful                97.000000
dtype: float64

In [118]:
# Few Shot Instruction
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_few_shot_instruction_10_examples_s2a.pt", model=model, tokenizer=tokenizer, threshold = 0.5)

side_effects_precision     0.671429
side_effects_recall        0.713333
side_effects_f1            0.670389
medication_precision       0.851429
medication_recall          0.900000
medication_f1              0.851056
duplicated                29.000000
successful                98.000000
dtype: float64

## Summary

In [119]:
def latex_metrics(metrics:pd.DataFrame)->str:
    """
    Convert the metrics dataframe to a latex table.
    """
    dfs = []
    for strat, group in metrics.groupby("strategy"):
        cols = [col for col in metrics.columns if col not in ["strategy", "approach"]]
        df_base = group[group["approach"] == "Base"][cols].reset_index(drop=True)
        df_base.columns = [f"{col}_base" for col in cols]
        df_rag = group[group["approach"] == "S2A-1"][cols].reset_index(drop=True)
        df_rag.columns = [f"{col}_rag" for col in cols]
        df_s2a = group[group["approach"] == "S2A-2"][cols].reset_index(drop=True)
        df_s2a.columns = [f"{col}_s2a" for col in cols]
        df = pd.concat([df_base, df_rag, df_s2a], axis=1)
        df["strategy"] = strat
        dfs.append(df)

    return pd.concat(dfs, axis=0)

In [120]:
results = summarize(file_paths, paths.RESULTS_PATH/"side-effects/labels.xlsx", model=model, tokenizer=tokenizer, threshold = 0.5)
results = results.round(2)
# results.to_csv(paths.RESULTS_PATH/"side-effects/summary13b.csv", index=False)

In [121]:
results

Unnamed: 0,side_effects_precision,side_effects_recall,side_effects_f1,medication_precision,medication_recall,medication_f1,duplicated,successful,approach,strategy
0,0.19,0.46,0.25,0.31,0.75,0.42,11.0,96.0,Base,Zero-Shot Base
1,0.29,0.58,0.36,0.43,0.88,0.54,13.0,94.0,Base,Zero-Shot Instruction
2,0.36,0.59,0.42,0.44,0.72,0.51,41.0,95.0,Base,Few-Shot Base
3,0.34,0.66,0.42,0.43,0.84,0.53,41.0,100.0,Base,Few-Shot Instruction
4,0.29,0.53,0.35,0.4,0.78,0.49,24.0,99.0,S2A-1,Zero-Shot Base
5,0.31,0.54,0.37,0.42,0.77,0.51,25.0,99.0,S2A-1,Zero-Shot Instruction
6,0.37,0.54,0.41,0.43,0.68,0.49,50.0,96.0,S2A-1,Few-Shot Base
7,0.34,0.6,0.41,0.44,0.79,0.53,46.0,97.0,S2A-1,Few-Shot Instruction
8,0.55,0.59,0.55,0.81,0.87,0.82,12.0,100.0,S2A-2,Zero-Shot Base
9,0.7,0.73,0.71,0.95,0.99,0.96,19.0,92.0,S2A-2,Zero-Shot Instruction


In [None]:
thesis13b = latex_metrics(results)
thesis13b_precision = thesis13b[["strategy"] + [col for col in thesis13b.columns if "precision" in col]]
thesis13b_recall = thesis13b[["strategy"] + [col for col in thesis13b.columns if "recall" in col]]
thesis13b_f1 = thesis13b[["strategy"] + [col for col in thesis13b.columns if "f1" in col]]

# thesis13b_precision.to_csv(paths.RESULTS_PATH/"side-effects/thesis13b_precision.csv", index=False)
# thesis13b_recall.to_csv(paths.RESULTS_PATH/"side-effects/thesis13b_recall.csv", index=False)
# thesis13b_f1.to_csv(paths.RESULTS_PATH/"side-effects/thesis13b_f1.csv", index=False)

- Model seems to have difficulties if there is no medication mentioned with side effects. It repeats the examples. I could set a threshold of 3 or more repeated examples is indication that model is confused, so just put "unknown/unknown"
- also split medications that contain "/" into two lists
- concatenate the side effects for the same medications with ","

# 7B

## Base

In [122]:
# Zero Shot Vanilla
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_zero_shot_vanilla.pt", model=model, tokenizer=tokenizer, threshold = 0.5)


side_effects_precision     0.142353
side_effects_recall        0.198333
side_effects_f1            0.149508
medication_precision       0.232186
medication_recall          0.301667
medication_f1              0.239222
duplicated                35.000000
successful                65.000000
dtype: float64

In [123]:
# Zero Shot instruction
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_zero_shot_instruction.pt", model=model, tokenizer=tokenizer, threshold = 0.5)


side_effects_precision     0.131762
side_effects_recall        0.170000
side_effects_f1            0.138556
medication_precision       0.202595
medication_recall          0.248333
medication_f1              0.209222
duplicated                 5.000000
successful                98.000000
dtype: float64

In [124]:
# Few Shot Vanilla
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_few_shot_vanilla_10_examples.pt", model=model, tokenizer=tokenizer, threshold = 0.5)

side_effects_precision     0.328333
side_effects_recall        0.330000
side_effects_f1            0.325000
medication_precision       0.398333
medication_recall          0.378333
medication_f1              0.380000
duplicated                18.000000
successful                98.000000
dtype: float64

In [125]:
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_few_shot_instruction_10_examples.pt", model=model, tokenizer=tokenizer, threshold = 0.5)

side_effects_precision     0.343333
side_effects_recall        0.331667
side_effects_f1            0.334444
medication_precision       0.393333
medication_recall          0.366667
medication_f1              0.374444
duplicated                22.000000
successful                93.000000
dtype: float64

## RAG

In [126]:
# Zero Shot Vanilla
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_zero_shot_vanilla_rag.pt", model=model, tokenizer=tokenizer, threshold = 0.5)


side_effects_precision      0.285000
side_effects_recall         0.295000
side_effects_f1             0.284000
medication_precision        0.420357
medication_recall           0.428333
medication_f1               0.413444
duplicated                 18.000000
successful                100.000000
dtype: float64

In [127]:
# Zero Shot instruction
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_zero_shot_instruction_rag.pt", model=model, tokenizer=tokenizer, threshold = 0.5)


side_effects_precision     0.260333
side_effects_recall        0.270000
side_effects_f1            0.261190
medication_precision       0.412333
medication_recall          0.408333
medication_f1              0.402381
duplicated                19.000000
successful                99.000000
dtype: float64

In [128]:
# Few Shot Vanilla
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_few_shot_vanilla_10_examples_rag.pt", model=model, tokenizer=tokenizer, threshold = 0.5)

side_effects_precision      0.295000
side_effects_recall         0.300000
side_effects_f1             0.296667
medication_precision        0.385000
medication_recall           0.378333
medication_f1               0.378333
duplicated                 24.000000
successful                100.000000
dtype: float64

In [129]:
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_few_shot_instruction_10_examples_rag.pt", model=model, tokenizer=tokenizer, threshold = 0.5)

side_effects_precision     0.313929
side_effects_recall        0.325000
side_effects_f1            0.313167
medication_precision       0.397262
medication_recall          0.403333
medication_f1              0.389833
duplicated                29.000000
successful                97.000000
dtype: float64

## S2A

In [130]:
# Zero Shot Vanilla
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_zero_shot_vanilla_s2a.pt", model=model, tokenizer=tokenizer, threshold = 0.5)


side_effects_precision      0.533333
side_effects_recall         0.488333
side_effects_f1             0.503333
medication_precision        0.790000
medication_recall           0.720000
medication_f1               0.743333
duplicated                  1.000000
successful                100.000000
dtype: float64

In [131]:
# Zero Shot instruction
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_zero_shot_instruction_s2a.pt", model=model, tokenizer=tokenizer, threshold = 0.5)


side_effects_precision     0.556667
side_effects_recall        0.500000
side_effects_f1            0.518333
medication_precision       0.820000
medication_recall          0.743333
medication_f1              0.768333
duplicated                 5.000000
successful                99.000000
dtype: float64

In [132]:
# Few Shot Vanilla
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_few_shot_vanilla_10_examples_s2a.pt", model=model, tokenizer=tokenizer, threshold = 0.5)

side_effects_precision     0.370000
side_effects_recall        0.331667
side_effects_f1            0.343333
medication_precision       0.550000
medication_recall          0.486667
medication_f1              0.506667
duplicated                12.000000
successful                99.000000
dtype: float64

In [133]:
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_few_shot_instruction_10_examples_s2a.pt", model=model, tokenizer=tokenizer, threshold = 0.5)

side_effects_precision     0.340000
side_effects_recall        0.313333
side_effects_f1            0.321667
medication_precision       0.470000
medication_recall          0.418333
medication_f1              0.435000
duplicated                20.000000
successful                96.000000
dtype: float64

In [134]:
file_paths7b = [paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_zero_shot_vanilla.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_zero_shot_instruction.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_few_shot_vanilla_10_examples.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_few_shot_instruction_10_examples.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_zero_shot_vanilla_rag.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_zero_shot_instruction_rag.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_few_shot_vanilla_10_examples_rag.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_few_shot_instruction_10_examples_rag.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_zero_shot_vanilla_s2a.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_zero_shot_instruction_s2a.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_few_shot_vanilla_10_examples_s2a.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_few_shot_instruction_10_examples_s2a.pt",]

In [135]:
results7b = summarize(file_paths7b, paths.RESULTS_PATH/"side-effects/labels.xlsx", model=model, tokenizer=tokenizer, threshold = 0.5)
results7b = results7b.round(2)
# results7b.to_csv(paths.RESULTS_PATH/"side-effects/summary7b.csv", index=False)

In [136]:
results7b

Unnamed: 0,side_effects_precision,side_effects_recall,side_effects_f1,medication_precision,medication_recall,medication_f1,duplicated,successful,approach,strategy
0,0.14,0.2,0.15,0.23,0.3,0.24,35.0,65.0,Base,Zero-Shot Base
1,0.13,0.17,0.14,0.2,0.25,0.21,5.0,98.0,Base,Zero-Shot Instruction
2,0.33,0.33,0.32,0.4,0.38,0.38,18.0,98.0,Base,Few-Shot Base
3,0.34,0.33,0.33,0.39,0.37,0.37,22.0,93.0,Base,Few-Shot Instruction
4,0.28,0.3,0.28,0.42,0.43,0.41,18.0,100.0,S2A-1,Zero-Shot Base
5,0.26,0.27,0.26,0.41,0.41,0.4,19.0,99.0,S2A-1,Zero-Shot Instruction
6,0.3,0.3,0.3,0.38,0.38,0.38,24.0,100.0,S2A-1,Few-Shot Base
7,0.31,0.32,0.31,0.4,0.4,0.39,29.0,97.0,S2A-1,Few-Shot Instruction
8,0.53,0.49,0.5,0.79,0.72,0.74,1.0,100.0,S2A-2,Zero-Shot Base
9,0.56,0.5,0.52,0.82,0.74,0.77,5.0,99.0,S2A-2,Zero-Shot Instruction


In [None]:
thesis7b = latex_metrics(results7b)
thesis7b_precision = thesis7b[["strategy"] + [col for col in thesis7b.columns if "precision" in col]]
thesis7b_recall = thesis7b[["strategy"] + [col for col in thesis7b.columns if "recall" in col]]
thesis7b_f1 = thesis7b[["strategy"] + [col for col in thesis7b.columns if "f1" in col]]

# thesis7b_precision.to_csv(paths.RESULTS_PATH/"side-effects/thesis7b_precision.csv", index=False)
# thesis7b_recall.to_csv(paths.RESULTS_PATH/"side-effects/thesis7b_recall.csv", index=False)
# thesis7b_f1.to_csv(paths.RESULTS_PATH/"side-effects/thesis7b_f1.csv", index=False)

# 13B Lora

In [137]:
file_paths13b_lora = [
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b-LoRa-merged_4bit_few_shot_instruction_10_examples.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b-LoRa-merged_4bit_zero_shot_vanilla_rag.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b-LoRa-merged_4bit_zero_shot_instruction_rag.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b-LoRa-merged_4bit_few_shot_vanilla_10_examples_rag.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b-LoRa-merged_4bit_few_shot_instruction_10_examples_rag.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b-LoRa-merged_4bit_zero_shot_vanilla_s2a.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b-LoRa-merged_4bit_zero_shot_instruction_s2a.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b-LoRa-merged_4bit_few_shot_vanilla_10_examples_s2a.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b-LoRa-merged_4bit_few_shot_instruction_10_examples_s2a.pt",]

In [138]:
results13b_lora = summarize(file_paths13b_lora, paths.RESULTS_PATH/"side-effects/labels.xlsx", model=model, tokenizer=tokenizer, threshold = 0.5)

In [139]:
results13b_lora = results13b_lora.round(2)
results13b_lora

Unnamed: 0,side_effects_precision,side_effects_recall,side_effects_f1,medication_precision,medication_recall,medication_f1,duplicated,successful,approach,strategy
0,0.14,0.34,0.18,0.2,0.54,0.26,79.0,97.0,Base,Few-Shot Instruction
1,0.25,0.34,0.28,0.31,0.44,0.35,63.0,56.0,S2A-1,Zero-Shot Base
2,0.28,0.52,0.34,0.38,0.73,0.47,20.0,100.0,S2A-1,Zero-Shot Instruction
3,0.34,0.36,0.34,0.43,0.45,0.43,46.0,77.0,S2A-1,Few-Shot Base
4,0.32,0.48,0.36,0.42,0.74,0.5,39.0,98.0,S2A-1,Few-Shot Instruction
5,0.42,0.44,0.42,0.76,0.82,0.77,13.0,95.0,S2A-2,Zero-Shot Base
6,0.49,0.52,0.49,0.82,0.88,0.83,14.0,97.0,S2A-2,Zero-Shot Instruction
7,0.66,0.61,0.63,0.84,0.77,0.79,24.0,96.0,S2A-2,Few-Shot Base
8,0.61,0.64,0.61,0.82,0.85,0.81,37.0,98.0,S2A-2,Few-Shot Instruction
