In [None]:
import sys
import os
sys.path.append(os.getcwd()+"/../..")
from src import paths

import torch
import json
import pandas as pd

from src.utils import get_default_pydantic_model

import re

In [None]:
pd.set_option('display.max_colwidth', None)

In [None]:
# Try to fix broken JSON by searching for last "}" then adding "]"
default_model = get_default_pydantic_model("medication")
def fix_json(json_str):
    last_index = json_str.rfind("}")
    if last_index != -1:  # If "}," is found
        fixed_json = json_str[:last_index + 1] + "]}"

        return fixed_json
    else:
        # If "}," is not found, return the default_model
        return default_model.model_dump_json()
    
# If all the values of morning, noon, evening, night are 0, then set them all to -99
def check_and_replace(df, cols_to_check):
    """
    Check if 4 specified columns in each row are all 0,
    then replace those 4 columns with -99.

    Parameters:
        df (pandas.DataFrame): Input DataFrame.
        cols_to_check (list): List of column names to check.

    Returns:
        pandas.DataFrame: DataFrame with replacements.
    """
    for index, row in df.iterrows():
        if (row[cols_to_check] == 0).sum() >= 4:
            df.loc[index, cols_to_check] = -99
    return df

def prepare_results(path: str)->pd.DataFrame:
    results = torch.load(path)
    df = pd.DataFrame(results)

    # Fix model_answers wherever successful is False
    _df_fixed = df[~df["successful"]].apply(lambda row: fix_json(row["model_answers"]), axis=1)
    df_fixed = df.copy()
    df_fixed.loc[~df["successful"], "model_answers"] = _df_fixed

    dfs = []
    for idx, (answer, text) in enumerate(zip(df_fixed["model_answers"], df_fixed["text"])):
        try:
            answer = json.loads(answer)
            medications = answer["medications"]
            for med in medications:
                med["text"] = text
                med["id"] = idx
                dfs.append((med))
        except:
            print(f"Error at index {idx}")
    res = pd.DataFrame(dfs)

    # If all the values of morning, noon, evening, night are 0, then set them all to -99
    res = check_and_replace(res, ["morning", "noon", "evening", "night"])

    # Convert everything to string and lowercase
    res = res.map(lambda x: str(x).lower())

    # Remove .0+ from every string
    expression = r"\.0+$"
    res = res.replace(expression, "", regex=True)

    return res

- False Positives (FP): Predicted but not in ground truth. So if a dose for a tp medication is predicted wrongly, or if a medication was predicted that is not in the ground truth (and thus also a dose was predicted that is not in the ground truth)
- True Positives (TP): Predicted and in ground truth.
- False Negatives (FN): Not predicted but in ground truth. This can only happen if the medication is not predicted as if it is predicted it will always also predict a unit.
- True Negatives (TN): Not predicted and not in ground truth (not relevant for this task).

Stuff it doesn't catch
- Spelling mistakes: defalgan instead of dafalgan. Example 73 for E/ml vs IE/ml. Example 79 for wrong schema (but I would also not know what to do)
- Sometimes didn't write full name or splits it up (Irfen Dolo; Excipial U Lipolotion)
- A lot of times if no schema was there it just put 0 0 0 0 for the intake. It had a tendency to write 0 (instead of -99) for stuff it didn't find. Also 1-0-0-0 not really safe. The rest seem robust.
    - Could convert to -99 if all of them are 0 which doesn't make sense anyways. Check if better metrics.
- Couldn't catch split stuff in dose like 80/10, 800/160
- Text 48 and 60 is good examples of what rule based dosis intake would not catch
- Also example 63 for dosis (2 für beideseitig)
- Example 68 for something you had to google as well (no mg, so model puts Filmtabl)
- Example 77 for something that is very hard to catch for a model like this
- Example 78 for hallucination of medication (because Insurance sounds like medication)
- Example 87 for a hard case (with no medication implicit), could probably be alleviated with correct example

# Evaluation

In [None]:
def prepare_labels(path: str)->pd.DataFrame:
    labels = pd.read_excel(path) 

    # Convert everything to string and lowercase
    labels = labels.map(lambda x: str(x).lower())

    # Remove .0+ from every string
    expression = r"\.0+$"
    labels = labels.replace(expression, "", regex=True)

    return labels

def calculate_precision_recall(ground_truth, predicted):
    ground_truth = ground_truth.copy()
    predicted = predicted.copy()
    true_positives = {}
    false_positives = {}
    false_negatives = {}
    
    for pred in predicted:
        pred_name = pred["name"]
        true_positives.setdefault("name", 0)
        matched = False
        for i, truth in enumerate(ground_truth):
            if truth["name"] in pred_name or pred_name in truth["name"]: # First we match the medication to the corresponding ground truth
                matched = True
                pred.pop("name") # Remove name and put true positive
                true_positives["name"] += 1
                for key in pred: # Then iterate over the keys and count the true positives and false positives
                    if pred[key] == truth[key]:
                        true_positives.setdefault(key, 0)
                        true_positives[key] += 1
                    else:
                        false_positives.setdefault(key, 0)
                        false_positives[key] += 1
                        
                del ground_truth[i]  # Remove the matched item
                break  # Move to the next predicted item
        
        if not matched: # If there is no medication in the ground truth that matches, then it is a false positive for all keys
            for key in pred:
                false_positives.setdefault(key, 0)
                false_positives[key] += 1
    for truth in ground_truth:
        for key in truth:
            false_negatives.setdefault(key, 0)
            false_negatives[key] += 1
    
    precision = {}
    recall = {}
    f1_score = {}

    if len(predicted) == 0:
        true_positives = {key: 0 for key in ground_truth[0].keys()}
        false_positives = {key: 0 for key in ground_truth[0].keys()}

    for key in relevant_columns:
        # Precision: TP / (TP + FP)
        precision[key] = true_positives.get(key, 0) / (true_positives.get(key, 0) + false_positives.get(key, 0) + 1e-10)

        # Recall: TP / (TP + FN)
        recall[key] = true_positives.get(key, 0) / (true_positives.get(key, 0) + false_negatives.get(key, 0) + 1e-10)

        # F1 Score: 2 * (Precision * Recall) / (Precision + Recall)
        f1_score[key] = 2 * (precision[key] * recall[key]) / (precision[key] + recall[key] + 1e-10)
    
    # Catch NA values in true_positives
    
    
    return precision, recall, f1_score

def evaluate_df(ground_truth: pd.DataFrame, predicted: pd.DataFrame, relevant_columns: list)->pd.DataFrame:
    """
    Evaluates the predicted DataFrame against the ground truth DataFrame.

    Parameters:
        ground_truth (pandas.DataFrame): DataFrame with the ground truth.
        predicted (pandas.DataFrame): DataFrame with the predicted values.
        relevant_columns (list): List of relevant columns to evaluate.

    Returns:
        pandas.DataFrame: DataFrame with the evaluated scores.
    """
    
    evaluated = []
    assert len(ground_truth.id.unique()) == len(predicted.id.unique()), "The number of unique ids (texts) in the ground truth and predicted DataFrames do not match."

    for idx in ground_truth.id.unique():
        ground_truth_dict = ground_truth[ground_truth.id == idx][relevant_columns].to_dict("records")
        predicted_dict = predicted[predicted.id == idx][relevant_columns].to_dict("records")

        precision, recall, f1_score = calculate_precision_recall(ground_truth_dict, predicted_dict)

        precision_dict = {f"precision_{key}": value for key, value in precision.items()}
        recall_dict = {f"recall_{key}": value for key, value in recall.items()}
        f1_score_dict = {f"f1_score_{key}": value for key, value in f1_score.items()}
        
        merged = {**precision_dict, **recall_dict, **f1_score_dict}
        merged["text"] = predicted[predicted.id == idx]["text"].values[0]
        merged["id"] = idx
        evaluated.append(merged)

    return pd.DataFrame(evaluated)

def aggregate_scores(evaluated: pd.DataFrame, columns_to_drop: list = ["id", "text"])->pd.DataFrame:
    """
    Aggregates the metrics by averaging metrics over all unique texts. Also aggregates intake dosage metrics.

    Parameters:
        evaluated (pandas.DataFrame): DataFrame with the evaluated scores.
        columns_to_drop (list): List of columns to drop. Default is ["id", "text"].

    Returns:
        pandas.DataFrame: DataFrame with the aggregated scores.
    """
    agg_df = evaluated.drop(columns=columns_to_drop).mean()
    agg_df["precision_intake"] = agg_df[["precision_morning", "precision_noon", "precision_evening", "precision_night"]].mean()
    agg_df["recall_intake"] = agg_df[["recall_morning", "recall_noon", "recall_evening", "recall_night"]].mean()
    agg_df["f1_score_intake"] = agg_df[["f1_score_morning", "f1_score_noon", "f1_score_evening", "f1_score_night"]].mean()
    agg_df.drop(["precision_morning", "precision_noon", "precision_evening", "precision_night", "recall_morning", "recall_noon", "recall_evening", "recall_night", "f1_score_morning", "f1_score_noon", "f1_score_evening", "f1_score_night"], inplace=True)
    return agg_df

In [None]:
labels = prepare_labels(paths.RESULTS_PATH/"medication/labels.xlsx")

res = prepare_results(paths.RESULTS_PATH/"medication/medication_outlines_Llama2-MedTuned-13b_4bit_few_shot_instruction_examples_10.pt")

# Choose only the relevant columns
relevant_columns = ["name", "dose", "dose_unit", "morning", "noon", "evening", "night"]

evaluated = evaluate_df(labels, res, relevant_columns)

agg_df = aggregate_scores(evaluated)

type(agg_df)

In [None]:
res

In [None]:
# Llama 13B
results_13b = []
filenames13b = [filename for filename in os.listdir(paths.RESULTS_PATH/"medication") if filename.startswith("medication_outlines_Llama2-MedTuned-13b")]
for filename in filenames13b:
    res = prepare_results(paths.RESULTS_PATH/"medication"/filename)
    evaluated = evaluate_df(labels, res, relevant_columns)
    agg_df = aggregate_scores(evaluated)
    results_13b.append(agg_df)
results_13b = pd.concat(results_13b, axis=1).round(2)
results_13b.columns = ["Few-Shot Instruction 1", "Few-Shot Instruction 10", "Few-Shot Instruction 2", "Few-Shot Instruction 4", "Few-Shot Instruction 8", "Few-Shot Base 10", "Zero-Shot Instruction 0", "Zero-Shot Base 0"]
# Reorder columns
results_13b = results_13b[["Zero-Shot Base 0", "Zero-Shot Instruction 0", "Few-Shot Base 10", "Few-Shot Instruction 10", "Few-Shot Instruction 1", "Few-Shot Instruction 2", "Few-Shot Instruction 4", "Few-Shot Instruction 8"]]

# Reorder rows
results_13b = results_13b.reindex(["precision_name", "precision_dose", "precision_dose_unit", "precision_intake", "recall_name", "recall_dose", "recall_dose_unit", "recall_intake", "f1_score_name", "f1_score_dose", "f1_score_dose_unit", "f1_score_intake"])
results_13b.to_csv(paths.RESULTS_PATH/"medication"/"results_13b.csv")

In [None]:
results_13b

In [None]:
# Llama 7B
results_7b = []
filenames7b = [filename for filename in os.listdir(paths.RESULTS_PATH/"medication") if filename.startswith("medication_outlines_Llama2-MedTuned-7b")]
for filename in filenames7b:
    res = prepare_results(paths.RESULTS_PATH/"medication"/filename)
    evaluated = evaluate_df(labels, res, relevant_columns)
    agg_df = aggregate_scores(evaluated)
    results_7b.append(agg_df)
results_7b = pd.concat(results_7b, axis=1).round(2)
results_7b.columns = ["Few-Shot Instruction 1", "Few-Shot Instruction 10", "Few-Shot Instruction 2", "Few-Shot Instruction 4", "Few-Shot Instruction 8", "Few-Shot Base 10", "Zero-Shot Instruction 0", "Zero-Shot Base 0"]
# Reorder columns
results_7b = results_7b[["Zero-Shot Base 0", "Zero-Shot Instruction 0", "Few-Shot Base 10", "Few-Shot Instruction 10", "Few-Shot Instruction 1", "Few-Shot Instruction 2", "Few-Shot Instruction 4", "Few-Shot Instruction 8"]]
# Reorder rows
results_7b = results_7b.reindex(["precision_name", "precision_dose", "precision_dose_unit", "precision_intake", "recall_name", "recall_dose", "recall_dose_unit", "recall_intake", "f1_score_name", "f1_score_dose", "f1_score_dose_unit", "f1_score_intake"])
results_7b.to_csv(paths.RESULTS_PATH/"medication"/"results_7b.csv")

In [None]:
results_7b

# Rule Based Approach
From old project

In [None]:
def load_list_medi_ms():
    
    '''
    load list of MS medications
    
    '''
    
    with open(paths.PROJECT_ROOT/"resources/old_project/medication_for_ms.txt", "r") as f:
        list_medi_ms = f.readlines()
    list_medi_ms = [item.strip() for item in list_medi_ms]
    
    return list_medi_ms

def _split_dose_and_unit(test_str, list_unit):
    
    '''
    split strings for dose and unit which aren't separated by a space, e.g. '120mg' and '0.5mg'
    
    '''
    
    def _contains_alpha_and_numeric(test_str):
        '''
        helper function to determine whether string could represent a dose and a unit and if it contains a dot
        '''

        # initiliaze
        status = 'no'
        
        # dose and unit start with a digit and end with a letter
        if (test_str[0].isdigit()) & (test_str[-1].isalpha()):

            # does it contain a dot
            if '.' in test_str:
                status = 'w/_dot'
            else:
                status = 'w/o_dot'

        return status

    # initialize
    list_tokens = [test_str]

    # get status whether it could be a dose and unit
    status = _contains_alpha_and_numeric(test_str)
    
    # if it contains 1 dot but no other special characters, and all letters represent a unit
    if status == 'w/_dot':

        if (test_str.count('.') == 1) & (len(re.findall('[\W]', test_str.replace('.', ''))) == 0):
            
            if re.findall('[a-zA-Z]+', test_str)[0] in list_unit:
    
                list_tokens = list(re.findall('(\d+)\.(\d+)?(\w+)', test_str)[0])
                list_tokens = ['.'.join(list_tokens[:2]), list_tokens[-1]]

    # if it doesn't contain a dot and all characters are either digits or letters
    if status == 'w/o_dot':
        
        if test_str.isalnum():
    
            list_tokens = list(re.findall('(\d+)(\w+)', test_str)[0])
    
    return list_tokens

def extract_dose_and_unit(list_tokens, list_unit_match):

    # intialize
    dose = ''
    unit = ''
    
    # extract dose and unit if there is exactly one match
    if len(list_unit_match) == 1:

        unit = list_unit_match[0]
        dose = list_tokens[list_tokens.index(unit) - 1]   
        
    return dose, unit

def extract_dosage_across_day(list_dose_match):
    
    # initialize
    morning = ''
    noon = ''
    evening = ''
    night = ''
    
    # extract dosage for first entry
    if len(list_dose_match) >= 1:
        
        medi_dose = list_dose_match[0]
        list_medi_doses = medi_dose.split('-')
        
        # 3 entries, e.g 1-1-1
        if len(list_medi_doses) == 3:
            morning = list_medi_doses[0]
            noon = list_medi_doses[1]
            evening = list_medi_doses[2]
            night = 0
            
        # 4 entries, e.g. 1-0-0-0
        elif len(list_medi_doses) == 4:
            morning = list_medi_doses[0]
            noon = list_medi_doses[1]
            evening = list_medi_doses[2]
            night = list_medi_doses[3]
            
    return morning, noon, evening, night

   
def flatten_listoflists(listoflists):
    '''
    function to flatten a list of lists
    
    input:
    - listoflists: nested list
    
    output:
    - flat_list: unnested list
    '''
    
    flat_list = [item for sublist in listoflists for item in sublist]
    
    return flat_list


In [None]:
df_medi = pd.read_csv(paths.DATA_PATH_PREPROCESSED/"medication/kisim_medication_sample.csv")
df_medi["id"] = df_medi.index

In [None]:
# list of units and MS medications
list_unit = ['mg', 'ug', 'g']
list_medi_ms = load_list_medi_ms()

# intitialize
list_output = list()

# for each row
for _, row in df_medi.iterrows():
    
    # get research id, etc.
    rid = row['rid']
    text_all = row['text']
    id = row['id']
    
    # split text into lines
    list_text_all = text_all.splitlines()
    
    # for each text line
    for text in list_text_all:
        
        # get tokens and split dose and unit, e.g. '120mg'
        list_tokens = text.split()
        list_tokens = flatten_listoflists([_split_dose_and_unit(item, list_unit) for item in list_tokens])
        
        # match medication names, units and dosing (e.g. 1-1-1)
        list_name_match = list(set(list_tokens).intersection(list_medi_ms))
        list_unit_match = list(set(list_tokens).intersection(list_unit))
        list_dose_match = [item for item in list_tokens if '-' in item]

        # if an MS medication name was matched
        if len(list_name_match) >=  1:

            # get (first) medication name (there are very few cases with > 1 name)
            name = list_name_match[0]

            # get dose and unit
            dose, unit = extract_dose_and_unit(list_tokens, list_unit_match)
              
            # get dosage across day
            morning, noon, evening, night = extract_dosage_across_day(list_dose_match)

            # extra field
            extra = ""

            # append
            list_output.append((name, dose, unit, morning, noon, evening, night, extra, text_all, id))       
            
# generate output data frame
df_results = pd.DataFrame(list_output, 
                         columns = [ 
                                    'name', 'dose', 'dose_unit', 
                                    'morning', 'noon', 'evening','night',
                                    'extra',
                                    'text', "id"])
output_ids = set(df_results.id.unique())
print("Number of reports that were processed:", len(output_ids))
left_over_ids = set(df_medi.id.unique()) - output_ids

left_over_dfs = []
for id in left_over_ids:
    _df = {**default_model.model_dump()["medications"][0], "text": df_medi[df_medi.id == id].text.values[0], "id": id}
    left_over_dfs.append(_df)
left_over_df = pd.DataFrame(left_over_dfs)

df_results = pd.concat([df_results, left_over_df]).sort_values("id").reset_index(drop=True)

# To get comparability need to map to string

df_results = df_results.map(lambda x: str(x).lower())
expression = r"\.0+$"
df_results = df_results.replace(expression, "", regex=True)


df_results.head()

In [None]:
df_results[df_results.id == "93"]

In [None]:
res[res["id"] == "78"]

In [None]:
labels[labels["id"] == "78"]

In [None]:
evaluated_rule = evaluate_df(labels, df_results, relevant_columns)
agg_df_rule = aggregate_scores(evaluated_rule)
agg_df_rule

In [None]:
# For the ones it predicted:
predicted_examples_ids = [str(id) for id in output_ids]
aggregate_scores(evaluated_rule[evaluated_rule.id.isin(predicted_examples_ids)])

In [None]:
res_rule = pd.DataFrame([aggregate_scores(evaluated_rule), aggregate_scores(evaluated_rule[evaluated_rule.id.isin(predicted_examples_ids)])]).round(2)
res_rule.to_csv(paths.RESULTS_PATH/"medication"/"results_rule.csv")

Notes for rule based:
- Only first example of medication is extracted.
- Only 4 out of 100 examples were even detected. (Also in original one they only detected around 6% of examples)
- Even for the ones it detected, if there are multiple medications it won't extract them. So recall not as high.
- For precision of course very high, but also here mistakes. Like dose-unit 0,5 is different from 0.5, which LLM catches as it outputs float format, while rule based does text matching

In [None]:
default_model

## Intermezzo
Just to check if they also just extracted so few examples

In [None]:
df_medi1 = pd.read_csv(paths.DATA_PATH_RSD/'reports_kisim_medication.csv')
# drop empty medication text
df_medi1 = df_medi1[df_medi1['medication_name'].notnull()]
# list of units and MS medications
list_unit = ['mg', 'ug', 'g']
list_medi_ms = load_list_medi_ms()

# intitialize
list_output = list()

# for each row
for _, row in df_medi1.iterrows():
    
    # # get research id, etc.
    # rid = row['rid']
    # text_all = row['text']
    # get research id, etc.
    rid = row['research_id']
    date = row['medication_prescription_date']
    prescription = row['medication_prescription_name']
    text_all = row['medication_name']
    
    # split text into lines
    list_text_all = text_all.splitlines()
    
    # for each text line
    for text in list_text_all:
        
        # get tokens and split dose and unit, e.g. '120mg'
        list_tokens = text.split()
        list_tokens = flatten_listoflists([_split_dose_and_unit(item, list_unit) for item in list_tokens])
        
        # match medication names, units and dosing (e.g. 1-1-1)
        list_name_match = list(set(list_tokens).intersection(list_medi_ms))
        list_unit_match = list(set(list_tokens).intersection(list_unit))
        list_dose_match = [item for item in list_tokens if '-' in item]

        # if an MS medication name was matched
        if len(list_name_match) >=  1:

            # get (first) medication name (there are very few cases with > 1 name)
            name = list_name_match[0]

            # get dose and unit
            dose, unit = extract_dose_and_unit(list_tokens, list_unit_match)
              
            # get dosage across day
            morning, noon, evening, night = extract_dosage_across_day(list_dose_match)

            # append
            list_output.append((rid, name, dose, unit, morning, noon, evening, night, text, text_all))        
            
# generate output data frame
df_results1 = pd.DataFrame(list_output, 
                         columns = ['rid', 
                                    'name', 'dose', 'unit', 
                                    'morning', 'evening', 'noon', 'night',
                                    'text_line', 'text_all'])
len(df_results1)/len(df_medi1)