## We are running CT_Pub data using ct_repo indexed stores

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json 
import pandas as pd 
from dotenv import load_dotenv
from model_list import models 
import chromadb
import openai
import module 
import os 
from tqdm import tqdm 

In [3]:
load_dotenv()
os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY_TEAM')
os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN')

In [4]:
data = pd.read_csv('../data/CT-Repo-With-Examples-Corrected-500-Sample.csv')
data.head(3)

Unnamed: 0,NCTId,BriefTitle,EligibilityCriteria,BriefSummary,Conditions,Interventions,PrimaryOutcomes,TrialGroup,API_BaselineMeasures,API_BaselineMeasures_Corrected
0,NCT01743001,Clinical Study to Evaluate the Effects of Maci...,Inclusion Criteria:\n\n* Subjects:\n\n * not ...,"Clinical study to assess the efficacy, safety,...","Pulmonary Arterial Hypertension,","Macitentan 10 mg, Placebo,",Change From Baseline to Week 16 in Exercise Ca...,hypertension,"Age, Continuous, Age, Customized, Sex: Female,...","`Age, Continuous`, `Age, Customized`, `Sex: Fe..."
1,NCT03439189,Companion Protocol for Methacetin Breath Test ...,Inclusion Criteria:\n\n1. Male or female subje...,To validate the ability of the Methacetin Brea...,"NASH - Nonalcoholic Steatohepatitis, Cirrhosis...","Methacetin Breath Test, Emricasan, Placebo ora...",Number of Subjects With Matched Clinically Sig...,hypertension,"Age, Continuous, Sex: Female, Male, Ethnicity ...","`Age, Continuous`, `Sex: Female, Male`, `Ethni..."
2,NCT00458003,Phenylephrine in Spinal Anesthesia in Preeclam...,Inclusion Criteria:\n\n* ASA PS II - III women...,Hypotension remains a common clinical problem ...,"Preeclampsia, Hypotension,","Ephedrine, Phenylephrine,","The Umbilical Artery pH,",hypertension,"Age, Categorical, Sex: Female, Male, Region of...","`Age, Categorical`, `Sex: Female, Male`, `Regi..."


In [5]:
#main_path = '/Users/nafisneehal/Desktop/CTBench_RAG/'

# Initialize ChromaDB
# Directory containing JSON files
input_json_directory = '../data/ctrepo_500_sample_json' #source of input json files
rag_json_directory = '../data/ctrepo_json' #source of example json files
vector_store_path = '../chroma_db_ctrepo' #search the ctrepo vector store for examples 

# #initialize client 
c_db = chromadb.PersistentClient(path=vector_store_path)
#get collection
chroma_collection = c_db.get_or_create_collection("ctrepo_all") #located in df_to_json_and_indexing.ipynb file 



In [6]:
def json_file_to_rag_query(file_name, json_directory):
    with open(f"{json_directory}/{file_name}", "r") as file:
        trial_data = json.load(file)
        trial_query = f"""
        BriefTitle: {trial_data['BriefTitle']}\n
        EligibilityCriteria: {trial_data['EligibilityCriteria']}\n 
        BriefSummary: {trial_data['BriefSummary']}\n
        Conditions: {trial_data['Conditions']}\n
        Interventions: {trial_data['Interventions']}\n
        PrimaryOutcomes: {trial_data['PrimaryOutcomes']}
        """
        return trial_query

# GPT-4o Generation with RAG (3 shot) 

In [7]:
number_of_similar_trials = 3
model_name   = models['gpt-4o']

data['gpt4o_rag_ts_gen'] = None 

for index, row in tqdm(data.iterrows()):

    #print(row['NCTId'])

    #AVOID these trials that were used as three-shot examples only
    example_trials_to_avoid = ['NCT00000620', 'NCT01483560', 'NCT04280783']
    if row['NCTId'] in example_trials_to_avoid:
        continue

    file_name = row['NCTId'] + '.json'
    query_trial = json_file_to_rag_query(file_name, input_json_directory) #string query for searching rag database only

    #run rag query and find the filenames of similar trials
    similar_trials = module.query_clinical_trials_using_llamaindex(chroma_collection, vector_store_path,  
                                                                   query_trial, top_k=number_of_similar_trials)
    
    #print(similar_trials)

    #get system message, question 
    system_message, question = module.build_three_shot_prompt(row, similar_trials, rag_json_directory, 
                                                              ref_col_name='API_BaselineMeasures_Corrected') #asking ct_repo examples 

    model_query = module.system_user_template(system_message, question)
    model_response = module.run_generation_single_openai(model_query, model_name = model_name, 
                                                openai_token = os.environ["OPENAI_API_KEY"], temperature=0.0)

    
    data.at[index, 'gpt4o_rag_ts_gen'] = model_response
    
    # print 
    # print(system_message)
    # print(question)
    # print(model_response)
    # break

500it [29:27,  3.54s/it]


In [8]:
data.head(3)

Unnamed: 0,NCTId,BriefTitle,EligibilityCriteria,BriefSummary,Conditions,Interventions,PrimaryOutcomes,TrialGroup,API_BaselineMeasures,API_BaselineMeasures_Corrected,gpt4o_rag_ts_gen
0,NCT01743001,Clinical Study to Evaluate the Effects of Maci...,Inclusion Criteria:\n\n* Subjects:\n\n * not ...,"Clinical study to assess the efficacy, safety,...","Pulmonary Arterial Hypertension,","Macitentan 10 mg, Placebo,",Change From Baseline to Week 16 in Exercise Ca...,hypertension,"Age, Continuous, Age, Customized, Sex: Female,...","`Age, Continuous`, `Age, Customized`, `Sex: Fe...","`Age, Continuous`, `Age, Categorical`, `Sex: F..."
1,NCT03439189,Companion Protocol for Methacetin Breath Test ...,Inclusion Criteria:\n\n1. Male or female subje...,To validate the ability of the Methacetin Brea...,"NASH - Nonalcoholic Steatohepatitis, Cirrhosis...","Methacetin Breath Test, Emricasan, Placebo ora...",Number of Subjects With Matched Clinically Sig...,hypertension,"Age, Continuous, Sex: Female, Male, Ethnicity ...","`Age, Continuous`, `Sex: Female, Male`, `Ethni...","`Age, Continuous`, `Sex: Female, Male`, `Ethni..."
2,NCT00458003,Phenylephrine in Spinal Anesthesia in Preeclam...,Inclusion Criteria:\n\n* ASA PS II - III women...,Hypotension remains a common clinical problem ...,"Preeclampsia, Hypotension,","Ephedrine, Phenylephrine,","The Umbilical Artery pH,",hypertension,"Age, Categorical, Sex: Female, Male, Region of...","`Age, Categorical`, `Sex: Female, Male`, `Regi...","`Age, Continuous`, `Sex: Female, Male`, `Race/..."


# Llama3 Generation with RAG (3-shot) 

In [9]:
number_of_similar_trials = 3
model_name   = models['gpt-4o']
model_hf_endpoint = models['llama3-70b-it']

data['llama3_70b_it_rag_ts_gen'] = None 

for index, row in tqdm(data.iterrows()):

    #print(row['NCTId'])

    #AVOID these trials that were used as three-shot examples only
    example_trials_to_avoid = ['NCT00000620', 'NCT01483560', 'NCT04280783']
    if row['NCTId'] in example_trials_to_avoid:
        continue

    file_name = row['NCTId'] + '.json'
    query_trial = json_file_to_rag_query(file_name, input_json_directory) #string query for searching rag database only

    try:
        #run rag query and find the filenames of similar trials
        similar_trials = module.query_clinical_trials_using_llamaindex(chroma_collection, vector_store_path,  
                                                                    query_trial, top_k=number_of_similar_trials)
        
        # print(similar_trials)

        #get system message, question 
        system_message, question = module.build_three_shot_prompt(row, similar_trials, rag_json_directory, 
                                                                ref_col_name='API_BaselineMeasures_Corrected') #which column to use for reference in RAG examples from ctrepo

        model_query = module.system_user_template(system_message, question)
        model_response = module.run_generation_single_hf_models(model_query, model_hf_endpoint, 
                                                        os.environ['HF_TOKEN'], temperature=0.0)

        
        data.at[index, 'llama3_70b_it_rag_ts_gen'] = model_response
    
    #Note: Weird Bug: UnprocessableEntityError: Error code: 422 - {'error': 'Input validation error: `inputs` tokens + `max_new_tokens` must be <= 8192. Given: 7573 `inputs` tokens and 1000 `max_new_tokens`', 'error_type': 'validation'}
    except Exception as e:
        print(f"An error occurred while processing trial {row['NCTId']}: {e}")
        continue


    #print 
    # print(system_message)
    # print(question)
    # print(model_response)
    # break

107it [05:39,  2.02s/it]

An error occurred while processing trial NCT02830880: Error code: 422 - {'error': 'Input validation error: `inputs` tokens + `max_new_tokens` must be <= 8192. Given: 8218 `inputs` tokens and 1000 `max_new_tokens`', 'error_type': 'validation'}


133it [07:02,  2.34s/it]

An error occurred while processing trial NCT03463265: Error code: 422 - {'error': 'Input validation error: `inputs` tokens + `max_new_tokens` must be <= 8192. Given: 7369 `inputs` tokens and 1000 `max_new_tokens`', 'error_type': 'validation'}


154it [08:06,  2.24s/it]

An error occurred while processing trial NCT01683994: Error code: 422 - {'error': 'Input validation error: `inputs` tokens + `max_new_tokens` must be <= 8192. Given: 8352 `inputs` tokens and 1000 `max_new_tokens`', 'error_type': 'validation'}


160it [08:21,  2.09s/it]

An error occurred while processing trial NCT02538510: Error code: 422 - {'error': 'Input validation error: `inputs` tokens + `max_new_tokens` must be <= 8192. Given: 7222 `inputs` tokens and 1000 `max_new_tokens`', 'error_type': 'validation'}


193it [10:24,  2.24s/it]

An error occurred while processing trial NCT01505868: Error code: 422 - {'error': 'Input validation error: `inputs` tokens + `max_new_tokens` must be <= 8192. Given: 8885 `inputs` tokens and 1000 `max_new_tokens`', 'error_type': 'validation'}


200it [10:48,  2.64s/it]

An error occurred while processing trial NCT01514201: Error code: 422 - {'error': 'Input validation error: `inputs` tokens + `max_new_tokens` must be <= 8192. Given: 7475 `inputs` tokens and 1000 `max_new_tokens`', 'error_type': 'validation'}


500it [26:11,  3.14s/it]


In [10]:
data.head(3)

Unnamed: 0,NCTId,BriefTitle,EligibilityCriteria,BriefSummary,Conditions,Interventions,PrimaryOutcomes,TrialGroup,API_BaselineMeasures,API_BaselineMeasures_Corrected,gpt4o_rag_ts_gen,llama3_70b_it_rag_ts_gen
0,NCT01743001,Clinical Study to Evaluate the Effects of Maci...,Inclusion Criteria:\n\n* Subjects:\n\n * not ...,"Clinical study to assess the efficacy, safety,...","Pulmonary Arterial Hypertension,","Macitentan 10 mg, Placebo,",Change From Baseline to Week 16 in Exercise Ca...,hypertension,"Age, Continuous, Age, Customized, Sex: Female,...","`Age, Continuous`, `Age, Customized`, `Sex: Fe...","`Age, Continuous`, `Age, Categorical`, `Sex: F...","`Age, Categorical`, `Age, Continuous`, `Sex: F..."
1,NCT03439189,Companion Protocol for Methacetin Breath Test ...,Inclusion Criteria:\n\n1. Male or female subje...,To validate the ability of the Methacetin Brea...,"NASH - Nonalcoholic Steatohepatitis, Cirrhosis...","Methacetin Breath Test, Emricasan, Placebo ora...",Number of Subjects With Matched Clinically Sig...,hypertension,"Age, Continuous, Sex: Female, Male, Ethnicity ...","`Age, Continuous`, `Sex: Female, Male`, `Ethni...","`Age, Continuous`, `Sex: Female, Male`, `Ethni...","`Age, Continuous`, `Sex: Female, Male`, `Ethni..."
2,NCT00458003,Phenylephrine in Spinal Anesthesia in Preeclam...,Inclusion Criteria:\n\n* ASA PS II - III women...,Hypotension remains a common clinical problem ...,"Preeclampsia, Hypotension,","Ephedrine, Phenylephrine,","The Umbilical Artery pH,",hypertension,"Age, Categorical, Sex: Female, Male, Region of...","`Age, Categorical`, `Sex: Female, Male`, `Regi...","`Age, Continuous`, `Sex: Female, Male`, `Race/...","`Age, Continuous`, `ASA PS`, `Body mass index ..."


In [11]:
data.to_csv('../data/hidden_data/CT-Repo-500-Samples-all-rag_on_ctrepo-gen.csv', index=False)

## GPT4 Evaluation

In [12]:
eval_data = pd.DataFrame()
eval_data['NCTId'] = data['NCTId']

In [13]:
data.head(3)

Unnamed: 0,NCTId,BriefTitle,EligibilityCriteria,BriefSummary,Conditions,Interventions,PrimaryOutcomes,TrialGroup,API_BaselineMeasures,API_BaselineMeasures_Corrected,gpt4o_rag_ts_gen,llama3_70b_it_rag_ts_gen
0,NCT01743001,Clinical Study to Evaluate the Effects of Maci...,Inclusion Criteria:\n\n* Subjects:\n\n * not ...,"Clinical study to assess the efficacy, safety,...","Pulmonary Arterial Hypertension,","Macitentan 10 mg, Placebo,",Change From Baseline to Week 16 in Exercise Ca...,hypertension,"Age, Continuous, Age, Customized, Sex: Female,...","`Age, Continuous`, `Age, Customized`, `Sex: Fe...","`Age, Continuous`, `Age, Categorical`, `Sex: F...","`Age, Categorical`, `Age, Continuous`, `Sex: F..."
1,NCT03439189,Companion Protocol for Methacetin Breath Test ...,Inclusion Criteria:\n\n1. Male or female subje...,To validate the ability of the Methacetin Brea...,"NASH - Nonalcoholic Steatohepatitis, Cirrhosis...","Methacetin Breath Test, Emricasan, Placebo ora...",Number of Subjects With Matched Clinically Sig...,hypertension,"Age, Continuous, Sex: Female, Male, Ethnicity ...","`Age, Continuous`, `Sex: Female, Male`, `Ethni...","`Age, Continuous`, `Sex: Female, Male`, `Ethni...","`Age, Continuous`, `Sex: Female, Male`, `Ethni..."
2,NCT00458003,Phenylephrine in Spinal Anesthesia in Preeclam...,Inclusion Criteria:\n\n* ASA PS II - III women...,Hypotension remains a common clinical problem ...,"Preeclampsia, Hypotension,","Ephedrine, Phenylephrine,","The Umbilical Artery pH,",hypertension,"Age, Categorical, Sex: Female, Male, Region of...","`Age, Categorical`, `Sex: Female, Male`, `Regi...","`Age, Continuous`, `Sex: Female, Male`, `Race/...","`Age, Continuous`, `ASA PS`, `Body mass index ..."


In [14]:
import json
from tqdm import tqdm

for index, row in tqdm(data.iterrows()):

    example_trials_to_avoid = ['NCT00000620', 'NCT01483560', 'NCT04280783']
    if row['NCTId'] in example_trials_to_avoid:
        continue

    qstart = module.get_question_from_row(row)

    if row['llama3_70b_it_rag_ts_gen'] is None:
        continue

    reference_list = module.extract_elements_v2(row['API_BaselineMeasures_Corrected']) #which column to use for reference as input from ctpub
    
    candidate_gts = module.extract_elements_v2(row['gpt4o_rag_ts_gen'])
    candidate_lts = module.extract_elements_v2(row['llama3_70b_it_rag_ts_gen'])
    
    system_message_gts, question_gts = module.build_gpt4_eval_prompt(reference_list,
                                                            candidate_gts,
                                                            qstart)
    
    
    system_message_lts, question_lts = module.build_gpt4_eval_prompt(reference_list,
                                                            candidate_lts,
                                                            qstart)

    
    eval_model_response_gts = module.run_evaluation_with_gpt4o(system_message_gts, question_gts, os.environ["OPENAI_API_KEY"])
    eval_model_response_lts = module.run_evaluation_with_gpt4o(system_message_lts, question_lts, os.environ["OPENAI_API_KEY"])
    
    #Convert eval_model_response to a JSON string and store in the dataframe
    eval_data.at[index, 'gpt4o_ts_gen_matches'] = eval_model_response_gts
    eval_data.at[index, 'llama3_70b_it_ts_gen_matches'] = eval_model_response_lts

    #break
    


500it [46:51,  5.62s/it]


In [15]:
eval_data.to_csv('../data/hidden_data/CT-Repo-500-Samples-all-rag_on_ctrepo-eval.csv', index=False)

In [16]:
eval_data.head(3)

Unnamed: 0,NCTId,gpt4o_ts_gen_matches,llama3_70b_it_ts_gen_matches
0,NCT01743001,"{\n ""matched_features"": [\n [""Age, C...","{\n ""matched_features"": [\n [""Age, Continu..."
1,NCT03439189,"{\n ""matched_features"": [\n [""Age, C...","{\n ""matched_features"": [\n [""Age, C..."
2,NCT00458003,"{\n ""matched_features"": [\n [""Age, C...","{\n ""matched_features"": [\n [""Age, C..."


In [17]:
print(eval_data.at[2, 'gpt4o_ts_gen_matches'])

{
    "matched_features": [
        ["Age, Categorical", "Age, Continuous"],
        ["Sex: Female, Male", "Sex: Female, Male"],
        ["Region of Enrollment", "Region of Enrollment"],
        ["Estimated gestational age (weeks)", "Gestational age (weeks)"],
        ["Preeclampsia", "Previous preeclampsia"],
        ["Baseline systolic blood pressure", "Systolic blood pressure (mmHg)"],
        ["Body mass index", "Body mass index (kg/m2)"]
    ],
    "remaining_reference_features": [
        "Multiple gestation",
        "Number of infants delivered per group",
        "Baseline heart rate"
    ],
    "remaining_candidate_features": [
        "Race/Ethnicity, Customized",
        "Parity",
        "Chronic hypertension",
        "Diastolic blood pressure (mmHg)",
        "Mode of delivery"
    ]
}


## Score Calculation

In [18]:
score_df = pd.DataFrame()
score_df['NCTId'] = data['NCTId']

In [19]:
import json 

gts_sum = {'precision':0, 'recall':0, 'f1':0}
lts_sum = {'precision':0, 'recall':0, 'f1':0}

total_count = 0

for index, row in tqdm(eval_data.iterrows()):
    avoid_ids = ['NCT00000620', 'NCT01483560', 'NCT04280783'] #these were used as examples for 3-shot generation
    if row['NCTId'] in avoid_ids:
        continue
    if not row['llama3_70b_it_ts_gen_matches'] or not row['gpt4o_ts_gen_matches']:
        continue

    #print(row['NCTId'])
    
    try:
        # gts
        gts_dict = json.loads(row['gpt4o_ts_gen_matches'])
        gts_matches = gts_dict['matched_features']
        gts_remaining_references = gts_dict['remaining_reference_features']
        gts_remaining_candidates = gts_dict['remaining_candidate_features']
        gts_score = module.match_to_score(gts_matches, gts_remaining_references, gts_remaining_candidates)
        score_df.at[index, 'gpt4o_ts_gen_scores'] = json.dumps(gts_score)
        gts_sum['precision'] += gts_score['precision']
        gts_sum['recall'] += gts_score['recall']
        gts_sum['f1'] += gts_score['f1']

        # lts 
        lts_dict = json.loads(row['llama3_70b_it_ts_gen_matches'])
        lts_matches = lts_dict['matched_features']
        lts_remaining_references = lts_dict['remaining_reference_features']
        lts_remaining_candidates = lts_dict['remaining_candidate_features']
        lts_score = module.match_to_score(lts_matches, lts_remaining_references, lts_remaining_candidates)
        score_df.at[index, 'llama3_70b_it_ts_gen_scores'] = json.dumps(lts_score)
        lts_sum['precision'] += lts_score['precision']
        lts_sum['recall'] += lts_score['recall']
        lts_sum['f1'] += lts_score['f1']

        total_count += 1

    except Exception as e:
        print(f"An error occurred while processing trial {row['NCTId']}: {e}")
        continue

    #break



500it [00:00, 16585.36it/s]

An error occurred while processing trial NCT02214186: Unterminated string starting at: line 120 column 9 (char 4519)
An error occurred while processing trial NCT02830880: the JSON object must be str, bytes or bytearray, not float
An error occurred while processing trial NCT03463265: the JSON object must be str, bytes or bytearray, not float
An error occurred while processing trial NCT01683994: the JSON object must be str, bytes or bytearray, not float
An error occurred while processing trial NCT02538510: the JSON object must be str, bytes or bytearray, not float
An error occurred while processing trial NCT00417079: Unterminated string starting at: line 146 column 9 (char 5306)
An error occurred while processing trial NCT01307579: division by zero
An error occurred while processing trial NCT01505868: the JSON object must be str, bytes or bytearray, not float
An error occurred while processing trial NCT02337530: Unterminated string starting at: line 107 column 5 (char 4103)
An error occu




In [20]:
score_df.head()

Unnamed: 0,NCTId,gpt4o_ts_gen_scores,llama3_70b_it_ts_gen_scores
0,NCT01743001,"{""precision"": 0.3333333333333333, ""recall"": 0....","{""precision"": 0.38461538461538464, ""recall"": 0..."
1,NCT03439189,"{""precision"": 0.4166666666666667, ""recall"": 0....","{""precision"": 0.3125, ""recall"": 0.714285714285..."
2,NCT00458003,"{""precision"": 0.5833333333333334, ""recall"": 0....","{""precision"": 0.5, ""recall"": 0.5, ""f1"": 0.5}"
3,NCT02056626,"{""precision"": 0.3, ""recall"": 0.5, ""f1"": 0.3749...","{""precision"": 0.2222222222222222, ""recall"": 0...."
4,NCT01780974,"{""precision"": 0.38095238095238093, ""recall"": 0...","{""precision"": 0.38095238095238093, ""recall"": 0..."


In [21]:
print(f"GPT-4o Three Shot Scores: Precision={gts_sum['precision']/total_count} Recall={gts_sum['recall']/total_count} F1={gts_sum['f1']/total_count}")
print(f"Llama-3 70B Instruct Three Shot Scores: Precision={lts_sum['precision']/total_count} Recall={lts_sum['recall']/total_count} F1={lts_sum['f1']/total_count}")


GPT-4o Three Shot Scores: Precision=0.4803974116518465 Recall=0.6404057933646666 F1=0.5228506767138614
Llama-3 70B Instruct Three Shot Scores: Precision=0.5076440399356804 Recall=0.6186471641231414 F1=0.5271666553959126
