In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json 
import pandas as pd 
from dotenv import load_dotenv
from model_list import models 
import chromadb
import openai
import module 
import os 
from tqdm import tqdm 

In [13]:
load_dotenv()
os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY_TEAM')
os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN')

In [4]:
data_pub = pd.read_csv('../data/CT-Pub-With-Examples-Corrected.csv')
data_pub.head(3)

Unnamed: 0,NCTId,BriefTitle,EligibilityCriteria,BriefSummary,Conditions,Interventions,PrimaryOutcomes,TrialGroup,API_BaselineMeasures,API_BaselineMeasures_Corrected,Paper_BaselineMeasures,Paper_BaselineMeasures_Corrected
0,NCT00000620,Action to Control Cardiovascular Risk in Diabe...,Inclusion Criteria:\n\n* Diagnosed with type 2...,The purpose of this study is to prevent major ...,"Atherosclerosis, Cardiovascular Diseases, Hype...","Anti-hyperglycemic Agents, Anti-hypertensive A...",First Occurrence of a Major Cardiovascular Eve...,hypertension,"Age, Continuous, Gender, Ethnicity (NIH/OMB), ...","`Age, Continuous`, `Gender`, `Ethnicity (NIH/O...","Age, Female sex, Median duration of diabetes, ...","`Age`, `Female sex`, `Median duration of diabe..."
1,NCT00126737,Home-Based Exercise and Weight Control Program...,Inclusion Criteria:\n\n* Male \& female 50 yea...,The purpose of this study is to determine whet...,"Chronic Diseases, Obesity, Osteoarthritis, Pain,","Weight Control Nutritional Program, Home-based...","WOMAC Function, Physical Scale SF-36v, Mental ...",obesity,"Age, Continuous, Sex: Female, Male, Race/Ethni...","`Age, Continuous`, `Sex: Female, Male`, `Race/...","Age, Duration of OA, Kellgren-Lawrence Classif...","`Age`, `Duration of OA`, `Kellgren-Lawrence Cl..."
2,NCT00283686,HALT Progression of Polycystic Kidney Disease ...,Inclusion Criteria:\n\n* Diagnosis of ADPKD.\n...,The efficacy of interruption of the renin-angi...,"Kidney, Polycystic,","Lisinopril, Telmisartan, Placebo, Standard Blo...",Study A: Percent Annual Change in Total Kidney...,hypertension,"Age, Continuous, Sex: Female, Male, Race (NIH/...","`Age, Continuous`, `Sex: Female, Male`, `Race ...","Age, Weight, Height, BMI, BSA, SBP, DBP, Liver...","`Age`, `Weight`, `Height`, `BMI`, `BSA`, `SBP`..."


In [5]:
#main_path = '/Users/nafisneehal/Desktop/CTBench_RAG/'

# Initialize ChromaDB
# Directory containing JSON files
json_directory = '../data/ctpub_json'
vector_store_path = '../chroma_db'

# #initialize client 
c_db = chromadb.PersistentClient(path=vector_store_path)
#get collection
chroma_collection = c_db.get_or_create_collection("ctbench_rag")



In [6]:
def json_file_to_rag_query(file_name):
    with open(f"../data/ctpub_json/{file_name}", "r") as file:
        trial_data = json.load(file)
        trial_query = f"""
        BriefTitle: {trial_data['BriefTitle']}\n
        EligibilityCriteria: {trial_data['EligibilityCriteria']}\n 
        BriefSummary: {trial_data['BriefSummary']}\n
        Conditions: {trial_data['Conditions']}\n
        Interventions: {trial_data['Interventions']}\n
        PrimaryOutcomes: {trial_data['PrimaryOutcomes']}
        """
        return trial_query

# GPT-4o Generation with RAG (3 shot) 

In [None]:
number_of_similar_trials = 3
model_name   = models['gpt-4o']

data_pub['gpt4o_rag_ts_gen'] = None 

for index, row in tqdm(data_pub.iterrows()):

    #print(row['NCTId'])

    #AVOID these trials that were used as three-shot examples only
    example_trials_to_avoid = ['NCT00000620', 'NCT01483560', 'NCT04280783']
    if row['NCTId'] in example_trials_to_avoid:
        continue

    file_name = row['NCTId'] + '.json'
    query_trial = json_file_to_rag_query(file_name) #string query for searching rag database only

    #run rag query and find the filenames of similar trials
    similar_trials = module.query_clinical_trials_using_llamaindex(chroma_collection, vector_store_path, json_directory, 
                                                                   query_trial, top_k=number_of_similar_trials)
    
    # print(similar_trials)

    #get system message, question 
    system_message, question = module.build_three_shot_prompt(row, similar_trials, json_directory, 
                                                              ref_col_name='Paper_BaselineMeasures_Corrected')

    model_query = module.system_user_template(system_message, question)
    model_response = module.run_generation_single_openai(model_query, model_name = model_name, 
                                                openai_token = os.environ["OPENAI_API_KEY"], temperature=0.0)

    
    data_pub.at[index, 'gpt4o_rag_ts_gen'] = model_response
    
    #print 
    # print(system_message)
    # print(question)
    # print(model_response)
    # break

# Llama3 Generation with RAG (3-shot) 

In [19]:
number_of_similar_trials = 3
model_name   = models['gpt-4o']
model_hf_endpoint = models['llama3-70b-it']

data_pub['llama3_70b_it_rag_ts_gen'] = None 

for index, row in tqdm(data_pub.iterrows()):

    #print(row['NCTId'])

    #AVOID these trials that were used as three-shot examples only
    example_trials_to_avoid = ['NCT00000620', 'NCT01483560', 'NCT04280783']
    if row['NCTId'] in example_trials_to_avoid:
        continue

    file_name = row['NCTId'] + '.json'
    query_trial = json_file_to_rag_query(file_name) #string query for searching rag database only

    try:
        #run rag query and find the filenames of similar trials
        similar_trials = module.query_clinical_trials_using_llamaindex(chroma_collection, vector_store_path, json_directory, 
                                                                    query_trial, top_k=number_of_similar_trials)
        
        # print(similar_trials)

        #get system message, question 
        system_message, question = module.build_three_shot_prompt(row, similar_trials, json_directory, 
                                                                ref_col_name='Paper_BaselineMeasures_Corrected')

        model_query = module.system_user_template(system_message, question)
        model_response = module.run_generation_single_hf_models(model_query, model_hf_endpoint, 
                                                        os.environ['HF_TOKEN'], temperature=0.0)

        
        data_pub.at[index, 'llama3_70b_it_rag_ts_gen'] = model_response
    
    #Note: Weird Bug: UnprocessableEntityError: Error code: 422 - {'error': 'Input validation error: `inputs` tokens + `max_new_tokens` must be <= 8192. Given: 7573 `inputs` tokens and 1000 `max_new_tokens`', 'error_type': 'validation'}
    except Exception as e:
        print(f"An error occurred while processing trial {row['NCTId']}: {e}")
        continue


    #print 
    # print(system_message)
    # print(question)
    # print(model_response)
    # break

44it [00:21,  2.00it/s]

An error occurred while processing trial NCT01767155: Error code: 422 - {'error': 'Input validation error: `inputs` tokens + `max_new_tokens` must be <= 8192. Given: 7573 `inputs` tokens and 1000 `max_new_tokens`', 'error_type': 'validation'}


70it [00:46,  1.58s/it]

An error occurred while processing trial NCT02643966: Error code: 422 - {'error': 'Input validation error: `inputs` tokens + `max_new_tokens` must be <= 8192. Given: 7236 `inputs` tokens and 1000 `max_new_tokens`', 'error_type': 'validation'}


94it [01:51,  2.42s/it]

An error occurred while processing trial NCT03394027: Error code: 422 - {'error': 'Input validation error: `inputs` tokens + `max_new_tokens` must be <= 8192. Given: 7582 `inputs` tokens and 1000 `max_new_tokens`', 'error_type': 'validation'}


96it [01:53,  1.69s/it]

An error occurred while processing trial NCT03674112: Error code: 422 - {'error': 'Input validation error: `inputs` tokens + `max_new_tokens` must be <= 8192. Given: 7546 `inputs` tokens and 1000 `max_new_tokens`', 'error_type': 'validation'}


103it [02:08,  1.25s/it]


In [20]:
data_pub.head(3)

Unnamed: 0,NCTId,BriefTitle,EligibilityCriteria,BriefSummary,Conditions,Interventions,PrimaryOutcomes,TrialGroup,API_BaselineMeasures,API_BaselineMeasures_Corrected,Paper_BaselineMeasures,Paper_BaselineMeasures_Corrected,gpt4o_rag_ts_gen,llama3_70b_it_rag_ts_gen
0,NCT00000620,Action to Control Cardiovascular Risk in Diabe...,Inclusion Criteria:\n\n* Diagnosed with type 2...,The purpose of this study is to prevent major ...,"Atherosclerosis, Cardiovascular Diseases, Hype...","Anti-hyperglycemic Agents, Anti-hypertensive A...",First Occurrence of a Major Cardiovascular Eve...,hypertension,"Age, Continuous, Gender, Ethnicity (NIH/OMB), ...","`Age, Continuous`, `Gender`, `Ethnicity (NIH/O...","Age, Female sex, Median duration of diabetes, ...","`Age`, `Female sex`, `Median duration of diabe...",,
1,NCT00126737,Home-Based Exercise and Weight Control Program...,Inclusion Criteria:\n\n* Male \& female 50 yea...,The purpose of this study is to determine whet...,"Chronic Diseases, Obesity, Osteoarthritis, Pain,","Weight Control Nutritional Program, Home-based...","WOMAC Function, Physical Scale SF-36v, Mental ...",obesity,"Age, Continuous, Sex: Female, Male, Race/Ethni...","`Age, Continuous`, `Sex: Female, Male`, `Race/...","Age, Duration of OA, Kellgren-Lawrence Classif...","`Age`, `Duration of OA`, `Kellgren-Lawrence Cl...","`Age`, `Sex`, `Race`, `BMI`, `Kellgren and Law...","`Age`, `Sex`, `Race`, `Education`, `Employment..."
2,NCT00283686,HALT Progression of Polycystic Kidney Disease ...,Inclusion Criteria:\n\n* Diagnosis of ADPKD.\n...,The efficacy of interruption of the renin-angi...,"Kidney, Polycystic,","Lisinopril, Telmisartan, Placebo, Standard Blo...",Study A: Percent Annual Change in Total Kidney...,hypertension,"Age, Continuous, Sex: Female, Male, Race (NIH/...","`Age, Continuous`, `Sex: Female, Male`, `Race ...","Age, Weight, Height, BMI, BSA, SBP, DBP, Liver...","`Age`, `Weight`, `Height`, `BMI`, `BSA`, `SBP`...","`Age`, `Sex`, `Race`, `Ethnicity`, `GFR catego...","`Age`, `Sex`, `Race`, `Ethnicity`, `Body Mass ..."


In [21]:
data_pub.to_csv('../data/hidden_data/CT-Pub-With-Examples-Corrected-all-rag-gen.csv', index=False)

## GPT4 Evaluation

In [23]:
eval_data_pub = pd.DataFrame()
eval_data_pub['NCTId'] = data_pub['NCTId']

In [27]:
data_pub.head(3)

Unnamed: 0,NCTId,BriefTitle,EligibilityCriteria,BriefSummary,Conditions,Interventions,PrimaryOutcomes,TrialGroup,API_BaselineMeasures,API_BaselineMeasures_Corrected,Paper_BaselineMeasures,Paper_BaselineMeasures_Corrected,gpt4o_rag_ts_gen,llama3_70b_it_rag_ts_gen
0,NCT00000620,Action to Control Cardiovascular Risk in Diabe...,Inclusion Criteria:\n\n* Diagnosed with type 2...,The purpose of this study is to prevent major ...,"Atherosclerosis, Cardiovascular Diseases, Hype...","Anti-hyperglycemic Agents, Anti-hypertensive A...",First Occurrence of a Major Cardiovascular Eve...,hypertension,"Age, Continuous, Gender, Ethnicity (NIH/OMB), ...","`Age, Continuous`, `Gender`, `Ethnicity (NIH/O...","Age, Female sex, Median duration of diabetes, ...","`Age`, `Female sex`, `Median duration of diabe...",,
1,NCT00126737,Home-Based Exercise and Weight Control Program...,Inclusion Criteria:\n\n* Male \& female 50 yea...,The purpose of this study is to determine whet...,"Chronic Diseases, Obesity, Osteoarthritis, Pain,","Weight Control Nutritional Program, Home-based...","WOMAC Function, Physical Scale SF-36v, Mental ...",obesity,"Age, Continuous, Sex: Female, Male, Race/Ethni...","`Age, Continuous`, `Sex: Female, Male`, `Race/...","Age, Duration of OA, Kellgren-Lawrence Classif...","`Age`, `Duration of OA`, `Kellgren-Lawrence Cl...","`Age`, `Sex`, `Race`, `BMI`, `Kellgren and Law...","`Age`, `Sex`, `Race`, `Education`, `Employment..."
2,NCT00283686,HALT Progression of Polycystic Kidney Disease ...,Inclusion Criteria:\n\n* Diagnosis of ADPKD.\n...,The efficacy of interruption of the renin-angi...,"Kidney, Polycystic,","Lisinopril, Telmisartan, Placebo, Standard Blo...",Study A: Percent Annual Change in Total Kidney...,hypertension,"Age, Continuous, Sex: Female, Male, Race (NIH/...","`Age, Continuous`, `Sex: Female, Male`, `Race ...","Age, Weight, Height, BMI, BSA, SBP, DBP, Liver...","`Age`, `Weight`, `Height`, `BMI`, `BSA`, `SBP`...","`Age`, `Sex`, `Race`, `Ethnicity`, `GFR catego...","`Age`, `Sex`, `Race`, `Ethnicity`, `Body Mass ..."


In [30]:
import json
from tqdm import tqdm

for index, row in tqdm(data_pub.iterrows()):

    example_trials_to_avoid = ['NCT00000620', 'NCT01483560', 'NCT04280783']
    if row['NCTId'] in example_trials_to_avoid:
        continue

    qstart = module.get_question_from_row(row)

    if row['llama3_70b_it_rag_ts_gen'] is None:
        continue

    reference_list = module.extract_elements_v2(row['Paper_BaselineMeasures_Corrected'])
    
    candidate_gts = module.extract_elements_v2(row['gpt4o_rag_ts_gen'])
    candidate_lts = module.extract_elements_v2(row['llama3_70b_it_rag_ts_gen'])
    
    system_message_gts, question_gts = module.build_gpt4_eval_prompt(reference_list,
                                                            candidate_gts,
                                                            qstart)
    
    
    system_message_lts, question_lts = module.build_gpt4_eval_prompt(reference_list,
                                                            candidate_lts,
                                                            qstart)

    
    eval_model_response_gts = module.run_evaluation_with_gpt4o(system_message_gts, question_gts, os.environ["OPENAI_API_KEY"])
    eval_model_response_lts = module.run_evaluation_with_gpt4o(system_message_lts, question_lts, os.environ["OPENAI_API_KEY"])
    
    #Convert eval_model_response to a JSON string and store in the dataframe
    eval_data_pub.at[index, 'gpt4o_ts_gen_matches'] = eval_model_response_gts
    eval_data_pub.at[index, 'llama3_70b_it_ts_gen_matches'] = eval_model_response_lts

    #break
    


103it [10:12,  5.95s/it]


In [32]:
eval_data_pub.to_csv('../data/hidden_data/CT-Pub-With-Examples-Corrected-all-rag-eval.csv', index=False)    

In [33]:
eval_data_pub.head(3)

Unnamed: 0,NCTId,gpt4o_ts_gen_matches,llama3_70b_it_ts_gen_matches
0,NCT00000620,,
1,NCT00126737,"{\n ""matched_features"": [\n [""Age"", ...","{\n ""matched_features"": [\n [""Age"", ..."
2,NCT00283686,"{\n ""matched_features"": [\n [""Age"", ...","{\n ""matched_features"": [\n [""Age"", ..."


In [37]:
print(eval_data_pub.at[2, 'gpt4o_ts_gen_matches'])

{
    "matched_features": [
        ["Age", "Age"],
        ["SBP", "Blood Pressure"],
        ["DBP", "Blood Pressure"],
        ["eGFR", "GFR category"],
        ["serum albumin", "Albumin-to-Creatinine Ratio (ACR)"],
        ["serum potassium", "Serum potassium level"]
    ],
    "remaining_reference_features": [
        "Weight",
        "Height",
        "BMI",
        "BSA",
        "Liver cysts",
        "Liver volume",
        "height adjusted liver volume",
        "liver cyst volume",
        "height adjusted liver cyst volume",
        "liver parenchymal volume",
        "height adjusted liver parnehcymal volume",
        "total kidney volume",
        "height adjusted total kidney volume",
        "spleen volume",
        "height adjusted spleen volume",
        "serum sodium",
        "hemoglobin",
        "WBC",
        "platelets",
        "AST",
        "ALT",
        "alkaline phosphatase",
        "bilirubin",
        "physical functioning QOL",
        "physical role

## Score Calculation

In [35]:
score_df = pd.DataFrame()
score_df['NCTId'] = data_pub['NCTId']

In [None]:
import json 

gts_sum = {'precision':0, 'recall':0, 'f1':0}
lts_sum = {'precision':0, 'recall':0, 'f1':0}

total_count = 0

for index, row in eval_data_pub.iterrows():
    avoid_ids = ['NCT00000620', 'NCT01483560', 'NCT04280783'] #these were used as examples for 3-shot generation
    if row['NCTId'] in avoid_ids:
        continue
    if not row['llama3_70b_it_ts_gen_matches'] or not row['gpt4o_ts_gen_matches']:
        continue

    print(row['NCTId'])
    
    try:
        # gts
        gts_dict = json.loads(row['gpt4o_ts_gen_matches'])
        gts_matches = gts_dict['matched_features']
        gts_remaining_references = gts_dict['remaining_reference_features']
        gts_remaining_candidates = gts_dict['remaining_candidate_features']
        gts_score = module.match_to_score(gts_matches, gts_remaining_references, gts_remaining_candidates)
        score_df.at[index, 'gpt4o_ts_gen_scores'] = json.dumps(gts_score)
        gts_sum['precision'] += gts_score['precision']
        gts_sum['recall'] += gts_score['recall']
        gts_sum['f1'] += gts_score['f1']

        # lts 
        lts_dict = json.loads(row['llama3_70b_it_ts_gen_matches'])
        lts_matches = lts_dict['matched_features']
        lts_remaining_references = lts_dict['remaining_reference_features']
        lts_remaining_candidates = lts_dict['remaining_candidate_features']
        lts_score = module.match_to_score(lts_matches, lts_remaining_references, lts_remaining_candidates)
        score_df.at[index, 'llama3_70b_it_ts_gen_scores'] = json.dumps(lts_score)
        lts_sum['precision'] += lts_score['precision']
        lts_sum['recall'] += lts_score['recall']
        lts_sum['f1'] += lts_score['f1']

        total_count += 1

    except Exception as e:
        print(f"An error occurred while processing trial {row['NCTId']}: {e}")
        continue

    #break



In [46]:
score_df.head()

Unnamed: 0,NCTId,gpt4o_ts_gen_scores,llama3_70b_it_ts_gen_scores
0,NCT00000620,,
1,NCT00126737,"{""precision"": 0.29411764705882354, ""recall"": 0...","{""precision"": 0.3333333333333333, ""recall"": 0...."
2,NCT00283686,"{""precision"": 0.35294117647058826, ""recall"": 0...","{""precision"": 0.2, ""recall"": 0.078947368421052..."
3,NCT00329030,"{""precision"": 0.4166666666666667, ""recall"": 0....","{""precision"": 0.3076923076923077, ""recall"": 0...."
4,NCT00360334,"{""precision"": 0.4444444444444444, ""recall"": 0....","{""precision"": 0.7272727272727273, ""recall"": 0...."


In [47]:
print(f"GPT-4o Three Shot Scores: Precision={gts_sum['precision']/total_count} Recall={gts_sum['recall']/total_count} F1={gts_sum['f1']/total_count}")
print(f"Llama-3 70B Instruct Three Shot Scores: Precision={lts_sum['precision']/total_count} Recall={lts_sum['recall']/total_count} F1={lts_sum['f1']/total_count}")


GPT-4o Three Shot Scores: Precision=0.44889296382376 Recall=0.58177675886661 F1=0.480418600379706
Llama-3 70B Instruct Three Shot Scores: Precision=0.4875790637155919 Recall=0.540484539473257 F1=0.483547904868379
