In [862]:
import sys
import os
sys.path.append(os.getcwd()+"/../..")
from src import paths

from src.utils import get_default_pydantic_model, SideEffect, SideEffectList

import pandas as pd
import torch

import json

import numpy as np

from bert_score import BERTScorer

In [863]:
with open(paths.DATA_PATH_PREPROCESSED/"side-effects/examples.json", "r") as f:
        examples = json.load(f)

# Load examples to check if model was confused
examples = [json.loads(e["labels"]) for e in examples]
medication_side_effects = []
for item in examples:
    for med_effect_dict in item:
        medication_side_effects.append((med_effect_dict['medication'].lower(), med_effect_dict['side_effect'].lower()))

medication_side_effects = list(set(medication_side_effects) - set(("unknown", "unknown")))

default_model = get_default_pydantic_model("side_effects")

scorer = BERTScorer(model_type=os.path.join(paths.MODEL_PATH, "medbert-512"), num_layers=4)

Some weights of BertModel were not initialized from the model checkpoint at /mnt/c/Users/marc_/OneDrive/ETH/MSC_Thesis/inf-extr/resources/models/medbert-512 and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [864]:
def preprocess_group(grouped_df: pd.DataFrame)->pd.DataFrame:
    """
    Preprocess a group of rows (grouped by obs index) from the dataframe. The preprocessing consists of:
    - Fusing all the side effects per medication with "," (so only one row per medication is left)
    - Replacing the medication name and side effects with "unknown" if the group has the same medication and side effects as example, which is a sign that the model was confused.

    Args:
        grouped_df (pd.DataFrame): A group of rows from the dataframe.

    Returns:
        pd.DataFrame: A dataframe with the group preprocessed.
    """

    df = []
    duplicated = False

    # If the group has the same medication and side effects as example just put "unknown" for both, as model was confused
    duplicate_count = 0
    for med, side_effect in medication_side_effects:
         for index, row in grouped_df.iterrows():
            if row["medication"].lower() == med and row["side_effect"].lower() == side_effect:
                duplicate_count += 1
    # if duplicate_count >= 3:
    #     # print("Changed at index: ", grouped_df["index"].iloc[0])
    #     grouped_df = pd.DataFrame({"medication": "unknown", "side_effect": "unknown", "text": grouped_df["text"].iloc[0], "original_text": grouped_df["original_text"].iloc[0], "index": grouped_df["index"].iloc[0]}, index=[0])
    if duplicate_count >=1:
        duplicated = True
    
    # group by medication name
    for med, group in grouped_df.groupby("medication"):
       # Fuse all the side effects with ","
        side_effects = group["side_effect"].str.cat(sep=",")
        text = group["text"].iloc[0]
        original_text = group["original_text"].iloc[0]
        index = group["index"].iloc[0]
        # Create a new row
        new_row = {"medication": med, "side_effect": side_effects, "text": text, "original_text": original_text, "index": index, "duplicated": duplicated, "successful": group["successful"].iloc[0]}
        # Append the new row to the dataframe
        df.append(new_row)

    return pd.DataFrame(df)

def prepare_results(path: str)->pd.DataFrame:
    """ 
    Prepare the results from the model to be compared with the labels. The preprocessing consists of:
    - Fixing the model answers, replacing them with the default model answer if they are invalid (in regards to JSON structure).
    - Grouping the results by index and preprocessing the group. (See preprocess_group() function for more details.)
    - Lowercasing and converting everything to string.

    Args:
        path (str): Path to the results file.
        
    Returns:
        pd.DataFrame: A dataframe with the results preprocessed.
    """
    results = torch.load(path)
    df = pd.DataFrame(results)

    df["successful"] = True

    # Fix model_answers
    for idx, row in df.iterrows():
        try:
            # Does model answer have key "side_effects"?
            answer = json.loads(row["model_answers"])

            # Is answer a valid SideEffectList?
            SideEffectList(**answer)

            # Is the element an empty list?
            if len(answer["side_effects"]) == 0:
                raise ValueError("Empty list")
            
            # Are all the elements valid SideEffect?
            for med in answer["side_effects"]:
                SideEffect(**med)
            
            df.at[idx, "model_answers"] = json.dumps(answer)

            
        except:
            df.at[idx, "model_answers"] = default_model.model_dump_json()
            # print(f"Error at index {idx}")
            df.at[idx, "successful"] = False
    
    dfs = []
    for idx, row in df.iterrows():
        row = row.to_dict()
        answer = json.loads(row["model_answers"])
        medications = answer["side_effects"]
        for med in medications:
            med["text"] = row.get("text", "")
            med["original_text"] = row.get("original_text", "")
            med["index"] = row.get("index", "")
            med["successful"] = row.get("successful", "")
            dfs.append((med))

    res = pd.DataFrame(dfs)

    res = res.sort_values(by="index")

    # Group by index and preprocess
    _dfs = []

    for index, group in res.groupby("index"):
        _dfs.append(preprocess_group(group))

    res = pd.concat(_dfs)

    # Convert everything to string and lowercase
    res = res.map(lambda x: str(x).lower())

    return res

def prepare_labels(path: str)->pd.DataFrame:
    """
    Prepare the labels to be compared with the results. The preprocessing consists of:
    - Lowercasing and converting everything to string.

    Args:
        path (str): Path to the labels file.

    Returns:

        pd.DataFrame: A dataframe with the labels preprocessed.
    """
    labels = pd.read_excel(path) 

    # Convert everything to string and lowercase
    labels = labels.map(lambda x: str(x).lower())

    return labels

def calculate_precision_recall_f1(ground_truth:list, predicted:list)->tuple[float, float, float]:
    """
    Calculate precision and recall from two lists, allowing for partial string matches.
    A string does not have to be exactly the same to be considered a match, it is enough if one string is a substring of the other.
    This function is used to evaluate the medication names.

    Args:
        ground_truth (list): List of ground truth medication names.
        predicted (list): List of predicted medication names.

    Returns:
        float: Precision
        float: Recall
        float: F1 Score
    """
    # Convert lists to sets for easier comparison
    ground_truth_set = set(ground_truth)
    predicted_set = set(predicted)

    # Initialize counters
    true_positives = 0
    false_positives = 0

    # Calculate True Positives and False Positives
    for pred_med in predicted_set:
        found_match = False
        for truth_med in ground_truth_set:
            if pred_med in truth_med or truth_med in pred_med:
                true_positives += 1
                found_match = True
                break
        if not found_match:
            false_positives += 1

    # Calculate False Negatives
    false_negatives = len(ground_truth_set) - true_positives

    # Calculate Precision
    if true_positives + false_positives == 0:
        precision = 0  # Handle division by zero
    else:
        precision = true_positives / (true_positives + false_positives)

    # Calculate Recall
    if true_positives + false_negatives == 0:
        recall = 0  # Handle division by zero
    else:
        recall = true_positives / (true_positives + false_negatives)

    # Calculate F1 Score
    if precision + recall == 0:
        f1 = 0
    else:
        f1 = 2 * (precision * recall) / (precision + recall)

    return precision, recall, f1

    # texts_gt = ""
    # texts_pred = ""

    # # Reorder the lists with approximate matching
    # for pred in predicted:
    #     for gt in ground_truth:
    #         if pred in gt or gt in pred:
    #             texts_gt += gt + " "
    #             texts_pred += pred + " "
    #             ground_truth.remove(gt)
    #             predicted.remove(pred)
    #             break
    
    # # Leftover ground truth and predictions
    # texts_gt += " ".join(ground_truth)
    # texts_pred += " ".join(predicted)

    # # Calculate bert score
    # precision, recall, f1 = scorer.score([texts_pred], [texts_gt])

    # return precision.item(), recall.item(), f1.item()


def calculate_bert_score(ground_truth:list[dict], predicted:list[dict], scorer:BERTScorer)->tuple[float, float, float]:
    """ 
    Calculate bert score from two lists of dictionaries. This function is used to evaluate the side effects.

    Args:
        ground_truth (list): List of ground truth medication names and their side effect.
        predicted (list): List of predicted medication names and their side effect.
        scorer (BERTScorer): BERTScorer object to score the similarity between the side effects.
    """

    # Match the prediction to the ground truth with medication name
    precision = []
    recall = []
    f1 = []
    texts_gt = ""
    texts_pred = ""

    for pred_id, pred in enumerate(predicted):
        for gt_id, gt in enumerate(ground_truth):
            if pred["medication"] in gt["medication"] or gt["medication"] in pred["medication"]:
                texts_gt += gt["side_effect"] + " "
                texts_pred += pred["side_effect"] + " "
                ground_truth.pop(gt_id)
                predicted.pop(pred_id)
                break
    
    # Leftover ground truth and predictions
    texts_gt += " ".join([gt["side_effect"] for gt in ground_truth])
    texts_pred += " ".join([pred["side_effect"] for pred in predicted])
    
    # Calculate bert score
    precision, recall, f1 = scorer.score([texts_pred], [texts_gt])

    return precision.item(), recall.item(), f1.item()

    # for pred_id, pred in enumerate(predicted):
    #     for gt_id, gt in enumerate(ground_truth):
    #         if pred["medication"] in gt["medication"] or gt["medication"] in pred["medication"]:
    #             p, r, f = scorer.score([pred["side_effect"]], [gt["side_effect"]])
    #             precision.append(p)
    #             recall.append(r)
    #             f1.append(f)
    #             # Remove the ground truth and prediction from the list
    #             ground_truth.pop(gt_id)
    #             predicted.pop(pred_id)
    
    # # Append 0 to the bert score for the remaining ground truth and predictions
    # precision.extend([torch.zeros(1)]*(len(ground_truth)+len(predicted)))
    # recall.extend([torch.zeros(1)]*(len(ground_truth)+len(predicted)))
    # f1.extend([torch.zeros(1)]*(len(ground_truth)+len(predicted)))

    # precision = torch.cat(precision, dim=0)
    # recall = torch.cat(recall, dim=0)
    # f1 = torch.cat(f1, dim=0)


    # return torch.mean(precision).item(), torch.mean(recall).item(), torch.mean(f1).item()

def create_metrics_df(labels_path:str, results_path:str, scorer:BERTScorer)->pd.DataFrame:
    """
    Create a dataframe with the metrics for each index. The metrics are:
    - Precision, Recall and F1 score for medication names
    - Precision, Recall and F1 score for side effects

    Args:
        labels_path (str): Path to the labels file.
        results_path (str): Path to the results file.
        scorer (BERTScorer): BERTScorer object to score the similarity between the side effects.

    Returns:
        pd.DataFrame: A dataframe with the metrics for each index.
    """
    
    labels = prepare_labels(labels_path)
    res = prepare_results(results_path)

    dfs = []
    for index, group in res.groupby("index"):
        text = group["text"].iloc[0]
        original_text = group["original_text"].iloc[0]
        index = group["index"].iloc[0]
        preds_med = group["medication"].to_list()
        preds_se = group["side_effect"].to_list()
        labels_med = labels[labels["index"] == index]["medication"].to_list()
        labels_se = labels[labels["index"] == index]["side_effect"].to_list()

        # Side effects
        ground_truth = labels[labels["index"] == index].to_dict(orient="records")
        predicted = group.to_dict(orient="records")
        side_effects_precision, side_effects_recall, side_effects_f1 = calculate_bert_score(ground_truth, predicted, scorer)

        # Medication
        ground_truth = labels[labels["index"] == index]["medication"].to_list()
        predicted = group["medication"].to_list()
        medication_precision, medication_recall, medication_f1 = calculate_precision_recall_f1(ground_truth, predicted)

        # Check if duplicated is the same for all in group
        assert len(group["duplicated"].unique()) == 1
        assert len(group["successful"].unique()) == 1

        if group["duplicated"].iloc[0] == "false":
            duplicated = False
        else:
            duplicated = True
        
        if group["successful"].iloc[0] == "false":
            successful = False
        else:
            successful = True

        dfs.append({"text": text, 
                    "original_text": original_text, 
                    "index": index, 
                    "preds_med": preds_med,
                    "labels_med": labels_med,
                    "preds_se": preds_se,
                    "labels_se": labels_se,
                    "duplicated": duplicated,
                    "successful": successful,
                    "side_effects_precision": side_effects_precision, "side_effects_recall": side_effects_recall, "side_effects_f1": side_effects_f1, "medication_precision": medication_precision, "medication_recall": medication_recall, "medication_f1": medication_f1})

    return pd.DataFrame(dfs)

def create_agg_metrics_df(labels_path:str, results_path:str, scorer:BERTScorer)->pd.DataFrame:
    """
    Create a dataframe with the aggregated metrics for the whole dataset. The metrics are:
    - Precision, Recall and F1 score for medication names
    - Precision, Recall and F1 score for side effects

    Args:
        labels_path (str): Path to the labels file.
        results_path (str): Path to the results file.
        scorer (BERTScorer): BERTScorer object to score the similarity between the side effects.

    Returns:
        pd.DataFrame: A dataframe with the aggregated metrics for the whole dataset.
    """
    metrics = create_metrics_df(labels_path, results_path, scorer)

    # Calculate the average of the metrics
    agg_metrics = metrics.agg({"side_effects_precision": "mean", "side_effects_recall": "mean", "side_effects_f1": "mean", "medication_precision": "mean", "medication_recall": "mean", "medication_f1": "mean", "duplicated": "sum", "successful": "sum"})

    return agg_metrics

In [865]:
def summarize(paths:list[str], labels:str, scorer:BERTScorer)->pd.DataFrame:
    """
    Summarize the results from the model. The summary consists of:
    - Precision, Recall and F1 score for medication names
    - Precision, Recall and F1 score for side effects
    - Aggregated metrics for the whole dataset

    Args:
        paths (list): List of paths to the results files.
        labels (str): Path to the labels file.
        scorer (BERTScorer): BERTScorer object to score the similarity between the side effects.

    Returns:
        pd.DataFrame: A dataframe with the summary of the results.
    """
    dfs = []
    for path in paths:
        metrics = create_agg_metrics_df(labels, path, scorer)
        path = str(path)
        if path.endswith("rag.pt"):
            metrics["approach"] = "S2A-1"
        elif path.endswith("s2a.pt"):
            metrics["approach"] = "S2A-2"
        else:
            metrics["approach"] = "Base"
        
        if "few_shot_vanilla" in path:
            metrics["strategy"] = "Few-Shot Base"
        elif "few_shot_instruction" in path:
            metrics["strategy"] = "Few-Shot Instruction"
        elif "zero_shot_vanilla" in path:
            metrics["strategy"] = "Zero-Shot Base"
        elif "zero_shot_instruction" in path:
            metrics["strategy"] = "Zero-Shot Instruction"
        else:
            ValueError("Unknown strategy")
        dfs.append(metrics)
    return pd.DataFrame(dfs)

In [866]:
file_paths = [paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_zero_shot_vanilla.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_zero_shot_instruction.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_few_shot_vanilla_10_examples.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_few_shot_instruction_10_examples.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_zero_shot_vanilla_rag.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_zero_shot_instruction_rag.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_few_shot_vanilla_10_examples_rag.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_few_shot_instruction_10_examples_rag.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_zero_shot_vanilla_s2a.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_zero_shot_instruction_s2a.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_few_shot_vanilla_10_examples_s2a.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_few_shot_instruction_10_examples_s2a.pt",]

# 13B

## Base

In [867]:
# Zero Shot Vanilla
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_zero_shot_vanilla.pt", scorer)


side_effects_precision     0.476860
side_effects_recall        0.581228
side_effects_f1            0.513299
medication_precision       0.310531
medication_recall          0.750000
medication_f1              0.415100
duplicated                11.000000
successful                96.000000
dtype: float64

In [868]:
# Zero Shot instruction
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_zero_shot_instruction.pt", scorer)


side_effects_precision     0.530966
side_effects_recall        0.637905
side_effects_f1            0.568499
medication_precision       0.427310
medication_recall          0.881667
medication_f1              0.538481
duplicated                13.000000
successful                94.000000
dtype: float64

In [869]:
# Few Shot Vanilla
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_few_shot_vanilla_10_examples.pt", scorer)

side_effects_precision     0.612221
side_effects_recall        0.739852
side_effects_f1            0.658761
medication_precision       0.435992
medication_recall          0.725000
medication_f1              0.509984
duplicated                41.000000
successful                95.000000
dtype: float64

In [870]:
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_few_shot_instruction_10_examples.pt", scorer)

side_effects_precision     0.579748
side_effects_recall        0.727668
side_effects_f1            0.632492
medication_precision       0.403071
medication_recall          0.743333
medication_f1              0.485891
duplicated                41.000000
successful                95.000000
dtype: float64

## RAG

In [871]:
# Zero Shot Vanilla
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_zero_shot_vanilla_rag.pt", scorer)


side_effects_precision     0.552860
side_effects_recall        0.621523
side_effects_f1            0.574738
medication_precision       0.401992
medication_recall          0.778333
medication_f1              0.490021
duplicated                24.000000
successful                99.000000
dtype: float64

In [872]:
# Zero Shot instruction
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_zero_shot_instruction_rag.pt", scorer)


side_effects_precision     0.560098
side_effects_recall        0.621091
side_effects_f1            0.579904
medication_precision       0.422183
medication_recall          0.768333
medication_f1              0.506132
duplicated                25.000000
successful                99.000000
dtype: float64

In [873]:
# Few Shot Vanilla
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_few_shot_vanilla_10_examples_rag.pt", scorer)

side_effects_precision     0.621593
side_effects_recall        0.696528
side_effects_f1            0.643909
medication_precision       0.429722
medication_recall          0.680000
medication_f1              0.489893
duplicated                50.000000
successful                96.000000
dtype: float64

In [874]:
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_few_shot_instruction_10_examples_rag.pt", scorer)

side_effects_precision     0.604449
side_effects_recall        0.725081
side_effects_f1            0.646291
medication_precision       0.436619
medication_recall          0.791667
medication_f1              0.526373
duplicated                46.000000
successful                97.000000
dtype: float64

## S2A

In [875]:
# Zero Shot Vanilla
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_zero_shot_vanilla_s2a.pt", scorer)


side_effects_precision      0.652418
side_effects_recall         0.647912
side_effects_f1             0.642934
medication_precision        0.806429
medication_recall           0.870000
medication_f1               0.820833
duplicated                 12.000000
successful                100.000000
dtype: float64

In [876]:
# Zero Shot instruction
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_zero_shot_instruction_s2a.pt", scorer)


side_effects_precision     0.741831
side_effects_recall        0.726890
side_effects_f1            0.728542
medication_precision       0.953095
medication_recall          0.990000
medication_f1              0.961500
duplicated                19.000000
successful                92.000000
dtype: float64

In [877]:
# Few Shot Vanilla
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_few_shot_vanilla_10_examples_s2a.pt", scorer)

side_effects_precision     0.768935
side_effects_recall        0.785101
side_effects_f1            0.768471
medication_precision       0.882857
medication_recall          0.900000
medication_f1              0.874667
duplicated                32.000000
successful                97.000000
dtype: float64

In [878]:
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-13b_4bit_few_shot_instruction_10_examples_s2a.pt", scorer)

side_effects_precision     0.757953
side_effects_recall        0.812161
side_effects_f1            0.775305
medication_precision       0.851429
medication_recall          0.900000
medication_f1              0.851056
duplicated                29.000000
successful                98.000000
dtype: float64

## Summary

In [879]:
def latex_metrics(metrics:pd.DataFrame)->str:
    """
    Convert the metrics dataframe to a latex table.
    """
    dfs = []
    for strat, group in metrics.groupby("strategy"):
        cols = [col for col in metrics.columns if col not in ["strategy", "approach"]]
        df_base = group[group["approach"] == "Base"][cols].reset_index(drop=True)
        df_base.columns = [f"{col}_base" for col in cols]
        df_rag = group[group["approach"] == "S2A-1"][cols].reset_index(drop=True)
        df_rag.columns = [f"{col}_rag" for col in cols]
        df_s2a = group[group["approach"] == "S2A-2"][cols].reset_index(drop=True)
        df_s2a.columns = [f"{col}_s2a" for col in cols]
        df = pd.concat([df_base, df_rag, df_s2a], axis=1)
        df["strategy"] = strat
        dfs.append(df)

    return pd.concat(dfs, axis=0)

In [880]:
results = summarize(file_paths, paths.RESULTS_PATH/"side-effects/labels.xlsx", scorer)
results = results.round(2)
results.to_csv(paths.RESULTS_PATH/"side-effects/summary13b.csv", index=False)

In [881]:
thesis13b = latex_metrics(results)
thesis13b_precision = thesis13b[["strategy"] + [col for col in thesis13b.columns if "precision" in col]]
thesis13b_recall = thesis13b[["strategy"] + [col for col in thesis13b.columns if "recall" in col]]
thesis13b_f1 = thesis13b[["strategy"] + [col for col in thesis13b.columns if "f1" in col]]

thesis13b_precision.to_csv(paths.RESULTS_PATH/"side-effects/thesis13b_precision.csv", index=False)
thesis13b_recall.to_csv(paths.RESULTS_PATH/"side-effects/thesis13b_recall.csv", index=False)
thesis13b_f1.to_csv(paths.RESULTS_PATH/"side-effects/thesis13b_f1.csv", index=False)

- Model seems to have difficulties if there is no medication mentioned with side effects. It repeats the examples. I could set a threshold of 3 or more repeated examples is indication that model is confused, so just put "unknown/unknown"
- also split medications that contain "/" into two lists
- concatenate the side effects for the same medications with ","

In [882]:
results

Unnamed: 0,side_effects_precision,side_effects_recall,side_effects_f1,medication_precision,medication_recall,medication_f1,duplicated,successful,approach,strategy
0,0.48,0.58,0.51,0.31,0.75,0.42,11.0,96.0,Base,Zero-Shot Base
1,0.53,0.64,0.57,0.43,0.88,0.54,13.0,94.0,Base,Zero-Shot Instruction
2,0.61,0.74,0.66,0.44,0.72,0.51,41.0,95.0,Base,Few-Shot Base
3,0.58,0.73,0.63,0.4,0.74,0.49,41.0,95.0,Base,Few-Shot Instruction
4,0.55,0.62,0.57,0.4,0.78,0.49,24.0,99.0,S2A-1,Zero-Shot Base
5,0.56,0.62,0.58,0.42,0.77,0.51,25.0,99.0,S2A-1,Zero-Shot Instruction
6,0.62,0.7,0.64,0.43,0.68,0.49,50.0,96.0,S2A-1,Few-Shot Base
7,0.6,0.73,0.65,0.44,0.79,0.53,46.0,97.0,S2A-1,Few-Shot Instruction
8,0.65,0.65,0.64,0.81,0.87,0.82,12.0,100.0,S2A-2,Zero-Shot Base
9,0.74,0.73,0.73,0.95,0.99,0.96,19.0,92.0,S2A-2,Zero-Shot Instruction


# 7B

## Base

In [883]:
# Zero Shot Vanilla
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_zero_shot_vanilla.pt", scorer)


side_effects_precision     0.483710
side_effects_recall        0.470272
side_effects_f1            0.469532
medication_precision       0.232186
medication_recall          0.301667
medication_f1              0.239222
duplicated                35.000000
successful                65.000000
dtype: float64

In [884]:
# Zero Shot instruction
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_zero_shot_instruction.pt", scorer)


side_effects_precision     0.439665
side_effects_recall        0.433439
side_effects_f1            0.429156
medication_precision       0.202595
medication_recall          0.248333
medication_f1              0.209222
duplicated                 5.000000
successful                98.000000
dtype: float64

In [885]:
# Few Shot Vanilla
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_few_shot_vanilla_10_examples.pt", scorer)

side_effects_precision     0.567207
side_effects_recall        0.562819
side_effects_f1            0.559396
medication_precision       0.398333
medication_recall          0.378333
medication_f1              0.380000
duplicated                18.000000
successful                98.000000
dtype: float64

In [886]:
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_few_shot_instruction_10_examples.pt", scorer)

side_effects_precision     0.577289
side_effects_recall        0.569135
side_effects_f1            0.566868
medication_precision       0.393333
medication_recall          0.366667
medication_f1              0.374444
duplicated                22.000000
successful                93.000000
dtype: float64

## RAG

In [887]:
# Zero Shot Vanilla
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_zero_shot_vanilla_rag.pt", scorer)


side_effects_precision      0.565181
side_effects_recall         0.514139
side_effects_f1             0.529431
medication_precision        0.420357
medication_recall           0.428333
medication_f1               0.413444
duplicated                 18.000000
successful                100.000000
dtype: float64

In [888]:
# Zero Shot instruction
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_zero_shot_instruction_rag.pt", scorer)


side_effects_precision     0.509794
side_effects_recall        0.470241
side_effects_f1            0.481800
medication_precision       0.412333
medication_recall          0.408333
medication_f1              0.402381
duplicated                19.000000
successful                99.000000
dtype: float64

In [889]:
# Few Shot Vanilla
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_few_shot_vanilla_10_examples_rag.pt", scorer)

side_effects_precision      0.577987
side_effects_recall         0.577308
side_effects_f1             0.571423
medication_precision        0.385000
medication_recall           0.378333
medication_f1               0.378333
duplicated                 24.000000
successful                100.000000
dtype: float64

In [890]:
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_few_shot_instruction_10_examples_rag.pt", scorer)

side_effects_precision     0.603300
side_effects_recall        0.605233
side_effects_f1            0.598004
medication_precision       0.397262
medication_recall          0.403333
medication_f1              0.389833
duplicated                29.000000
successful                97.000000
dtype: float64

## S2A

In [891]:
# Zero Shot Vanilla
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_zero_shot_vanilla_s2a.pt", scorer)


side_effects_precision      0.577548
side_effects_recall         0.547854
side_effects_f1             0.553986
medication_precision        0.790000
medication_recall           0.720000
medication_f1               0.743333
duplicated                  1.000000
successful                100.000000
dtype: float64

In [892]:
# Zero Shot instruction
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_zero_shot_instruction_s2a.pt", scorer)


side_effects_precision     0.596390
side_effects_recall        0.545501
side_effects_f1            0.559707
medication_precision       0.820000
medication_recall          0.743333
medication_f1              0.768333
duplicated                 5.000000
successful                99.000000
dtype: float64

In [893]:
# Few Shot Vanilla
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_few_shot_vanilla_10_examples_s2a.pt", scorer)

side_effects_precision     0.584967
side_effects_recall        0.582505
side_effects_f1            0.579639
medication_precision       0.550000
medication_recall          0.486667
medication_f1              0.506667
duplicated                12.000000
successful                99.000000
dtype: float64

In [894]:
create_agg_metrics_df(paths.RESULTS_PATH/"side-effects/labels.xlsx", paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_few_shot_instruction_10_examples_s2a.pt", scorer)

side_effects_precision     0.607340
side_effects_recall        0.592319
side_effects_f1            0.594689
medication_precision       0.470000
medication_recall          0.418333
medication_f1              0.435000
duplicated                20.000000
successful                96.000000
dtype: float64

In [895]:
file_paths7b = [paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_zero_shot_vanilla.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_zero_shot_instruction.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_few_shot_vanilla_10_examples.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_few_shot_instruction_10_examples.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_zero_shot_vanilla_rag.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_zero_shot_instruction_rag.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_few_shot_vanilla_10_examples_rag.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_few_shot_instruction_10_examples_rag.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_zero_shot_vanilla_s2a.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_zero_shot_instruction_s2a.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_few_shot_vanilla_10_examples_s2a.pt",
              paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_few_shot_instruction_10_examples_s2a.pt",]

In [896]:
results7b = summarize(file_paths7b, paths.RESULTS_PATH/"side-effects/labels.xlsx", scorer)
results7b = results7b.round(2)
results7b.to_csv(paths.RESULTS_PATH/"side-effects/summary7b.csv", index=False)

In [897]:
thesis7b = latex_metrics(results7b)
thesis7b_precision = thesis7b[["strategy"] + [col for col in thesis7b.columns if "precision" in col]]
thesis7b_recall = thesis7b[["strategy"] + [col for col in thesis7b.columns if "recall" in col]]
thesis7b_f1 = thesis7b[["strategy"] + [col for col in thesis7b.columns if "f1" in col]]

thesis7b_precision.to_csv(paths.RESULTS_PATH/"side-effects/thesis7b_precision.csv", index=False)
thesis7b_recall.to_csv(paths.RESULTS_PATH/"side-effects/thesis7b_recall.csv", index=False)
thesis7b_f1.to_csv(paths.RESULTS_PATH/"side-effects/thesis7b_f1.csv", index=False)

In [898]:
results7b

Unnamed: 0,side_effects_precision,side_effects_recall,side_effects_f1,medication_precision,medication_recall,medication_f1,duplicated,successful,approach,strategy
0,0.48,0.47,0.47,0.23,0.3,0.24,35.0,65.0,Base,Zero-Shot Base
1,0.44,0.43,0.43,0.2,0.25,0.21,5.0,98.0,Base,Zero-Shot Instruction
2,0.57,0.56,0.56,0.4,0.38,0.38,18.0,98.0,Base,Few-Shot Base
3,0.58,0.57,0.57,0.39,0.37,0.37,22.0,93.0,Base,Few-Shot Instruction
4,0.57,0.51,0.53,0.42,0.43,0.41,18.0,100.0,S2A-1,Zero-Shot Base
5,0.51,0.47,0.48,0.41,0.41,0.4,19.0,99.0,S2A-1,Zero-Shot Instruction
6,0.58,0.58,0.57,0.38,0.38,0.38,24.0,100.0,S2A-1,Few-Shot Base
7,0.6,0.61,0.6,0.4,0.4,0.39,29.0,97.0,S2A-1,Few-Shot Instruction
8,0.58,0.55,0.55,0.79,0.72,0.74,1.0,100.0,S2A-2,Zero-Shot Base
9,0.6,0.55,0.56,0.82,0.74,0.77,5.0,99.0,S2A-2,Zero-Shot Instruction


In [899]:
torch.load(paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_few_shot_instruction_10_examples_s2a.pt")["model_answers"]

['{"side_effects":[{"medication":"Phenytoinaufsättigung","side_effect":"Intoxikation"}]}',
 '{"side_effects":[{"medication":"Lymphopenie","side_effect":"bekannte NW"}]}',
 '{"side_effects":[{"medication":"kutane Verhärtungen","side_effect":"grippeähnliche Nebenwirkungen"}]}',
 '{"side_effects":[{"medication":"Prednison","side_effect":"Diabetes mellitus (Prednison), Osteoporose (Prednison), Neutropenie (Imurek), Niereninsuffizienz (z."}]}',
 '{"side_effects":[{"medication":"grippeähnliche Nebenwirkungen","side_effect":"strong"}]}',
 '{"side_effects":[{"medication":"grippeähnliche Nebenwirkungen","side_effect":"unknown"}]}',
 '{"side_effects":[{"medication":"kutanen Verhärtungen","side_effect":"grippeähnliche Nebenwirkungen"}]}',
 '{"side_effects":[{"medication":"Valproat/Lamotrigin","side_effect":"akzentuiert seit 10.07.2009, am ehesten Nebenwirkung"}]}',
 '{"side_effects":[{"medication":"grippeähnliche Nebenwirkungen","side_effect":"persistierend"}]}',
 '{"side_effects":[{"medication":

In [900]:
prepare_results(paths.RESULTS_PATH/"side-effects/side-effects_outlines_Llama2-MedTuned-7b_4bit_zero_shot_vanilla.pt")["successful"].value_counts()

successful
true     124
false     35
Name: count, dtype: int64