## We are running CT_Repo 500 Sample data using ct_repo indexed stores

In [1]:
%load_ext autoreload
%autoreload 2

In [18]:
import json 
import pandas as pd 
from dotenv import load_dotenv
from model_list import models 
import chromadb
import openai
import module 
import os 
from tqdm import tqdm 

In [3]:
load_dotenv()
os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY_TEAM')
os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN')
os.environ['GROQ_API_KEY'] = os.getenv('GROQ_API_KEY')

In [4]:
#data = pd.read_csv('../data/CT-Repo-With-Examples-Corrected-500-Sample.csv')
data = pd.read_csv('../data/hidden_data/biobert_embed/CT-Repo-500-Samples-rag-biobert-allgen.csv')
data.head(3)

Unnamed: 0,NCTId,BriefTitle,EligibilityCriteria,BriefSummary,Conditions,Interventions,PrimaryOutcomes,TrialGroup,API_BaselineMeasures,API_BaselineMeasures_Corrected,gpt4o_rag_ts_gen,llama3_70b_it_rag_ts_gen
0,NCT01743001,Clinical Study to Evaluate the Effects of Maci...,Inclusion Criteria:\n\n* Subjects:\n\n * not ...,"Clinical study to assess the efficacy, safety,...","Pulmonary Arterial Hypertension,","Macitentan 10 mg, Placebo,",Change From Baseline to Week 16 in Exercise Ca...,hypertension,"Age, Continuous, Age, Customized, Sex: Female,...","`Age, Continuous`, `Age, Customized`, `Sex: Fe...","`Age, Continuous`, `Sex: Female, Male`, `Race ...",
1,NCT03439189,Companion Protocol for Methacetin Breath Test ...,Inclusion Criteria:\n\n1. Male or female subje...,To validate the ability of the Methacetin Brea...,"NASH - Nonalcoholic Steatohepatitis, Cirrhosis...","Methacetin Breath Test, Emricasan, Placebo ora...",Number of Subjects With Matched Clinically Sig...,hypertension,"Age, Continuous, Sex: Female, Male, Ethnicity ...","`Age, Continuous`, `Sex: Female, Male`, `Ethni...","`Age, Categorical`, `Age, Continuous`, `Sex: F...",
2,NCT00458003,Phenylephrine in Spinal Anesthesia in Preeclam...,Inclusion Criteria:\n\n* ASA PS II - III women...,Hypotension remains a common clinical problem ...,"Preeclampsia, Hypotension,","Ephedrine, Phenylephrine,","The Umbilical Artery pH,",hypertension,"Age, Categorical, Sex: Female, Male, Region of...","`Age, Categorical`, `Sex: Female, Male`, `Regi...","`Age, Continuous`, `Sex: Female, Male`, `Body ...",


In [10]:
#main_path = '/Users/nafisneehal/Desktop/CTBench_RAG/'

# Initialize ChromaDB
# Directory containing JSON files
input_json_directory = '../data/ctrepo_500_sample_json' #source of input json files
rag_json_directory = '../data/ctrepo_json' #source of example json files
vector_store_path = '../chroma_db_ctrepo_indexedby_biobert' #search the ctrepo vector store for examples 

# #initialize client 
c_db = chromadb.PersistentClient(path=vector_store_path)
#get collection
chroma_collection = c_db.get_or_create_collection("ctrepo_all_biobert") #located in df_to_json_and_indexing.ipynb file 

In [11]:
def json_file_to_rag_query(file_name, json_directory):
    with open(f"{json_directory}/{file_name}", "r") as file:
        trial_data = json.load(file)
        trial_query = f"""
        BriefTitle: {trial_data['BriefTitle']}\n
        EligibilityCriteria: {trial_data['EligibilityCriteria']}\n 
        BriefSummary: {trial_data['BriefSummary']}\n
        Conditions: {trial_data['Conditions']}\n
        Interventions: {trial_data['Interventions']}\n
        PrimaryOutcomes: {trial_data['PrimaryOutcomes']}
        """
        return trial_query

# GPT-4o Generation with RAG (3 shot) 

In [14]:
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core import Settings

number_of_similar_trials = 3
model_name   = models['gpt-4o']
embed_model = HuggingFaceEmbedding(model_name="dmis-lab/biobert-base-cased-v1.2")
Settings.embed_model = embed_model
Settings.llm = None 

data['gpt4o_rag_ts_gen'] = None 

for index, row in tqdm(data.iterrows()):

    #print(row['NCTId'])

    #AVOID these trials that were used as three-shot examples only
    example_trials_to_avoid = ['NCT00000620', 'NCT01483560', 'NCT04280783']
    if row['NCTId'] in example_trials_to_avoid:
        continue

    file_name = row['NCTId'] + '.json'
    query_trial = json_file_to_rag_query(file_name, input_json_directory) #string query for searching rag database only

    #run rag query and find the filenames of similar trials
    similar_trials = module.query_clinical_trials_using_llamaindex(chroma_collection, vector_store_path,  
                                                                   query_trial, top_k=number_of_similar_trials,
                                                                   embed_model=embed_model)
    
    #print(similar_trials)

    #get system message, question 
    system_message, question = module.build_three_shot_prompt(row, similar_trials, rag_json_directory, 
                                                              ref_col_name='API_BaselineMeasures_Corrected') #asking ct_repo examples 

    model_query = module.system_user_template(system_message, question)
    model_response = module.run_generation_single_openai(model_query, model_name = model_name, 
                                                openai_token = os.environ["OPENAI_API_KEY"], temperature=0.0)

    
    data.at[index, 'gpt4o_rag_ts_gen'] = model_response
    
    # print 
    # print(system_message)
    # print(question)
    # print(model_response)
    # break

  from .autonotebook import tqdm as notebook_tqdm
No sentence-transformers model found with name dmis-lab/biobert-base-cased-v1.2. Creating a new one with mean pooling.


LLM is explicitly disabled. Using MockLLM.


500it [22:52,  2.75s/it]


In [16]:
data.head(3)

Unnamed: 0,NCTId,BriefTitle,EligibilityCriteria,BriefSummary,Conditions,Interventions,PrimaryOutcomes,TrialGroup,API_BaselineMeasures,API_BaselineMeasures_Corrected,gpt4o_rag_ts_gen,llama3_70b_it_rag_ts_gen
0,NCT01743001,Clinical Study to Evaluate the Effects of Maci...,Inclusion Criteria:\n\n* Subjects:\n\n * not ...,"Clinical study to assess the efficacy, safety,...","Pulmonary Arterial Hypertension,","Macitentan 10 mg, Placebo,",Change From Baseline to Week 16 in Exercise Ca...,hypertension,"Age, Continuous, Age, Customized, Sex: Female,...","`Age, Continuous`, `Age, Customized`, `Sex: Fe...","`Age, Continuous`, `Sex: Female, Male`, `Race ...","`Age, Continuous`, `Sex: Female, Male`, `WHO f..."
1,NCT03439189,Companion Protocol for Methacetin Breath Test ...,Inclusion Criteria:\n\n1. Male or female subje...,To validate the ability of the Methacetin Brea...,"NASH - Nonalcoholic Steatohepatitis, Cirrhosis...","Methacetin Breath Test, Emricasan, Placebo ora...",Number of Subjects With Matched Clinically Sig...,hypertension,"Age, Continuous, Sex: Female, Male, Ethnicity ...","`Age, Continuous`, `Sex: Female, Male`, `Ethni...","`Age, Categorical`, `Age, Continuous`, `Sex: F...","`Age, Categorical`, `Age, Continuous`, `Sex: F..."
2,NCT00458003,Phenylephrine in Spinal Anesthesia in Preeclam...,Inclusion Criteria:\n\n* ASA PS II - III women...,Hypotension remains a common clinical problem ...,"Preeclampsia, Hypotension,","Ephedrine, Phenylephrine,","The Umbilical Artery pH,",hypertension,"Age, Categorical, Sex: Female, Male, Region of...","`Age, Categorical`, `Sex: Female, Male`, `Regi...","`Age, Continuous`, `Sex: Female, Male`, `Body ...","`Age, Continuous`, `ASA Physical Status, Categ..."


# Llama3.1 Generation with RAG (3-shot) 

In [21]:
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core import Settings

number_of_similar_trials = 3
model_name   = models['gpt-4o']
model_hf_endpoint = models['llama3_1-70b-it']

embed_model = HuggingFaceEmbedding(model_name="dmis-lab/biobert-base-cased-v1.2")
Settings.embed_model = embed_model
Settings.llm = None 

data['llama3_70b_it_rag_ts_gen'] = None 

for index, row in tqdm(data.iterrows()):

    #print(row['NCTId'])

    #AVOID these trials that were used as three-shot examples only
    example_trials_to_avoid = ['NCT00000620', 'NCT01483560', 'NCT04280783']
    if row['NCTId'] in example_trials_to_avoid:
        continue

    file_name = row['NCTId'] + '.json'
    query_trial = json_file_to_rag_query(file_name, input_json_directory) #string query for searching rag database only

    try:
        #run rag query and find the filenames of similar trials
        similar_trials = module.query_clinical_trials_using_llamaindex(chroma_collection, vector_store_path,  
                                                                    query_trial, top_k=number_of_similar_trials,
                                                                    embed_model=embed_model)
        
        # print(similar_trials)

        #get system message, question 
        system_message, question = module.build_three_shot_prompt(row, similar_trials, rag_json_directory, 
                                                                ref_col_name='API_BaselineMeasures_Corrected') #which column to use for reference in RAG examples from ctrepo

        model_query = module.system_user_template(system_message, question)
        model_response = module.run_generation_single_hf_models(model_query, model_hf_endpoint, 
                                                        os.environ['HF_TOKEN'], temperature=0.0)
        # model_response = module.run_generation_single_hf_models_groq(model_query, model_hf_endpoint, 
        #                                                 os.environ['GROQ_API_KEY'], temperature=0.0)

        
        data.at[index, 'llama3_70b_it_rag_ts_gen'] = model_response
    
    #Note: Weird Bug: UnprocessableEntityError: Error code: 422 - {'error': 'Input validation error: `inputs` tokens + `max_new_tokens` must be <= 8192. Given: 7573 `inputs` tokens and 1000 `max_new_tokens`', 'error_type': 'validation'}
    except Exception as e:
        print(f"An error occurred while processing trial {row['NCTId']}: {e}")
        continue
    
    # if index%50==0:
    #     import time 
    #     time.sleep(5)

    #print 
    # print(system_message)
    # print(question)
    # print(model_response)
    # break

No sentence-transformers model found with name dmis-lab/biobert-base-cased-v1.2. Creating a new one with mean pooling.


LLM is explicitly disabled. Using MockLLM.


45it [04:15,  6.92s/it]

An error occurred while processing trial NCT02796560: Error code: 429 - {'error': 'Model is overloaded', 'error_type': 'overloaded'}


500it [1:03:41,  7.64s/it]


In [22]:
data.head(3)

Unnamed: 0,NCTId,BriefTitle,EligibilityCriteria,BriefSummary,Conditions,Interventions,PrimaryOutcomes,TrialGroup,API_BaselineMeasures,API_BaselineMeasures_Corrected,gpt4o_rag_ts_gen,llama3_70b_it_rag_ts_gen
0,NCT01743001,Clinical Study to Evaluate the Effects of Maci...,Inclusion Criteria:\n\n* Subjects:\n\n * not ...,"Clinical study to assess the efficacy, safety,...","Pulmonary Arterial Hypertension,","Macitentan 10 mg, Placebo,",Change From Baseline to Week 16 in Exercise Ca...,hypertension,"Age, Continuous, Age, Customized, Sex: Female,...","`Age, Continuous`, `Age, Customized`, `Sex: Fe...","`Age, Continuous`, `Sex: Female, Male`, `Race ...","`Age, Continuous`, `Sex: Female, Male`, `WHO f..."
1,NCT03439189,Companion Protocol for Methacetin Breath Test ...,Inclusion Criteria:\n\n1. Male or female subje...,To validate the ability of the Methacetin Brea...,"NASH - Nonalcoholic Steatohepatitis, Cirrhosis...","Methacetin Breath Test, Emricasan, Placebo ora...",Number of Subjects With Matched Clinically Sig...,hypertension,"Age, Continuous, Sex: Female, Male, Ethnicity ...","`Age, Continuous`, `Sex: Female, Male`, `Ethni...","`Age, Categorical`, `Age, Continuous`, `Sex: F...","`Age, Categorical`, `Age, Continuous`, `Sex: F..."
2,NCT00458003,Phenylephrine in Spinal Anesthesia in Preeclam...,Inclusion Criteria:\n\n* ASA PS II - III women...,Hypotension remains a common clinical problem ...,"Preeclampsia, Hypotension,","Ephedrine, Phenylephrine,","The Umbilical Artery pH,",hypertension,"Age, Categorical, Sex: Female, Male, Region of...","`Age, Categorical`, `Sex: Female, Male`, `Regi...","`Age, Continuous`, `Sex: Female, Male`, `Body ...","`Age, Continuous`, `Body Mass Index (BMI)`, `G..."


In [23]:
data.to_csv('../data/hidden_data/biobert_embed/CT-Repo-500-Samples-rag-biobert-allgen.csv', index=False)

## GPT4 Evaluation

In [24]:
eval_data = pd.DataFrame()
eval_data['NCTId'] = data['NCTId']

In [25]:
data.head(3)

Unnamed: 0,NCTId,BriefTitle,EligibilityCriteria,BriefSummary,Conditions,Interventions,PrimaryOutcomes,TrialGroup,API_BaselineMeasures,API_BaselineMeasures_Corrected,gpt4o_rag_ts_gen,llama3_70b_it_rag_ts_gen
0,NCT01743001,Clinical Study to Evaluate the Effects of Maci...,Inclusion Criteria:\n\n* Subjects:\n\n * not ...,"Clinical study to assess the efficacy, safety,...","Pulmonary Arterial Hypertension,","Macitentan 10 mg, Placebo,",Change From Baseline to Week 16 in Exercise Ca...,hypertension,"Age, Continuous, Age, Customized, Sex: Female,...","`Age, Continuous`, `Age, Customized`, `Sex: Fe...","`Age, Continuous`, `Sex: Female, Male`, `Race ...","`Age, Continuous`, `Sex: Female, Male`, `WHO f..."
1,NCT03439189,Companion Protocol for Methacetin Breath Test ...,Inclusion Criteria:\n\n1. Male or female subje...,To validate the ability of the Methacetin Brea...,"NASH - Nonalcoholic Steatohepatitis, Cirrhosis...","Methacetin Breath Test, Emricasan, Placebo ora...",Number of Subjects With Matched Clinically Sig...,hypertension,"Age, Continuous, Sex: Female, Male, Ethnicity ...","`Age, Continuous`, `Sex: Female, Male`, `Ethni...","`Age, Categorical`, `Age, Continuous`, `Sex: F...","`Age, Categorical`, `Age, Continuous`, `Sex: F..."
2,NCT00458003,Phenylephrine in Spinal Anesthesia in Preeclam...,Inclusion Criteria:\n\n* ASA PS II - III women...,Hypotension remains a common clinical problem ...,"Preeclampsia, Hypotension,","Ephedrine, Phenylephrine,","The Umbilical Artery pH,",hypertension,"Age, Categorical, Sex: Female, Male, Region of...","`Age, Categorical`, `Sex: Female, Male`, `Regi...","`Age, Continuous`, `Sex: Female, Male`, `Body ...","`Age, Continuous`, `Body Mass Index (BMI)`, `G..."


In [26]:
import json
from tqdm import tqdm

for index, row in tqdm(data.iterrows()):

    example_trials_to_avoid = ['NCT00000620', 'NCT01483560', 'NCT04280783']
    if row['NCTId'] in example_trials_to_avoid:
        continue

    qstart = module.get_question_from_row(row)

    if row['llama3_70b_it_rag_ts_gen'] is None:
        continue

    reference_list = module.extract_elements_v2(row['API_BaselineMeasures_Corrected']) #which column to use for reference as input from ctpub
    
    candidate_gts = module.extract_elements_v2(row['gpt4o_rag_ts_gen'])
    candidate_lts = module.extract_elements_v2(row['llama3_70b_it_rag_ts_gen'])
    
    system_message_gts, question_gts = module.build_gpt4_eval_prompt(reference_list,
                                                            candidate_gts,
                                                            qstart)
    
    
    system_message_lts, question_lts = module.build_gpt4_eval_prompt(reference_list,
                                                            candidate_lts,
                                                            qstart)

    
    eval_model_response_gts = module.run_evaluation_with_gpt4o(system_message_gts, question_gts, os.environ["OPENAI_API_KEY"])
    eval_model_response_lts = module.run_evaluation_with_gpt4o(system_message_lts, question_lts, os.environ["OPENAI_API_KEY"])
    
    #Convert eval_model_response to a JSON string and store in the dataframe
    eval_data.at[index, 'gpt4o_ts_gen_matches'] = eval_model_response_gts
    eval_data.at[index, 'llama3_70b_it_ts_gen_matches'] = eval_model_response_lts

    #break
    


500it [58:13,  6.99s/it]


In [27]:
eval_data.to_csv('../data/hidden_data/biobert_embed/CT-Repo-500-Samples-rag-biobert-alleval.csv', index=False)

In [28]:
eval_data.head(3)

Unnamed: 0,NCTId,gpt4o_ts_gen_matches,llama3_70b_it_ts_gen_matches
0,NCT01743001,"{\n ""matched_features"": [\n [""Age, Continu...","{\n ""matched_features"": [\n [""Age, Continu..."
1,NCT03439189,"{\n ""matched_features"": [\n [""Age, C...","{\n ""matched_features"": [\n [""Age, C..."
2,NCT00458003,"{\n ""matched_features"": [\n [""Age, C...","{\n ""matched_features"": [\n [""Age, C..."


In [29]:
print(eval_data.at[2, 'gpt4o_ts_gen_matches'])

{
    "matched_features": [
        ["Age, Categorical", "Age, Continuous"],
        ["Sex: Female, Male", "Sex: Female, Male"],
        ["Estimated gestational age (weeks)", "Gestational age (weeks)"],
        ["Multiple gestation", "Number of fetuses"],
        ["Baseline systolic blood pressure", "Systolic blood pressure (mmHg)"],
        ["Baseline heart rate", "Heart rate (bpm)"],
        ["Body mass index", "Body mass index (kg/m2)"]
    ],
    "remaining_reference_features": [
        "Region of Enrollment",
        "Preeclampsia",
        "Number of infants delivered per group"
    ],
    "remaining_candidate_features": [
        "Diastolic blood pressure (mmHg)",
        "Parity",
        "Previous cesarean delivery",
        "Use of antihypertensive medication"
    ]
}


## Score Calculation

In [30]:
score_df = pd.DataFrame()
score_df['NCTId'] = data['NCTId']

In [32]:
import json 

gts_sum = {'precision':0, 'recall':0, 'f1':0}
lts_sum = {'precision':0, 'recall':0, 'f1':0}

total_count = 0

for index, row in tqdm(eval_data.iterrows()):
    avoid_ids = ['NCT00000620', 'NCT01483560', 'NCT04280783'] #these were used as examples for 3-shot generation
    if row['NCTId'] in avoid_ids:
        continue
    if not row['llama3_70b_it_ts_gen_matches'] or not row['gpt4o_ts_gen_matches']:
        continue

    #print(row['NCTId'])
    
    try:
        # gts
        gts_dict = json.loads(row['gpt4o_ts_gen_matches'])
        gts_matches = gts_dict['matched_features']
        gts_remaining_references = gts_dict['remaining_reference_features']
        gts_remaining_candidates = gts_dict['remaining_candidate_features']
        gts_score = module.match_to_score(gts_matches, gts_remaining_references, gts_remaining_candidates)
        score_df.at[index, 'gpt4o_ts_gen_scores'] = json.dumps(gts_score)
        gts_sum['precision'] += gts_score['precision']
        gts_sum['recall'] += gts_score['recall']
        gts_sum['f1'] += gts_score['f1']

        # lts 
        lts_dict = json.loads(row['llama3_70b_it_ts_gen_matches'])
        lts_matches = lts_dict['matched_features']
        lts_remaining_references = lts_dict['remaining_reference_features']
        lts_remaining_candidates = lts_dict['remaining_candidate_features']
        lts_score = module.match_to_score(lts_matches, lts_remaining_references, lts_remaining_candidates)
        score_df.at[index, 'llama3_70b_it_ts_gen_scores'] = json.dumps(lts_score)
        lts_sum['precision'] += lts_score['precision']
        lts_sum['recall'] += lts_score['recall']
        lts_sum['f1'] += lts_score['f1']

        total_count += 1

    except Exception as e:
        print(f"An error occurred while processing trial {row['NCTId']}: {e}")
        continue

    #break



500it [00:00, 16045.66it/s]

An error occurred while processing trial NCT01673373: Unterminated string starting at: line 75 column 9 (char 4218)
An error occurred while processing trial NCT02796560: the JSON object must be str, bytes or bytearray, not float
An error occurred while processing trial NCT03321929: Unterminated string starting at: line 97 column 5 (char 4716)
An error occurred while processing trial NCT03463265: Unterminated string starting at: line 93 column 9 (char 3428)
An error occurred while processing trial NCT03343301: Unterminated string starting at: line 78 column 5 (char 3689)
An error occurred while processing trial NCT01478048: Unterminated string starting at: line 91 column 9 (char 4782)
An error occurred while processing trial NCT01308840: Expecting value: line 136 column 1 (char 5109)
An error occurred while processing trial NCT01160484: Unterminated string starting at: line 103 column 9 (char 4221)
An error occurred while processing trial NCT02337946: Unterminated string starting at: li




In [33]:
score_df.head()

Unnamed: 0,NCTId,gpt4o_ts_gen_scores,llama3_70b_it_ts_gen_scores
0,NCT01743001,"{""precision"": 0.4, ""recall"": 0.666666666666666...","{""precision"": 0.3, ""recall"": 0.5, ""f1"": 0.3749..."
1,NCT03439189,"{""precision"": 0.24, ""recall"": 0.75, ""f1"": 0.36...","{""precision"": 0.24, ""recall"": 0.75, ""f1"": 0.36..."
2,NCT00458003,"{""precision"": 0.6363636363636364, ""recall"": 0....","{""precision"": 0.375, ""recall"": 0.6, ""f1"": 0.46..."
3,NCT02056626,"{""precision"": 0.3333333333333333, ""recall"": 0....","{""precision"": 0.2, ""recall"": 0.5, ""f1"": 0.2857..."
4,NCT01780974,"{""precision"": 0.47058823529411764, ""recall"": 0...","{""precision"": 0.25, ""recall"": 0.53846153846153..."


Old Score

In [21]:
# print(f"GPT-4o Three Shot Scores: Precision={gts_sum['precision']/total_count} Recall={gts_sum['recall']/total_count} F1={gts_sum['f1']/total_count}")
# print(f"Llama-3 70B Instruct Three Shot Scores: Precision={lts_sum['precision']/total_count} Recall={lts_sum['recall']/total_count} F1={lts_sum['f1']/total_count}")


GPT-4o Three Shot Scores: Precision=0.4803974116518465 Recall=0.6404057933646666 F1=0.5228506767138614
Llama-3 70B Instruct Three Shot Scores: Precision=0.5076440399356804 Recall=0.6186471641231414 F1=0.5271666553959126


New Score

In [34]:
print(f"GPT-4o Three Shot Scores: Precision={gts_sum['precision']/total_count} Recall={gts_sum['recall']/total_count} F1={gts_sum['f1']/total_count}")
print(f"Llama-3 70B Instruct Three Shot Scores: Precision={lts_sum['precision']/total_count} Recall={lts_sum['recall']/total_count} F1={lts_sum['f1']/total_count}")


GPT-4o Three Shot Scores: Precision=0.4986016784613313 Recall=0.6576455854041053 F1=0.5380519574760396
Llama-3 70B Instruct Three Shot Scores: Precision=0.43393577323500737 Recall=0.6605854303834549 F1=0.4941895214793782
