In [224]:
import sys
import os
sys.path.append(os.getcwd()+"/../..")
from src import paths

import torch
import json
import pandas as pd

from src.utils import get_default_pydantic_model

In [225]:
pd.set_option('display.max_colwidth', None)
results = torch.load(paths.RESULTS_PATH/"medication/medication_outlines_Llama2-MedTuned-7b_4bit__shot_instruction_examples_10.pt")
df = pd.DataFrame(results)

In [226]:
# Try to fix broken JSON by searching for last "}" then adding "]"
default_model = get_default_pydantic_model("medication")
def fix_json(json_str):
    last_index = json_str.rfind("}")
    if last_index != -1:  # If "}," is found
        fixed_json = json_str[:last_index + 1] + "]}"

        return fixed_json
    else:
        # If "}," is not found, return the default_model
        return default_model.model_dump_json()

In [227]:
# Fix model_answers wherever successful is False
_df_fixed = df[~df["successful"]].apply(lambda row: fix_json(row["model_answers"]), axis=1)
df_fixed = df.copy()
df_fixed.loc[~df["successful"], "model_answers"] = _df_fixed

In [228]:
dfs = []
for idx, (answer, text) in enumerate(zip(df_fixed["model_answers"], df_fixed["text"])):
    try:
        answer = json.loads(answer)
        medications = answer["medications"]
        for med in medications:
            med["text"] = text
            med["id"] = idx
            dfs.append((med))
    except:
        print(f"Error at index {idx}")
res = pd.DataFrame(dfs)

# If all the values of morning, noon, evening, night are 0, then set them all to -99
def check_and_replace(df, cols_to_check):
    """
    Check if 4 specified columns in each row are all 0,
    then replace those 4 columns with -99.

    Parameters:
        df (pandas.DataFrame): Input DataFrame.
        cols_to_check (list): List of column names to check.

    Returns:
        pandas.DataFrame: DataFrame with replacements.
    """
    for index, row in df.iterrows():
        if (row[cols_to_check] == 0).sum() >= 4:
            df.loc[index, cols_to_check] = -99
    return df

res = check_and_replace(res, ["morning", "noon", "evening", "night"])

# Convert everything to string and lowercase
res = res.map(lambda x: str(x).lower())

# Remove .0+ from every string
expression = r"\.0+$"
res = res.replace(expression, "", regex=True)

- False Positives (FP): Predicted but not in ground truth. So if a dose for a tp medication is predicted wrongly, or if a medication was predicted that is not in the ground truth (and thus also a dose was predicted that is not in the ground truth)
- True Positives (TP): Predicted and in ground truth.
- False Negatives (FN): Not predicted but in ground truth. This can only happen if the medication is not predicted as if it is predicted it will always also predict a unit.
- True Negatives (TN): Not predicted and not in ground truth (not relevant for this task).

Stuff it doesn't catch
- Spelling mistakes: defalgan instead of dafalgan. Example 73 for E/ml vs IE/ml. Example 79 for wrong schema (but I would also not know what to do)
- Sometimes didn't write full name or splits it up (Irfen Dolo; Excipial U Lipolotion)
- A lot of times if no schema was there it just put 0 0 0 0 for the intake. It had a tendency to write 0 (instead of -99) for stuff it didn't find. Also 1-0-0-0 not really safe. The rest seem robust.
    - Could convert to -99 if all of them are 0 which doesn't make sense anyways. Check if better metrics.
- Couldn't catch split stuff in dose like 80/10, 800/160
- Text 48 and 60 is good examples of what rule based dosis intake would not catch
- Also example 63 for dosis (2 für beideseitig)
- Example 68 for something you had to google as well (no mg, so model puts Filmtabl)
- Example 77 for something that is very hard to catch for a model like this
- Example 78 for hallucination of medication (because Insurance sounds like medication)
- Example 87 for a hard case (with no medication implicit), could probably be alleviated with correct example

# Evaluation

In [229]:
labels = pd.read_excel(paths.RESULTS_PATH/"medication/labels.xlsx") 

# Convert everything to string and lowercase
labels = labels.map(lambda x: str(x).lower())

# Remove .0+ from every string
expression = r"\.0+$"
labels = labels.replace(expression, "", regex=True)

In [230]:
# Choose only the relevant columns
relevant_columns = ["name", "dose", "dose_unit", "morning", "noon", "evening", "night"]

In [231]:
def calculate_precision_recall(ground_truth, predicted):
    ground_truth = ground_truth.copy()
    predicted = predicted.copy()
    true_positives = {}
    false_positives = {}
    false_negatives = {}
    
    for pred in predicted:
        pred_name = pred["name"]
        true_positives.setdefault("name", 0)
        matched = False
        for i, truth in enumerate(ground_truth):
            if truth["name"] in pred_name or pred_name in truth["name"]: # First we match the medication to the corresponding ground truth
                matched = True
                pred.pop("name") # Remove name and put true positive
                true_positives["name"] += 1
                for key in pred: # Then iterate over the keys and count the true positives and false positives
                    if pred[key] == truth[key]:
                        true_positives.setdefault(key, 0)
                        true_positives[key] += 1
                    else:
                        false_positives.setdefault(key, 0)
                        false_positives[key] += 1
                        
                del ground_truth[i]  # Remove the matched item
                break  # Move to the next predicted item
        
        if not matched: # If there is no medication in the ground truth that matches, then it is a false positive for all keys
            for key in pred:
                false_positives.setdefault(key, 0)
                false_positives[key] += 1
    for truth in ground_truth:
        for key in truth:
            false_negatives.setdefault(key, 0)
            false_negatives[key] += 1
    
    precision = {}
    recall = {}
    f1_score = {}

    if len(predicted) == 0:
        true_positives = {key: 0 for key in ground_truth[0].keys()}
        false_positives = {key: 0 for key in ground_truth[0].keys()}

    for key in relevant_columns:
        # Precision: TP / (TP + FP)
        precision[key] = true_positives.get(key, 0) / (true_positives.get(key, 0) + false_positives.get(key, 0) + 1e-10)

        # Recall: TP / (TP + FN)
        recall[key] = true_positives.get(key, 0) / (true_positives.get(key, 0) + false_negatives.get(key, 0) + 1e-10)

        # F1 Score: 2 * (Precision * Recall) / (Precision + Recall)
        f1_score[key] = 2 * (precision[key] * recall[key]) / (precision[key] + recall[key] + 1e-10)
    
    # Catch NA values in true_positives
    
    
    return precision, recall, f1_score

evaluated = []
assert len(labels.id.unique()) == len(res.id.unique())

for idx in labels.id.unique():
    ground_truth = labels[labels.id == idx][relevant_columns].to_dict("records")
    predicted = res[res.id == idx][relevant_columns].to_dict("records")
    precision, recall, f1_score = calculate_precision_recall(ground_truth, predicted)
    precision_dict = {f"precision_{key}": value for key, value in precision.items()}
    recall_dict = {f"recall_{key}": value for key, value in recall.items()}
    f1_score_dict = {f"f1_score_{key}": value for key, value in f1_score.items()}
    merged = {**precision_dict, **recall_dict, **f1_score_dict}
    merged["text"] = res[res.id == idx]["text"].values[0]
    merged["id"] = idx
    evaluated.append(merged)

evaluated_df = pd.DataFrame(evaluated)

In [232]:
agg_df = evaluated_df.drop(columns=["id", "text"]).mean()
agg_df["intake_metrics_precision"] = agg_df[["precision_morning", "precision_noon", "precision_evening", "precision_night"]].mean()
agg_df["intake_metrics_recall"] = agg_df[["recall_morning", "recall_noon", "recall_evening", "recall_night"]].mean()
agg_df["intake_metrics_f1_score"] = agg_df[["f1_score_morning", "f1_score_noon", "f1_score_evening", "f1_score_night"]].mean()
agg_df.drop(["precision_morning", "precision_noon", "precision_evening", "precision_night", "recall_morning", "recall_noon", "recall_evening", "recall_night", "f1_score_morning", "f1_score_noon", "f1_score_evening", "f1_score_night"], inplace=True)


In [233]:
agg_df

precision_name              0.911667
precision_dose              0.707667
precision_dose_unit         0.676667
recall_name                 0.842143
recall_dose                 0.651000
recall_dose_unit            0.635143
f1_score_name               0.856270
f1_score_dose               0.660603
f1_score_dose_unit          0.639603
intake_metrics_precision    0.592292
intake_metrics_recall       0.560083
intake_metrics_f1_score     0.561190
dtype: float64

precision_name              0.978333
precision_dose              0.868333
precision_dose_unit         0.908333
recall_name                 0.988238
recall_dose                 0.886905
recall_dose_unit            0.926905
f1_score_name               0.981120
f1_score_dose               0.874612
f1_score_dose_unit          0.914612
intake_metrics_precision    0.455167
intake_metrics_recall       0.473750
intake_metrics_f1_score     0.461512

precision_name              0.978333
precision_dose              0.868333
precision_dose_unit         0.908333
recall_name                 0.988238
recall_dose                 0.886905
recall_dose_unit            0.926905
f1_score_name               0.981120
f1_score_dose               0.874612
f1_score_dose_unit          0.914612
intake_metrics_precision    0.601417
intake_metrics_recall       0.622750
intake_metrics_f1_score     0.608456
dtype: float64