# Evaluate Llama

In [1]:
# Load libraries
import torch
from trl import SFTTrainer
from datasets import load_dataset

import pickle
from datasets import Dataset
import pandas as pd
import dspy
import ast
import tqdm
import re
import csv
import string

from collections import OrderedDict, defaultdict
from typing import List, Tuple, Set, Dict
from collections import Counter

import re
import difflib
from collections import OrderedDict

  from .autonotebook import tqdm as notebook_tqdm


## Helper functions

In [9]:
def extract_counter(name):
    """
    Extracts a counter (n) at the end of the medication.
    Returns the medication name (without counter) and the counter.
    """
    match = re.search(r'\s*\((\d+)\)\s*$', name)
    if match:
        counter = match.group(1)
        name_without_counter = name[:match.start()].strip()
        return name_without_counter, counter
    else:
        return name.strip(), None

def preprocess_med_name(name):
    """
    Preprocess medication name:
    - Extracting counter
    - Lowercase
    - Remove any content in parentheses (except the counter)
    - Remove any punctuation except hyphens and spaces
    - Remove spaces
    """
    name_without_counter, counter = extract_counter(name)
    name_clean = name_without_counter.lower()
    name_clean = re.sub(r'\([^)]*\)', '', name_clean)
    name_clean = re.sub(r'[^\w\s-]', '', name_clean)
    name_clean = ' '.join(name_clean.split())
    
    return name_clean, counter

def are_med_names_matching(name1, name2, threshold):
    """
    Compare two medication names
    
    Parameters:
    - med name 1, med name 2
    - threshold: similarity score threshold
    
    Returns:
    - True if match, False otherwise.
    """
    name1_clean, counter1 = preprocess_med_name(name1)
    name2_clean, counter2 = preprocess_med_name(name2)
    
    name1_clean = re.sub(r'\b(\w+)\s+(n|en|e)\b', r'\1\2', name1_clean)
    name2_clean = re.sub(r'\b(\w+)\s+(n|en|e)\b', r'\1\2', name2_clean)
    
    # Check if the first characters are the same
    if not name1_clean or not name2_clean:
        return False, 0.0
    if name1_clean[0] != name2_clean[0]:
        return False, 0.0
    
    # Similarity score of the medication names
    ratio = difflib.SequenceMatcher(None, name1_clean, name2_clean).ratio()
    
    # If counters are different => not a match
    if counter1 != counter2:
        return False, ratio
    
    # If ratio >= threshold it is a match
    return ratio >= threshold, ratio

def adapt_medication_names(gold_dict, pred_dict, threshold):
    """
    Compare medication names. If match align med name in pred and gold.
    
    Parameters:
    - gold_dict: medication gold dict
    - pred_dict: medication pred dict
    - threshold: similarity score
    
    Returns:
    - medication pred dict with aligned medication dict
    """
    gold_meds = gold_dict.get('medications', [])
    pred_meds = pred_dict.get('medications', [])
    
    # Check each medication in pred_dict
    for pred_med in pred_meds:
        pred_name = pred_med.get('medication', '')
        best_match_name = None
        highest_ratio = 0
        
        # Compare with each medication in gold_dict
        for gold_med in gold_meds:
            gold_name = gold_med.get('medication', '')
            match, ratio = are_med_names_matching(pred_name, gold_name, threshold)
            if match and ratio > highest_ratio:
                highest_ratio = ratio
                best_match_name = gold_name
        
        # If match is found, update the pred_med 'medication' name
        if best_match_name:
            pred_med['medication'] = best_match_name
    
    return pred_dict

def reorder_dict(gold_samples, pred_samples):
    """
    Order pred dict similar to gold_dict. in-place
    
    Parameters:
    - gold_dict: medication gold dict
    - pred_dict: medication pred dict
    """
    for gold_sample, pred_sample in zip(gold_samples, pred_samples):
        gold_dict = gold_sample[1]
        pred_dict = pred_sample[1]

        if len(gold_dict['medications']) > 0:
            first_keys_order = list(gold_dict['medications'][0].keys())
        
            for i, med in enumerate(pred_dict['medications']):
                sorted_med = OrderedDict((key, med[key]) for key in first_keys_order)            
                pred_dict['medications'][i] = sorted_med
        else:
            pass

def ensure_keys_in_medications(data):
    """
    Check, if all keys (relation information) is present. in-place.
    
    Parameters:
    - data: gold or pred dict
    """
    RE_CLASSES_inclmed = RE_CLASSES.union({'medication'})
    count = 0
    for entry in data:
        text, ordered_dict = entry
        medications_list = ordered_dict.get('medications', [])
        
        for medication_dict in medications_list:
            # Check for each required key, add if missing with an empty value
            for key in RE_CLASSES_inclmed:
                if key not in medication_dict:
                    try:
                        medication_dict[key] = ''  # Add missing key with an empty value
                    except Exception as e:
                        print(entry)
                        count += 1
                        print(f"Error: {e} - {count}")

## Generate classification reports

In [3]:
def extract_relations(gold_samples: List[Tuple[str, OrderedDict]], pred_samples: List[Tuple[str, OrderedDict]], use_medication_threshold: bool) -> Tuple[List[Tuple[str, str, str, str]], List[Tuple[str, str, str, str]]]:
    """
    Extract relations from a list of samples. Ignore 'no drugs' sample

    Parameters:
    - gold_samples: List of gold samples
    - pred_samples: List of pred samples

    Returns:
    - Tuples for gold and predicted samples.
    """
        
    gold_triples = []
    pred_triples = []

    ensure_keys_in_medications(gold_samples)
    ensure_keys_in_medications(pred_samples)

    reorder_dict(gold_samples, pred_samples)
    
    for gold, pred in zip(gold_samples, pred_samples):
        gold_text, gold_dict = gold
        pred_text, pred_dict = pred

        # Ensure pred and gold texts are the same
        assert gold_text == pred_text, "Mismatch in texts between gold and pred."
        gold_medications = gold_dict.get('medications', [])
        pred_medications = pred_dict.get('medications', [])

        if use_medication_threshold:
            pred_dict = adapt_medication_names(gold_dict, pred_dict, threshold=0.75)

        # Extract medication information from gold and pred
        gold_medications = gold_dict.get('medications', [])
        pred_medications = pred_dict.get('medications', [])
        
        # Convert list of medication dictionaries to lists with counts for comparison
        gold_drugs_counts = defaultdict(int)  
        pred_drugs_counts = defaultdict(int)  

        for med in gold_medications:  
            gold_drugs_counts[med['medication']] += 1  

        for med in pred_medications:  
            pred_drugs_counts[med['medication']] += 1

        # Identify drugs with counts in both gold and predicted samples
        common_drugs = set(gold_drugs_counts.keys()) & set(pred_drugs_counts.keys())
        gold_only = set(gold_drugs_counts.keys()) - set(pred_drugs_counts.keys())
        pred_only = set(pred_drugs_counts.keys()) - set(gold_drugs_counts.keys())

        # Process drugs
        for drug in common_drugs:
            # Extract the correct number of instances for each drug
            gold_props_list = [med for med in gold_medications if med['medication'] == drug]  
            pred_props_list = [med for med in pred_medications if med['medication'] == drug]  

            gold_instances_count = gold_drugs_counts[drug]  
            pred_instances_count = pred_drugs_counts[drug]  
            
            common_count = min(gold_instances_count, pred_instances_count)  
            gold_props_list = gold_props_list[:common_count]  
            pred_props_list = pred_props_list[:common_count]

            for gold_props in gold_props_list:
                for rel_class, rel_value in gold_props.items():
                    if rel_class != 'medication':
                        gold_triples.append((gold_text, drug, rel_class, rel_value))

            for pred_props in pred_props_list:
                for rel_class, rel_value in pred_props.items():
                    if rel_class != 'medication':
                        pred_triples.append((pred_text, drug, rel_class, rel_value))
        
        # Process gold-only drugs
        for drug in gold_only:
            gold_props_list = [med for med in gold_medications if med['medication'] == drug]
            for gold_props in gold_props_list:
                for rel_class, rel_value in gold_props.items():
                    if rel_class != 'medication':
                        gold_triples.append((gold_text, drug, rel_class, rel_value))
                        pred_triples.append((pred_text, 'no drugs', rel_class, ''))
        
        # Process pred-only drugs
        for drug in pred_only:
            pred_props_list = [med for med in pred_medications if med['medication'] == drug]
            for pred_props in pred_props_list:
                for rel_class, rel_value in pred_props.items():
                    if rel_class != 'medication':
                        pred_triples.append((pred_text, drug, rel_class, rel_value))
                        gold_triples.append((gold_text, "no drugs", rel_class, ''))

    return gold_triples, pred_triples

In [4]:
def compute_re_metrics(gold_relations: List[Tuple[str, str, str, str]], pred_relations: List[Tuple[str, str, str, str]], classes: Set[str], complete: bool, target_class: str) -> Dict[str, Dict[str, int]]:
    """
    Compute exact and lenient TP, FP, FN for each relation class.

    Parameters:
    - gold_relations: list of gold (text, drug, class, value) tuples.
    - pred_relations: list of pred (text, drug, class, value) tuples.
    - classes: set of relation classes to evaluate.
    - complete: whether to use the complete matching strategy. Meaning: if gold_drug != pred_drug => model failed to identify drug. 
    Then all relation values are counted as FN and all pred_drug relation values are FP.

    Returns:
    - Dictionary with class as key and metrics as value.
    """

    metrics = {cls: {
        'Exact_TP': 0,
        'Exact_FP': 0,
        'Exact_FN': 0,
        'Exact_TN': 0,
        'Any_TP': 0,
        'Any_FP': 0,
        'Any_FN': 0,
        'Any_TN': 0,
        'Lenient_TP': 0,
        'Lenient_FP': 0,
        'Lenient_FN': 0,
        'Support': 0 
        } for cls in classes}

    false_positives = {key: [] for key in classes}
    false_negatives = {key: [] for key in classes}

    # Normalize relation values
    def normalize_relation_value(value, relation_class):
        def normalize_word(word):
            word = re.sub(r'\.', '', word.strip().lower())            
            word = re.sub(r'\s*\(s\)\s*', '', word)
            return word.strip()
    
        # If input is string, normalize the string
        if isinstance(value, str):
            value = normalize_word(value)
            return value
    
        # If input is list, normalize each element in the list
        elif isinstance(value, list):
            value = [normalize_word(v) for v in value]
            return value
    
        # If the input is neither string nor list, return it as is
        else:
            return value
        
    def normalize_drug_name(drug_name):
        # lowercase
        drug_name = drug_name.lower()
        # remove leading/trailing whitespace
        drug_name = drug_name.strip()
        # replace hyphens with a space
        drug_name = drug_name.replace('-', ' ')
        # remove special characters
        drug_name = re.sub(r'[^\w\s]', '', drug_name)
        # replace multiple spaces with single space
        drug_name = re.sub(r'\s+', ' ', drug_name)
        # remove all whitespace and special characters
        drug_name = re.sub(r'\W+', '', drug_name)
        return drug_name

    for index in range(len(gold_relations)):
        gold_sample = gold_relations[index]
        pred_sample = pred_relations[index]
        
        text, gold_drug, gold_relation_class, gold_relation_value = gold_sample
        text, pred_drug, pred_relation_class, pred_relation_value = pred_sample
    
        gold_relation_value = normalize_relation_value(gold_relation_value, gold_relation_class)
        pred_relation_value = normalize_relation_value(pred_relation_value, gold_relation_class) 

        gold_drug = normalize_drug_name(gold_drug)
        pred_drug = normalize_drug_name(pred_drug)       
    
        relation_class = gold_relation_class
    
        # Check if relation values are lists for consistent processing
        gold_values = gold_relation_value if isinstance(gold_relation_value, list) else [gold_relation_value] if gold_relation_value else []
        pred_values = pred_relation_value if isinstance(pred_relation_value, list) else [pred_relation_value] if pred_relation_value else []
        
        # DEBUGGING: Avoid typos for small 8b zero shot models
        #if relation_class == 'dosagemg' or relation_class == 'dosagem':
        #    relation_class = 'dosage'
        #if relation_class == 'refills':
        #    continue#break
        #if relation_class == 'Ant':
        #    continue#break
            
        # Update support value for the class
        metrics[relation_class]['Support'] += len(gold_values)

        # Initialize counts and matched indices for this sample
        exact_tp = 0
        lenient_tp = 0
        exact_fp = 0
        lenient_fp = 0
        exact_fn = 0
        lenient_fn = 0
        
        exact_matched_gold_indices = []
        exact_matched_pred_indices = []
        lenient_matched_gold_indices = []
        lenient_matched_pred_indices = []

        if gold_drug == pred_drug:
            # Compare gold and predicted values for exact match
            for gold_idx, gold_value in enumerate(gold_values):
                exact_match_found = False
                lenient_match_found = False
                for pred_idx, pred_value in enumerate(pred_values):
                    if gold_value == pred_value and pred_idx not in exact_matched_pred_indices:
                        # Exact match{}
                        exact_tp += 1
                        exact_matched_gold_indices.append(gold_idx)
                        exact_matched_pred_indices.append(pred_idx)
                        # Also count exact as lenient match
                        lenient_tp += 1
                        lenient_matched_gold_indices.append(gold_idx)
                        lenient_matched_pred_indices.append(pred_idx)
                        exact_match_found = True
                        break
                # If no exact match, check for lenient match
                if not exact_match_found:
                    for pred_idx, pred_value in enumerate(pred_values):
                        if pred_idx not in lenient_matched_pred_indices:
                            gold_words = gold_value.split()
                            pred_words = pred_value.split()
                            gold_words = [word.strip(string.punctuation).lower() for word in gold_words]
                            pred_words = [word.strip(string.punctuation).lower() for word in pred_words]
                            if any(word in pred_words for word in gold_words):
                                lenient_tp += 1
                                lenient_matched_gold_indices.append(gold_idx)
                                lenient_matched_pred_indices.append(pred_idx)
                                lenient_match_found = True
                                break
                    # Count exact FN if no exact match was found, regardless of lenient match
                    exact_fn += 1
                    if not lenient_match_found:
                        # Increment Lenient_FN if no lenient match was found
                        lenient_fn += 1
        
            # Any unmatched predicted values are false positives
            for pred_idx, pred_value in enumerate(pred_values):
                if pred_idx not in exact_matched_pred_indices:
                    exact_fp += 1
                if pred_idx not in lenient_matched_pred_indices:
                    lenient_fp += 1
        
        else:
            if complete:
                # Drug names do not match; count all gold values as FN and predicted values as FP
                exact_fn += len(gold_values)
                lenient_fn += len(gold_values)
                exact_fp += len(pred_values)
                lenient_fp += len(pred_values)
            else:
                # Drug names do not match; count only per missed drug name as FN and predicted values as FP             
                exact_fn += 1 if len(gold_values) >= 1 else 0 #len(gold_values)
                lenient_fn += 1 if len(gold_values) >= 1 else 0 #len(gold_values)
                exact_fp += 1 if len(pred_values) >= 1 else 0 #len(pred_values)
                lenient_fp += 1 if len(pred_values) >= 1 else 0 #len(pred_values)                
    
        # Update metrics
        metrics[relation_class]['Exact_TP'] += exact_tp
        metrics[relation_class]['Exact_FP'] += exact_fp
        metrics[relation_class]['Exact_FN'] += exact_fn
        
        metrics[relation_class]['Lenient_TP'] += lenient_tp
        metrics[relation_class]['Lenient_FP'] += lenient_fp
        metrics[relation_class]['Lenient_FN'] += lenient_fn        
               
        # Handle 'Any' metrics
        if relation_class in RE_CLASSES:
            any_match_found = False
            if gold_values and pred_values:
                for gold_value in gold_values:
                    gold_words = [word.strip(string.punctuation).lower() for word in gold_value.split()]
                    for pred_value in pred_values:
                        pred_words = [word.strip(string.punctuation).lower() for word in pred_value.split()]
                        if any(word in pred_words for word in gold_words):
                            any_match_found = True
                            break
                    if any_match_found:
                        break
                if any_match_found:
                    metrics[relation_class]['Any_TP'] += 1
                else:
                    metrics[relation_class]['Any_FN'] += 1
            elif gold_values and not pred_values:
                metrics[relation_class]['Any_FN'] += 1
            elif not gold_values and pred_values:
                metrics[relation_class]['Any_FP'] += 1

        # Collect FP
        for pred_idx, pred_value in enumerate(pred_values):
            if pred_idx not in lenient_matched_pred_indices:# or pred_idx not in exact_matched_pred_indices:
                #if gold_drug == pred_drug:
                    false_positives[relation_class].append({
                        'text': text,
                        'triple': (gold_sample[1:], pred_sample[1:]),
                        'drug': pred_drug,
                        'predicted_value': pred_value,
                        'gold_values': gold_values
                    })
            
        # Collect FN
        for gold_idx, gold_value in enumerate(gold_values):
            if gold_idx not in lenient_matched_gold_indices:# or gold_idx not in exact_matched_gold_indices:
                if gold_drug == pred_drug:
                    # Collect False Negatives
                    false_negatives[relation_class].append({
                        'text': text,
                        'triple': (gold_sample[1:], pred_sample[1:]),
                        'drug': gold_drug,
                        'gold_value': gold_value,
                        'predicted_values': pred_values
                    })

    return metrics

In [5]:
def print_classification_report(metrics: Dict[str, Dict[str, int]], classes: Set[str], task: str, evaluation_type: str):
    """
    Print classification reports.

    Parameters:
    - metrics dictionary containing metrics per class.
    - classes set of classes.
    - task: 'NER' or 'RE'
    - evaluation_type: exact or lenient
    """
    print(f"\n{task} {evaluation_type} Classification Reportn")
    header = f"{'Class':<15} {'Support':<8} {'TP':<5} {'FP':<5} {'FN':<5} {'Precision':<10} {'Recall':<10} {'F1-Score':<10}"
    print(header)
    print("-" * len(header))
    
    precision_sum = 0.0
    recall_sum = 0.0
    f1_sum = 0.0
    num_classes = len(classes)
    
    for cls in sorted(classes):
        if evaluation_type == 'Exact':
            TP = metrics[cls]['Exact_TP']
            FP = metrics[cls]['Exact_FP']
            FN = metrics[cls]['Exact_FN']
        elif evaluation_type == 'Lenient':
            TP = metrics[cls]['Lenient_TP']
            FP = metrics[cls]['Lenient_FP']
            FN = metrics[cls]['Lenient_FN']
        elif evaluation_type == 'Any':
            TP = metrics[cls]['Any_TP']
            FP = metrics[cls]['Any_FP']
            FN = metrics[cls]['Any_FN']
        else:
            TP = FP = FN = 0
        
        precision = TP / (TP + FP) if (TP + FP) > 0 else 0.0
        recall = TP / (TP + FN) if (TP + FN) > 0 else 0.0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
        
        precision_sum += precision
        recall_sum += recall
        f1_sum += f1
        
        print(f"{cls:<15} {metrics[cls]['Support']:<8} {TP:<5} {FP:<5} {FN:<5} {precision:<10.3f} {recall:<10.3f} {f1:<10.3f}")
        
    # compute micro average
    if evaluation_type == 'Exact':
        TP_micro = sum([metrics[cls]['Exact_TP'] for cls in classes])
        FP_micro = sum([metrics[cls]['Exact_FP'] for cls in classes])
        FN_micro = sum([metrics[cls]['Exact_FN'] for cls in classes])
    elif evaluation_type == 'Lenient':
        TP_micro = sum([metrics[cls]['Lenient_TP'] for cls in classes])
        FP_micro = sum([metrics[cls]['Lenient_FP'] for cls in classes])
        FN_micro = sum([metrics[cls]['Lenient_FN'] for cls in classes])
    elif evaluation_type == 'Any':
        TP_micro = sum([metrics[cls]['Any_TP'] for cls in classes])
        FP_micro = sum([metrics[cls]['Any_FP'] for cls in classes])
        FN_micro = sum([metrics[cls]['Any_FN'] for cls in classes])        
    else:
        TP_micro = FP_micro = FN_micro = 0
    
    precision_micro = TP_micro / (TP_micro + FP_micro) if (TP_micro + FP_micro) > 0 else 0.0
    recall_micro = TP_micro / (TP_micro + FN_micro) if (TP_micro + FN_micro) > 0 else 0.0
    f1_micro = 2 * precision_micro * recall_micro / (precision_micro + recall_micro) if (precision_micro + recall_micro) > 0 else 0.0
    
    # compute macro average
    precision_macro = precision_sum / num_classes if num_classes > 0 else 0.0
    recall_macro = recall_sum / num_classes if num_classes > 0 else 0.0
    f1_macro = f1_sum / num_classes if num_classes > 0 else 0.0
    
    print("-" * len(header))
    print(f"{'Micro Avg':<15} {'-':<5} {'-':<5} {'-':<5} {precision_micro:<10.3f} {recall_micro:<10.3f} {f1_micro:<10.3f}")
    print(f"{'Macro Avg':<15} {'-':<5} {'-':<5} {'-':<5} {precision_macro:<10.3f} {recall_macro:<10.3f} {f1_macro:<10.3f}")
    print("\n")

def evaluate_model(gold_samples: List[OrderedDict], pred_samples: List[OrderedDict], complete: bool, use_medication_threshold = False, target_class = ""):
    """
    Evaluate model by computing and printing classification reports (NER, RE)

    Parameters:
    - gold_samples list of gold standard samples.
    - pred_samples list of predicted samples.
    """
    # extract relations for RE
    gold_relations, pred_relations = extract_relations(gold_samples, pred_samples, use_medication_threshold)  
    
    # compute RE metrics
    re_metrics = compute_re_metrics(gold_relations, pred_relations, classes=RE_CLASSES, complete=complete, target_class = target_class)
    
    # print Exact RE classification report
    print_classification_report(re_metrics, RE_CLASSES, task='Relation Extraction (RE)', evaluation_type='Exact')
    
    # print Lenient RE classification report
    print_classification_report(re_metrics, RE_CLASSES, task='Relation Extraction (RE)', evaluation_type='Lenient')

In [14]:
# Define relation classes
RE_CLASSES = {'ade', 'dosage', 'duration', 'form', 'frequency', 'reason', 'route', 'strength'}

with open('output.csv', 'r', newline='') as csvfile:
    reader = csv.DictReader(csvfile, delimiter='|')
    golds = []
    preds = []
    
    for id, row in enumerate(reader):
        text = row['text']

        gold = eval(row['gold'], {"OrderedDict": OrderedDict}, {})
        pred = eval(row['pred'], {"OrderedDict": OrderedDict}, {})

        golds.append((text, gold))
        preds.append((text, pred))
        
    evaluate_model(golds, preds, complete=True, use_medication_threshold=True, target_class=RE_CLASSES)



Relation Extraction (RE) Exact Classification Reportn
Class           Support  TP    FP    FN    Precision  Recall     F1-Score  
---------------------------------------------------------------------------
ade             3        2     0     1     1.000      0.667      0.800     
dosage          0        0     0     0     0.000      0.000      0.000     
duration        5        5     0     0     1.000      1.000      1.000     
form            0        0     0     0     0.000      0.000      0.000     
frequency       1        1     0     0     1.000      1.000      1.000     
reason          13       12    4     1     0.750      0.923      0.828     
route           1        1     0     0     1.000      1.000      1.000     
strength        2        2     0     0     1.000      1.000      1.000     
---------------------------------------------------------------------------
Micro Avg       -     -     -     0.852      0.920      0.885     
Macro Avg       -     -     -     0.719   