In [1]:
import sys, os, pandas as pd, ast, json
from glob import glob
import warnings
warnings.filterwarnings('ignore')


In [2]:
os.chdir('../..')

In [3]:
# ================== GENERAL IMPORTS ==================
import os
import json
from dotenv import load_dotenv

# ================== UTIL FUNCTIONS ==================
from utils.embedding import get_context_db, retrieve_context
from utils.prompt import get_prompt
from llm.run_RAGLLM import run_RAG


# ================== MODEL & API IMPORTS ==================
from mistralai.client import MistralClient
from openai import OpenAI
from llm.inference import run_llm
import faiss


----


In [60]:
#answers
#gold answer is
df = pd.read_csv('external-validation/panel-sequencing/reports/filtered-data/first_line_treatments_post_msk.csv')
answers = []
for f in glob('external-validation/panel-sequencing/reports/test/*.txt'):
    patient = f.split('/')[-1].replace('.txt','')
    treatment = {i.lower() for i in ast.literal_eval(df[df['PATIENT_ID'] == patient]['FIRST_LINE_TREATMENT'].iloc[0])}
    answers.append([patient, treatment])
answers = pd.DataFrame(answers, columns=['PATIENT_ID', 'FIRST_LINE_TREATMENT'])
answers['FIRST_LINE_TREATMENT'] = answers['FIRST_LINE_TREATMENT']
answers = dict(zip(answers['PATIENT_ID'], answers['FIRST_LINE_TREATMENT']))
answers[patient]

{'regorafenib'}

----
### functions

In [100]:
#utils/evaluation.py
import sys
import os
script_dir = os.getcwd()
root_dir = script_dir#os.path.join(os.path.dirname(os.path.abspath(script_dir)))
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(script_dir))))

from sklearn.metrics import precision_score, recall_score, f1_score
from utils.io import load_object
import pandas as pd
import re

In [101]:
# Load brand-generic name mapped dictionary and prompt-groundtruth (generic) mapped dictionary
drug_names_mapping_dict=load_object(filename=os.path.join(root_dir,'data/fda_drug_names_mapping_dict.pkl'))

# Extract ground-truth generic and brand names
all_generic_names_set=set()
for lst in drug_names_mapping_dict.values():
    for name in lst:
        all_generic_names_set.add(frozenset(name))

# Function to extract drug names
def extract_drug_names(drug_lines):
    pattern = r'"drug names?":\s?(?:\[(.*?)\]|"(.*?)"|([^"\[\]]+))'
    cleaned_drug_list = []
    for line in drug_lines:
        match = re.search(pattern, line)
        if match:
            drug_string = next(filter(None, match.groups())) # Extract the first non-empty capturing group
            drug_string = re.sub(r'\s?\([^)]*\)', '', drug_string) # Remove any text within parentheses
            drugs = [drug.strip().strip('"').replace('®', '') for drug in re.split(r'\s*\+\s*|\s*,\s*', drug_string)] # Split by `+` or `,` and clean spaces & quotes
            cleaned_drug_list.append(set(drugs))
    return(cleaned_drug_list)

In [102]:
# Function to check if the model explicitly says there are no FDA-approved drugs.
def is_no_drug_output(output: str) -> bool:
    no_drug_phrases = ["no fda-approved drugs", "none", "no approved therapies", "no therapies available"]
    return any(phrase in output.lower() for phrase in no_drug_phrases)

# Function to evaluate predicted drugs from synthetic queries
def calc_eval_metrics(
    output_test_ls: list[str], 
    query_ls: list[str], 
    prompt_groundtruth_dict: dict[str, list]
    ) -> dict:
    """
    Evaluate predicted drug from LLM.
    
    Arguments: 
        output_test_ls (list): List of full output generated by LLM.
        query_ls (list): List of input prompts.
        prompt_groundtruth_dict (dict): Dictionary of input prompts and matching ground-truth drugs.
    
    """
    exact_match_acc, partial_match_acc = [], []
    precision_ls, recall_ls, f1_ls, specificity_ls = [], [], [], []
    pred_drugs_generic_set_ls = []
    true_drugs_generic_set_ls = []
    
    for i, output in enumerate(output_test_ls):
        
        # Extract drug name lines
        s_split=output.split("\n") # Split each line
        drug_lines=[line.lower() for line in s_split if "Drug Name" in line] # Extract relevant lines
        
        # If the model explicitly returned "none" or "no fda-approved drugs", treat it as empty
        if any(is_no_drug_output(drug) for drug in drug_lines):
            pred_drugs_names_set = set()
        else:
            pred_drugs_names_set = extract_drug_names(drug_lines) # Create a set of predicted individual drugs
        
        # Convert brand names to generic names if there are matching brand names, otherwise just append
        pred_drugs_generic_list=[]
        for subset in pred_drugs_names_set:
            normalized_subset_list=[]
            if all(drug not in drug_names_mapping_dict for drug in subset): # if all drugs are not a brand name in our mapping dict
                normalized_subset_list.extend([subset])
            else:
                for drug in subset:
                    if drug in drug_names_mapping_dict:
                        generic_names=drug_names_mapping_dict[drug]
                        normalized_subset_list.extend(generic_names) 
                    else:
                        normalized_subset_list.extend([{drug}])
            pred_drugs_generic_list.append(normalized_subset_list)
        
        # Convert the list of generic names to unique frozensets 
        pred_drugs_generic_set=set()
        for lst in pred_drugs_generic_list:
            if len(lst) == 1:
                pred_drugs_generic_set.add(frozenset(*lst))
            else:
                for subset in lst:
                    pred_drugs_generic_set.add(frozenset(subset))
        pred_drugs_generic_set_ls.append(pred_drugs_generic_set)
        
        # Convert the list of ground-truth generic names to unique frozensets
        true_drugs_generic_set = {frozenset(_set) for _set in prompt_groundtruth_dict[query_ls[i]]}
        true_drugs_generic_set_ls.append(true_drugs_generic_set)
        
        # Cases with ground-truth therapies
        if len(true_drugs_generic_set) != 0:
            
            # Compute exact match accuracy (if all true drugs are in the predicted drug output)
            exact_match_acc.append(all(subset in pred_drugs_generic_set for subset in true_drugs_generic_set))
            
            # Compute partial match accuracy (if one or more true drugs are in the predicted drug output)
            partial_match_acc.append(len(pred_drugs_generic_set & true_drugs_generic_set) > 0)
            
            # All possible FDA-approved drugs
            all_drugs_set = true_drugs_generic_set | pred_drugs_generic_set | all_generic_names_set
            
            # Calculate true positive, false positive, false negative, true negative
            tp = len(pred_drugs_generic_set.intersection(true_drugs_generic_set))  
            fp = len(pred_drugs_generic_set - true_drugs_generic_set)  
            fn = len(true_drugs_generic_set - pred_drugs_generic_set)  
            tn = len(all_drugs_set - true_drugs_generic_set - pred_drugs_generic_set)  
            
            # Calculate precision, recall, and F1 scores
            precision = tp / (tp + fp) if (tp + fp) > 0 else 0
            recall = tp / (tp + fn) if (tp + fn) > 0 else 0
            f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
            specificity = tn / (tn + fp) if (tn + fp) > 0 else 0  
            precision_ls.append(precision)
            recall_ls.append(recall)
            f1_ls.append(f1)
            specificity_ls.append(specificity) 
        
        # Cases with no ground-truth therapies (no FDA-approved drugs available)
        else:
            partial_match_acc.append(True if not true_drugs_generic_set and not pred_drugs_generic_set else False) 
            exact_match_acc.append(True if not true_drugs_generic_set and not pred_drugs_generic_set else False) 
            precision_ls.append(None)
            recall_ls.append(None)
            f1_ls.append(None)
            specificity_ls.append(None)

    def filtered_res(lst):
            filtered_res = [x for x in lst if x is not None]  
            return(filtered_res)

    avg_exact_match_acc=sum(x is True for x in exact_match_acc)/len(exact_match_acc)
    avg_partial_match_acc=sum(x is True for x in partial_match_acc)/len(partial_match_acc)
    filtered_precision=filtered_res(precision_ls)
    filtered_recall=filtered_res(recall_ls)
    filtered_f1=filtered_res(f1_ls)
    filtered_specificity=filtered_res(specificity_ls)
    avg_precision=sum(filtered_precision)/len(filtered_precision)
    avg_recall=sum(filtered_recall)/len(filtered_recall)
    avg_f1=sum(filtered_f1)/len(filtered_f1)
    avg_specificity=sum(filtered_specificity)/len(filtered_specificity)

    result={
        'avg_exact_match_acc':avg_exact_match_acc,
        'avg_partial_match_acc':avg_partial_match_acc,
        'avg_precision':avg_precision,
        'avg_recall':avg_recall,
        'avg_f1':avg_f1,
        'avg_specificity':avg_specificity,
        'exact_match_acc':exact_match_acc,
        'partial_match_acc':partial_match_acc,
        'precision_ls':precision_ls,
        'recall_ls':recall_ls,
        'f1_ls':f1_ls,
        'specificity_ls':specificity_ls,
        'pred_drugs_generic_set_ls':pred_drugs_generic_set_ls,
        'true_drugs_generic_set_ls':true_drugs_generic_set_ls
    }
    
    return(result)

---
### analysis

In [106]:
def read_output(file_path):
    output_list = []
    for f in glob(file_path):
        with open(f, 'r') as file:
            output_list.append(json.load(file)['response'])
    return output_list
def read_prompt(file_path):
    prompt_list = []
    for f in glob(file_path):
        with open(f, 'r') as file:
            prompt_list.append(json.load(file)['prompt'])
    return prompt_list
def read_truth(file_path):
    truth = {}
    for f in glob(file_path):
        patient = f.split('/')[-1].replace('.json','')
        with open(f, 'r') as file:
            truth[json.load(file)['prompt']] = [answers[patient]]
    return truth

In [110]:
for file_path in ['external-validation/panel-sequencing/experiments/basic-llm-prompt0/*',
                  'external-validation/panel-sequencing/experiments/basic-llm-prompt5/*',
                  'external-validation/panel-sequencing/experiments/basic-rag-prompt0/*',
                  'external-validation/panel-sequencing/experiments/basic-rag-prompt5/*']:
    eval = calc_eval_metrics(read_output(file_path), read_prompt(file_path), read_truth(file_path))
    print(file_path.split('/')[-2])
    print('Partial Match:', eval['avg_partial_match_acc'])
    print('Exact Match:', eval['avg_exact_match_acc'])
    print('=========')

basic-llm-prompt0
Partial Match: 0.12244897959183673
Exact Match: 0.12244897959183673
basic-llm-prompt5
Partial Match: 0.02040816326530612
Exact Match: 0.02040816326530612
basic-rag-prompt0
Partial Match: 0.02040816326530612
Exact Match: 0.02040816326530612
basic-rag-prompt5
Partial Match: 0.04081632653061224
Exact Match: 0.04081632653061224
