# Feedback LLM

## Collect false predictions for the feedback LLM

In [4]:
# Load libraries
import torch
from trl import SFTTrainer
from datasets import load_dataset

import pickle
from datasets import Dataset
import pandas as pd
import dspy
import ast
import tqdm
import re
import csv
import string
import os

from collections import OrderedDict, defaultdict
from typing import List, Tuple, Set, Dict, Optional, Union
from collections import Counter

import re
import difflib
from collections import OrderedDict
from functools import wraps
import random
import numpy as np

from pydantic import BaseModel
from pydantic import Field
import json
from transformers import pipeline

### Helper functions

In [7]:
def extract_counter(name):
    """
    Extracts a counter (n) at the end of the medication.
    Returns the medication name (without counter) and the counter.
    """
    match = re.search(r'\s*\((\d+)\)\s*$', name)
    if match:
        counter = match.group(1)
        name_without_counter = name[:match.start()].strip()
        return name_without_counter, counter
    else:
        return name.strip(), None

def preprocess_med_name(name):
    """
    Preprocess medication name:
    - Extracting counter
    - Lowercase
    - Remove any content in parentheses (except the counter)
    - Remove any punctuation except hyphens and spaces
    - Remove spaces
    """
    name_without_counter, counter = extract_counter(name)
    name_clean = name_without_counter.lower()
    name_clean = re.sub(r'\([^)]*\)', '', name_clean)
    name_clean = re.sub(r'[^\w\s-]', '', name_clean)
    name_clean = ' '.join(name_clean.split())
    
    return name_clean, counter

def are_med_names_matching(name1, name2, threshold):
    """
    Compare two medication names
    
    Parameters:
    - med name 1, med name 2
    - threshold: similarity score threshold
    
    Returns:
    - True if match, False otherwise.
    """
    name1_clean, counter1 = preprocess_med_name(name1)
    name2_clean, counter2 = preprocess_med_name(name2)
    
    name1_clean = re.sub(r'\b(\w+)\s+(n|en|e)\b', r'\1\2', name1_clean)
    name2_clean = re.sub(r'\b(\w+)\s+(n|en|e)\b', r'\1\2', name2_clean)
    
    # Check if the first characters are the same
    if not name1_clean or not name2_clean:
        return False, 0.0
    if name1_clean[0] != name2_clean[0]:
        return False, 0.0
    
    # Similarity score of the medication names
    ratio = difflib.SequenceMatcher(None, name1_clean, name2_clean).ratio()
    
    # If counters are different => not a match
    if counter1 != counter2:
        return False, ratio
    
    # If ratio >= threshold it is a match
    return ratio >= threshold, ratio

def adapt_medication_names(gold_dict, pred_dict, threshold):
    """
    Compare medication names. If match align med name in pred and gold.
    
    Parameters:
    - gold_dict: medication gold dict
    - pred_dict: medication pred dict
    - threshold: similarity score
    
    Returns:
    - medication pred dict with aligned medication dict
    """
    gold_meds = gold_dict.get('medications', [])
    pred_meds = pred_dict.get('medications', [])
    
    # Check each medication in pred_dict
    for pred_med in pred_meds:
        pred_name = pred_med.get('medication', '')
        best_match_name = None
        highest_ratio = 0
        
        # Compare with each medication in gold_dict
        for gold_med in gold_meds:
            gold_name = gold_med.get('medication', '')
            match, ratio = are_med_names_matching(pred_name, gold_name, threshold)
            if match and ratio > highest_ratio:
                highest_ratio = ratio
                best_match_name = gold_name
        
        # If match is found, update the pred_med 'medication' name
        if best_match_name:
            pred_med['medication'] = best_match_name
    
    return pred_dict

def reorder_dict(gold_samples, pred_samples):
    """
    Order pred dict similar to gold_dict. in-place
    
    Parameters:
    - gold_dict: medication gold dict
    - pred_dict: medication pred dict
    """
    for gold_sample, pred_sample in zip(gold_samples, pred_samples):
        gold_dict = gold_sample[1]
        pred_dict = pred_sample[1]

        if len(gold_dict['medications']) > 0:
            first_keys_order = list(gold_dict['medications'][0].keys())
        
            for i, med in enumerate(pred_dict['medications']):
                sorted_med = OrderedDict((key, med[key]) for key in first_keys_order)            
                pred_dict['medications'][i] = sorted_med
        else:
            pass

def ensure_keys_in_medications(data):
    """
    Check, if all keys (relation information) is present. in-place.
    
    Parameters:
    - data: gold or pred dict
    """
    RE_CLASSES_inclmed = RE_CLASSES.union({'medication'})
    count = 0
    for entry in data:
        text, ordered_dict = entry
        medications_list = ordered_dict.get('medications', [])
        
        for medication_dict in medications_list:
            # Check for each required key, add if missing with an empty value
            for key in RE_CLASSES_inclmed:
                if key not in medication_dict:
                    try:
                        medication_dict[key] = ''  # Add missing key with an empty value
                    except Exception as e:
                        print(entry)
                        count += 1
                        print(f"Error: {e} - {count}")

def extract_relations(gold_samples: List[Tuple[str, OrderedDict]], pred_samples: List[Tuple[str, OrderedDict]], use_medication_threshold: bool) -> Tuple[List[Tuple[str, str, str, str]], List[Tuple[str, str, str, str]]]:
    """
    Extract relations from a list of samples. Ignore 'no drugs' sample

    Parameters:
    - gold_samples: List of gold samples
    - pred_samples: List of pred samples

    Returns:
    - Tuples for gold and predicted samples.
    """
        
    gold_triples = []
    pred_triples = []

    ensure_keys_in_medications(gold_samples)
    ensure_keys_in_medications(pred_samples)

    reorder_dict(gold_samples, pred_samples)
    
    for gold, pred in zip(gold_samples, pred_samples):
        gold_text, gold_dict = gold
        pred_text, pred_dict = pred

        # Ensure pred and gold texts are the same
        assert gold_text == pred_text, "Mismatch in texts between gold and pred."
        gold_medications = gold_dict.get('medications', [])
        pred_medications = pred_dict.get('medications', [])

        if use_medication_threshold:
            pred_dict = adapt_medication_names(gold_dict, pred_dict, threshold=0.75)

        # Extract medication information from gold and pred
        gold_medications = gold_dict.get('medications', [])
        pred_medications = pred_dict.get('medications', [])
        
        # Convert list of medication dictionaries to lists with counts for comparison
        gold_drugs_counts = defaultdict(int)  
        pred_drugs_counts = defaultdict(int)  

        for med in gold_medications:  
            gold_drugs_counts[med['medication']] += 1  

        for med in pred_medications:  
            pred_drugs_counts[med['medication']] += 1

        # Identify drugs with counts in both gold and predicted samples
        common_drugs = set(gold_drugs_counts.keys()) & set(pred_drugs_counts.keys())
        gold_only = set(gold_drugs_counts.keys()) - set(pred_drugs_counts.keys())
        pred_only = set(pred_drugs_counts.keys()) - set(gold_drugs_counts.keys())

        # Process drugs
        for drug in common_drugs:
            # Extract the correct number of instances for each drug
            gold_props_list = [med for med in gold_medications if med['medication'] == drug]  
            pred_props_list = [med for med in pred_medications if med['medication'] == drug]  

            gold_instances_count = gold_drugs_counts[drug]  
            pred_instances_count = pred_drugs_counts[drug]  
            
            common_count = min(gold_instances_count, pred_instances_count)  
            gold_props_list = gold_props_list[:common_count]  
            pred_props_list = pred_props_list[:common_count]

            for gold_props in gold_props_list:
                for rel_class, rel_value in gold_props.items():
                    if rel_class != 'medication':
                        gold_triples.append((gold_text, drug, rel_class, rel_value))

            for pred_props in pred_props_list:
                for rel_class, rel_value in pred_props.items():
                    if rel_class != 'medication':
                        pred_triples.append((pred_text, drug, rel_class, rel_value))
        
        # Process gold-only drugs
        for drug in gold_only:
            gold_props_list = [med for med in gold_medications if med['medication'] == drug]
            for gold_props in gold_props_list:
                for rel_class, rel_value in gold_props.items():
                    if rel_class != 'medication':
                        gold_triples.append((gold_text, drug, rel_class, rel_value))
                        pred_triples.append((pred_text, 'no drugs', rel_class, ''))
        
        # Process pred-only drugs
        for drug in pred_only:
            pred_props_list = [med for med in pred_medications if med['medication'] == drug]
            for pred_props in pred_props_list:
                for rel_class, rel_value in pred_props.items():
                    if rel_class != 'medication':
                        pred_triples.append((pred_text, drug, rel_class, rel_value))
                        gold_triples.append((gold_text, "no drugs", rel_class, ''))

    return gold_triples, pred_triples

def compute_re_metrics(gold_relations: List[Tuple[str, str, str, str]], pred_relations: List[Tuple[str, str, str, str]], classes: Set[str], complete: bool, target_class: str) -> Dict[str, Dict[str, int]]:
    """
    Collect exact and lenient TP, FP, FN for each relation class.

    Parameters:
    - gold_relations: list of gold (text, drug, class, value) tuples.
    - pred_relations: list of pred (text, drug, class, value) tuples.
    - classes: set of relation classes to evaluate.
    - complete: whether to use the complete matching strategy. Meaning: if gold_drug != pred_drug => model failed to identify drug.
    - target_class: target class to evaluate
    Then all relation values are counted as FN and all pred_drug relation values are FP.

    Returns:
    - Dictionary with class as key and metrics as value.
    """

    metrics = {cls: {
        'Exact_TP': 0,
        'Exact_FP': 0,
        'Exact_FN': 0,
        'Exact_TN': 0,
        'Any_TP': 0,
        'Any_FP': 0,
        'Any_FN': 0,
        'Any_TN': 0,
        'Lenient_TP': 0,
        'Lenient_FP': 0,
        'Lenient_FN': 0,
        'Support': 0 
        } for cls in classes}

    false_positives = {key: [] for key in classes}
    false_negatives = {key: [] for key in classes}

    # Normalize relation values
    def normalize_relation_value(value, relation_class):
        def normalize_word(word):
            word = re.sub(r'\.', '', word.strip().lower())            
            word = re.sub(r'\s*\(s\)\s*', '', word)
            return word.strip()
    
        # If input is string, normalize the string
        if isinstance(value, str):
            value = normalize_word(value)
            return value
    
        # If input is list, normalize each element in the list
        elif isinstance(value, list):
            value = [normalize_word(v) for v in value]
            return value
    
        # If the input is neither string nor list, return it as is
        else:
            return value
        
    def normalize_drug_name(drug_name):
        # lowercase
        drug_name = drug_name.lower()
        # remove leading/trailing whitespace
        drug_name = drug_name.strip()
        # replace hyphens with a space
        drug_name = drug_name.replace('-', ' ')
        # remove special characters
        drug_name = re.sub(r'[^\w\s]', '', drug_name)
        # replace multiple spaces with single space
        drug_name = re.sub(r'\s+', ' ', drug_name)
        # remove all whitespace and special characters
        drug_name = re.sub(r'\W+', '', drug_name)
        return drug_name

    for index in range(len(gold_relations)):
        gold_sample = gold_relations[index]
        pred_sample = pred_relations[index]
        
        text, gold_drug, gold_relation_class, gold_relation_value = gold_sample
        text, pred_drug, pred_relation_class, pred_relation_value = pred_sample
    
        gold_relation_value = normalize_relation_value(gold_relation_value, gold_relation_class)
        pred_relation_value = normalize_relation_value(pred_relation_value, gold_relation_class) 

        gold_drug = normalize_drug_name(gold_drug)
        pred_drug = normalize_drug_name(pred_drug)       
    
        relation_class = gold_relation_class
    
        # Check if relation values are lists for consistent processing
        gold_values = gold_relation_value if isinstance(gold_relation_value, list) else [gold_relation_value] if gold_relation_value else []
        pred_values = pred_relation_value if isinstance(pred_relation_value, list) else [pred_relation_value] if pred_relation_value else []
        
        # DEBUGGING: Avoid typos for small 8b zero shot models
        #if relation_class == 'dosagemg' or relation_class == 'dosagem':
        #    relation_class = 'dosage'
        #if relation_class == 'refills':
        #    continue#break
        #if relation_class == 'Ant':
        #    continue#break
            
        # Update support value for the class
        metrics[relation_class]['Support'] += len(gold_values)

        # Initialize counts and matched indices for this sample
        exact_tp = 0
        lenient_tp = 0
        exact_fp = 0
        lenient_fp = 0
        exact_fn = 0
        lenient_fn = 0
        
        exact_matched_gold_indices = []
        exact_matched_pred_indices = []
        lenient_matched_gold_indices = []
        lenient_matched_pred_indices = []

        if gold_drug == pred_drug:
            # Compare gold and predicted values for exact match
            for gold_idx, gold_value in enumerate(gold_values):
                exact_match_found = False
                lenient_match_found = False
                for pred_idx, pred_value in enumerate(pred_values):
                    if gold_value == pred_value and pred_idx not in exact_matched_pred_indices:
                        # Exact match{}
                        exact_tp += 1
                        exact_matched_gold_indices.append(gold_idx)
                        exact_matched_pred_indices.append(pred_idx)
                        # Also count exact as lenient match
                        lenient_tp += 1
                        lenient_matched_gold_indices.append(gold_idx)
                        lenient_matched_pred_indices.append(pred_idx)
                        exact_match_found = True
                        break
                # If no exact match, check for lenient match
                if not exact_match_found:
                    for pred_idx, pred_value in enumerate(pred_values):
                        if pred_idx not in lenient_matched_pred_indices:
                            gold_words = gold_value.split()
                            pred_words = pred_value.split()
                            gold_words = [word.strip(string.punctuation).lower() for word in gold_words]
                            pred_words = [word.strip(string.punctuation).lower() for word in pred_words]
                            if any(word in pred_words for word in gold_words):
                                lenient_tp += 1
                                lenient_matched_gold_indices.append(gold_idx)
                                lenient_matched_pred_indices.append(pred_idx)
                                lenient_match_found = True
                                break
                    # Count exact FN if no exact match was found, regardless of lenient match
                    exact_fn += 1
                    if not lenient_match_found:
                        # Increment Lenient_FN if no lenient match was found
                        lenient_fn += 1
        
            # Any unmatched predicted values are false positives
            for pred_idx, pred_value in enumerate(pred_values):
                if pred_idx not in exact_matched_pred_indices:
                    exact_fp += 1
                if pred_idx not in lenient_matched_pred_indices:
                    lenient_fp += 1
        
        else:
            if complete:
                # Drug names do not match; count all gold values as FN and predicted values as FP
                exact_fn += len(gold_values)
                lenient_fn += len(gold_values)
                exact_fp += len(pred_values)
                lenient_fp += len(pred_values)
            else:
                # Drug names do not match; count only per missed drug name as FN and predicted values as FP             
                exact_fn += 1 if len(gold_values) >= 1 else 0 #len(gold_values)
                lenient_fn += 1 if len(gold_values) >= 1 else 0 #len(gold_values)
                exact_fp += 1 if len(pred_values) >= 1 else 0 #len(pred_values)
                lenient_fp += 1 if len(pred_values) >= 1 else 0 #len(pred_values)                
    
        # Update metrics
        metrics[relation_class]['Exact_TP'] += exact_tp
        metrics[relation_class]['Exact_FP'] += exact_fp
        metrics[relation_class]['Exact_FN'] += exact_fn
        
        metrics[relation_class]['Lenient_TP'] += lenient_tp
        metrics[relation_class]['Lenient_FP'] += lenient_fp
        metrics[relation_class]['Lenient_FN'] += lenient_fn        

        # Collect FP
        for pred_idx, pred_value in enumerate(pred_values):
            if pred_idx not in lenient_matched_pred_indices:# or pred_idx not in exact_matched_pred_indices:
                #if gold_drug == pred_drug:
                    false_positives[relation_class].append({
                        'text': text,
                        'triple': (gold_sample[1:], pred_sample[1:]),
                        'drug': pred_drug,
                        'predicted_value': pred_value,
                        'gold_values': gold_values
                    })
            
        # Collect FN
        for gold_idx, gold_value in enumerate(gold_values):
            if gold_idx not in lenient_matched_gold_indices:# or gold_idx not in exact_matched_gold_indices:
                if gold_drug == pred_drug:
                    # Collect False Negatives
                    false_negatives[relation_class].append({
                        'text': text,
                        'triple': (gold_sample[1:], pred_sample[1:]),
                        'drug': gold_drug,
                        'gold_value': gold_value,
                        'predicted_values': pred_values
                    })

    # Write false negatives and positives into a file
    os.makedirs(os.path.dirname("fp_fn/i2b2/"), exist_ok=True)

    with open(f"fp_fn/i2b2/{target_class}_fp.txt", "w") as f_p:
        writer = csv.writer(f_p, quoting=csv.QUOTE_ALL, escapechar='\\')
        for fp in false_positives[target_class]:
            if len(fp['gold_values']) > 0:
                writer.writerow([fp['text'], str(fp['triple'])])

    
    with open(f"fp_fn/i2b2/{target_class}_fn.txt", "w") as f_n:
        writer = csv.writer(f_n, quoting=csv.QUOTE_ALL, escapechar='\\')
        for fn in false_negatives[target_class]:
            if len(fn['predicted_values']) > 0:
                writer.writerow([fn['text'], str(fn['triple'])])
       

    return metrics

def evaluate_model(gold_samples: List[OrderedDict], pred_samples: List[OrderedDict], complete: bool, use_medication_threshold = False, target_class = ""):
    """
    Evaluate the model by computing and printing classification reports for NER and RE.

    Args:
        gold_samples (List[OrderedDict]): List of gold standard samples.
        pred_samples (List[OrderedDict]): List of predicted samples.
    """
    # Extract relations for RE
    gold_relations, pred_relations = extract_relations(gold_samples, pred_samples, use_medication_threshold)  
    
    
    # Compute RE metrics
    re_metrics = compute_re_metrics(gold_relations, pred_relations, classes=RE_CLASSES, complete=complete, target_class = target_class)
    

### Extract false predictions

In [8]:
# Define relation classes
RE_CLASSES = {'ade', 'dosage', 'duration', 'form', 'frequency', 'reason', 'route', 'strength'}

for target_class in RE_CLASSES:
    with open('output_results_llama3_8b_i2b2_pydantic.csv', 'r', newline='') as csvfile:
        reader = csv.DictReader(csvfile, delimiter='|')
        golds = []
        preds = []
        
        for id, row in enumerate(reader):
            text = row['text']

            gold = eval(row['gold'], {"OrderedDict": OrderedDict}, {})
            pred = eval(row['pred'], {"OrderedDict": OrderedDict}, {})

            golds.append((text, gold))
            preds.append((text, pred))
            
        evaluate_model(golds[:], preds[:], complete=True, use_medication_threshold=True, target_class=target_class)


## Apply feedback LLM

### Load model and define prompts

In [None]:
chatbot = pipeline("text-generation", model="unsloth/Mistral-Large-Instruct-2407-bnb-4bit")

system_message = """
You are a doctor specializing in pharmacology. You receive a text along with two triplets containing medication information: the gold standard (Gold) and the model prediction (Pred), which may include false positives and false negatives. Each triplet includes:
- the medication name,
- the category of medication information,
- and a value.

The context of the text should be taken into account to ensure the clinical meaning of the values is fully understood and no subtleties are overlooked.

Your task is to evaluate carefully and conservatively whether the values in the two triplets are clinically comparable. A classification of *SIMILAR* should only be made if the similarity is obvious and clearly clinically meaningful!

Please proceed with the evaluation as follows:
1. Carefully read the text to understand the context of the medication information.
2. Compare the values in the Gold and Pred triplets.
3. Determine if the values are clinically equivalent or comparable in the given context.
4. If the values are clearly similar within the context of the text, provide the result as *Result: SIMILAR*.
5. If the values are not clearly comparable or have a divergent clinical meaning, provide the result as *Result: NOT SIMILAR*.

**Example of structure:**

Text: The patient administers insulin a day to manage blood sugar levels
Triplets: (('insulin', 'frequency', ''), ('insulin', 'frequency', 'day')): The model predicts a false positive. Day is not in the gold standard. Result: NOT SIMILAR

Text: he patient’s methotrexate regimen includes doses on qFri and qSat, administered weekly to ensure consistent therapeutic levels.
Triplets: (('methotrexate', 'frequency', 'qFri, qSat'), ('methotrexate', 'frequency', ['qFri', 'qSat'])): Both values are similar. The mdodel splits the information qFri and qSat into two list items, while the gold standard contains these information in a single string. Result: SIMILAR

"""




Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


### Apply feedback LLM per relation class and store answers in feedback files

#### Collect FPs

In [None]:
output_fp=[]
for target_class in RE_CLASSES:
    print(target_class)
    fp_file_path = f"fp_fn/i2b2/{target_class}_fp.txt"
    
    # Count the total lines in the file to use with tqdmp
    with open(fp_file_path, "r") as f_p:
        total_lines = sum(1 for _ in f_p)

    with open(f"fp_fn/i2b2/{target_class}_fp_feedback.txt", "w") as f_p_feedback:
        with open(fp_file_path, "r") as f_p:
            reader = csv.reader(f_p, quoting=csv.QUOTE_ALL, escapechar='\\')
            for row in tqdm.tqdm(reader, total=total_lines):
                text, triple = row[0], row[1]
                user_content = f"Text: {text}, Triplets: {triple}"
                messages = [
                    {"role": "system", "content": system_message},
                    {"role": "user", "content": user_content},
                ]
                answer = chatbot(messages)[-1]
                response_text = answer['generated_text'][-1]['content']

                output_fp.append(response_text)
                cleaned_response_text = response_text.replace("\n", " ")
                f_p_feedback.write(f"{user_content} ---- {cleaned_response_text}\n")
                f_p_feedback.flush()

#### Collect FNs

In [None]:
output_fn=[]
for target_class in RE_CLASSES:
    print(target_class)
    fn_file_path = f"fp_fn/cardiode/{target_class}_fn.txt"
    
    # Count the total lines in the file to use with tqdmp
    with open(fn_file_path, "r") as f_n:
        total_lines = sum(1 for _ in f_n)

    with open(f"fp_fn/cardiode/{target_class}_fn_feedback.txt", "w") as f_n_feedback:
        with open(fn_file_path, "r") as f_n:
            reader = csv.reader(f_n, quoting=csv.QUOTE_ALL, escapechar='\\')
            for row in tqdm.tqdm(reader, total=total_lines):
                text, triple = row[0], row[1]
                user_content = f"Text: {text}, Triplets: {triple}"
                messages = [
                    {"role": "system", "content": system_message},
                    {"role": "user", "content": user_content},
                ]
                answer = chatbot(messages)[-1]
                response_text = answer['generated_text'][-1]['content']

                output_fn.append(response_text)
                f_n_feedback.write(f"{user_content} ---- {response_text.replace('\n', ' ')}\n")
                f_n_feedback.flush()