In [58]:
import pandas as pd
from openai import OpenAI
import os
import random
import json
from dotenv import load_dotenv
from typing import List, Dict

from tqdm.auto import tqdm

In [2]:
load_dotenv('/home/ubuntu/medical_assistant_rag/.envrc')
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')

## Ingestion

In [3]:
df = pd.read_csv('./data/data_metadata_small.csv')
df.head()

Unnamed: 0,id,question,answer,medical_department,condition_type,patient_demographics,common_symptoms,treatment_or_management,severity
0,0,A 23-year-old pregnant woman at 22 weeks gesta...,Nitrofurantoin,Obstetrics & Gynecology,Infectious,"Age Group: Adult, Gender: Female, Pregnancy St...","Burning sensation (e.g., urination)",Medication,Mild
1,1,A 3-month-old baby died suddenly at night whil...,Placing the infant in a supine position on a f...,Pediatrics,Idiopathic,"Age Group: Infant (1-12 months), Gender: Male,...","Fever, Altered Mental Status","Preventive Measures (e.g., vaccinations)",Life-threatening
2,2,A mother brings her 3-week-old infant to the p...,Abnormal migration of ventral pancreatic bud,Pediatrics,Infectious,"Age Group: Neonate (0-28 days), Gender: Male, ...","Fussiness, Nausea/Vomiting",Observation/Monitoring,Moderate
3,3,A pulmonary autopsy specimen from a 58-year-ol...,Thromboembolism,Pulmonology,Acute,"Age Group: Adult, Gender: Female, Pregnancy St...","Dyspnea (Shortness of breath), Fatigue","Supportive Care (e.g., oxygen therapy)",Life-threatening
4,4,A 20-year-old woman presents with menorrhagia ...,Von Willebrand disease,Obstetrics & Gynecology,Chronic,"Age Group: Adult, Gender: Female, Pregnancy St...","Bleeding (e.g., menorrhagia), Easy bruising",Medication,Moderate


In [4]:
len(df)

99

In [5]:
documents = df.to_dict(orient='records')
documents[0]

{'id': 0,
 'question': 'A 23-year-old pregnant woman at 22 weeks gestation presents with burning upon urination. She states it started 1 day ago and has been worsening despite drinking more water and taking cranberry extract. She otherwise feels well and is followed by a doctor for her pregnancy. Her temperature is 97.7°F (36.5°C), blood pressure is 122/77 mmHg, pulse is 80/min, respirations are 19/min, and oxygen saturation is 98% on room air. Physical exam is notable for an absence of costovertebral angle tenderness and a gravid uterus. Which of the following is the best treatment for this patient?',
 'answer': 'Nitrofurantoin',
 'medical_department': 'Obstetrics & Gynecology',
 'condition_type': 'Infectious',
 'patient_demographics': 'Age Group: Adult, Gender: Female, Pregnancy Status: Pregnant',
 'common_symptoms': 'Burning sensation (e.g., urination)',
 'treatment_or_management': 'Medication',
 'severity': 'Mild'}

## Elastic Search

In [21]:
from sentence_transformers import SentenceTransformer
from elasticsearch import Elasticsearch

In [22]:
es_url = 'http://localhost:9200'
es_client = Elasticsearch(es_url)

Indexing

In [23]:
index_settings = {
    "settings": {
        "number_of_shards": 1,
        "number_of_replicas": 0
    },
    "mappings": {
        "properties": {
            "id": {"type": "keyword"},
            "question": {"type": "text"},
            "answer": {"type": "text"},
            "medical_department": {"type": "keyword"},
            "condition_type": {"type": "keyword"},
            "patient_demographics": {"type": "text"},
            "common_symptoms": {"type": "text"},
            "treatment_or_management": {"type": "text"},
            "severity": {"type": "keyword"},
            "question_vector": {
                "type": "dense_vector",
                "dims": 384,
                "index": True,
                "similarity": "cosine"
            },
            "answer_vector": {
                "type": "dense_vector",
                "dims": 384,
                "index": True,
                "similarity": "cosine"
            },
            "question_answer_vector": {
                "type": "dense_vector",
                "dims": 384,
                "index": True,
                "similarity": "cosine"
            },
        }
    }
}

In [24]:
index_name = "medical-questions"

es_client.indices.delete(index=index_name, ignore_unavailable=True)
es_client.indices.create(index=index_name, body=index_settings)

ObjectApiResponse({'acknowledged': True, 'shards_acknowledged': True, 'index': 'medical-questions'})

In [25]:
model = SentenceTransformer('sentence-transformers/multi-qa-MiniLM-L6-cos-v1')



In [26]:
for doc in tqdm(documents):
    question = doc.get('question', 'No question provided')
    answer = doc.get('answer', 'No answer provided')
    qa_combined = question + ' ' + answer

    doc['question'] = question
    doc['answer'] = answer
    doc['question_vector'] = model.encode(question).tolist()
    doc['answer_vector'] = model.encode(answer).tolist()
    doc['question_answer_vector'] = model.encode(qa_combined).tolist()

    # Use the document's 'id' field as the Elasticsearch document ID
    es_client.index(index=index_name, id=doc['id'], document=doc)


  0%|          | 0/99 [00:00<?, ?it/s]

Retrieval

In [27]:
from langchain.embeddings import SentenceTransformerEmbeddings
from typing import Dict
from langchain_elasticsearch import ElasticsearchRetriever

In [28]:
# Original query
query = "A 23-year-old pregnant woman at 22 weeks gestation presents with burning upon urination. She states it started 1 day ago and has been worsening despite drinking more water and taking cranberry extract. She otherwise feels well and is followed by a doctor for her pregnancy. Her temperature is 97.7°F (36.5°C), blood pressure is 122/77 mmHg, pulse is 80/min, respirations are 19/min, and oxygen saturation is 98% on room air. Physical exam is notable for an absence of costovertebral angle tenderness and a gravid uterus. Which of the following is the best treatment for this patient?"

# Generated query
# query = "A 30-year-old woman in her second trimester of pregnancy presents with symptoms of dysuria and urinary urgency. She has no significant medical history and is not allergic to any medications. Physical examination and vital signs are within normal limits. Which antibiotic is considered safe and effective for treating her urinary tract infection during pregnancy?"

Embeddings:

In [29]:
embeddings = SentenceTransformerEmbeddings(model_name="sentence-transformers/multi-qa-MiniLM-L6-cos-v1")

In [67]:
import logging
logging.basicConfig(level=logging.INFO)

In [68]:
def hybrid_query(search_query: str) -> List[Dict]:
    vector = embeddings.embed_query(search_query)
    query_body = {
        "query": {
            "bool": {
                "must": {
                    "multi_match": {
                        "query": search_query,
                        "fields": [
                            "question",
                            "answer",
                            "common_symptoms",
                            "condition_type",
                            "medical_department",
                            "patient_demographics"
                        ],
                        "type": "best_fields",
                        "boost": 0.5,
                    }
                }
            }
        },
        "knn": {
            "field": "question_answer_vector",
            "query_vector": vector,
            "k": 5,
            "num_candidates": 10000,
            "boost": 0.5
        },
        "size": 5,
        "_source": ["question", "id", "answer"]
    }

    logging.info(f"Query body being sent to Elasticsearch: {query_body}")

    hybrid_retriever = ElasticsearchRetriever.from_es_params(
        index_name=index_name,
        body_func=lambda q: query_body,
        content_field='question',
        url=es_url,
    )
    
    hybrid_results = hybrid_retriever.invoke(search_query)
    
    result_docs = []
    for doc in hybrid_results:
        source = doc.metadata
        if 'id' in source:
            # Add a log to inspect the structure of each document
            logging.info(f"Parsed document: {source}")
            result_docs.append({'id': source['id'], 'question': source.get('question', ''), 'answer': source.get('answer', '')})

    logging.info(f"Final parsed results: {result_docs}")
    return result_docs

def question_text_hybrid(q):
    question = q['question']
    return hybrid_query(question)

def hit_rate(relevance_total):
    cnt = 0

    for line in relevance_total:
        if True in line:
            cnt = cnt + 1

    return cnt / len(relevance_total)

def mrr(relevance_total):
    total_score = 0.0

    for line in relevance_total:
        for rank in range(len(line)):
            if line[rank] == True:
                total_score = total_score + 1 / (rank + 1)

    return total_score / len(relevance_total)

def evaluate(documents, search_function):
    relevance_total = []

    for q in tqdm(documents):
        doc_id = q['id']
        results = search_function(q)
        logging.info(f"Results for query '{q['question']}': {results}")
        
        # Check for relevance by matching 'id'
        relevance = [d['id'] == doc_id for d in results if 'id' in d]
        relevance_total.append(relevance)

    hit_rate_value = hit_rate(relevance_total)
    mrr_value = mrr(relevance_total)
    logging.info(f"Hit rate: {hit_rate_value}, MRR: {mrr_value}")

    return {
        'hit_rate': hit_rate_value,
        'mrr': mrr_value,
    }

In [69]:
evaluate(documents, question_text_hybrid)

  0%|          | 0/99 [00:00<?, ?it/s]

INFO:root:Query body being sent to Elasticsearch: {'query': {'bool': {'must': {'multi_match': {'query': 'A 23-year-old pregnant woman at 22 weeks gestation presents with burning upon urination. She states it started 1 day ago and has been worsening despite drinking more water and taking cranberry extract. She otherwise feels well and is followed by a doctor for her pregnancy. Her temperature is 97.7°F (36.5°C), blood pressure is 122/77 mmHg, pulse is 80/min, respirations are 19/min, and oxygen saturation is 98% on room air. Physical exam is notable for an absence of costovertebral angle tenderness and a gravid uterus. Which of the following is the best treatment for this patient?', 'fields': ['question', 'answer', 'common_symptoms', 'condition_type', 'medical_department', 'patient_demographics'], 'type': 'best_fields', 'boost': 0.5}}}}, 'knn': {'field': 'question_answer_vector', 'query_vector': [0.0015250662108883262, 0.012687389738857746, 0.06418988108634949, 0.056518349796533585, -0.

{'hit_rate': 0.0, 'mrr': 0.0}

## RAG

In [12]:
client = OpenAI()

In [13]:
def search(query):
    boost = {}

    results = index.search(
        query=query,
        filter_dict={},
        boost_dict=boost,
        num_results=10
    )

    return results

prompt_template = """
You are a knowledgeable medical assistant. Answer the QUESTION based solely on the information provided in the CONTEXT from the medical database.

Use only the facts from the CONTEXT when formulating your answer.

QUESTION: {question}

CONTEXT:
{context}
""".strip()

entry_template = """
 medical_department: {medical_department},
 condition_type: {condition_type},
 patient_demographics: {patient_demographics}, 
 common_symptoms: {common_symptoms},
 treatment_or_management: {treatment_or_management},
 severity: {severity},
""".strip()
    
def build_prompt(query, search_results):
    context = ""
    for doc in search_results:
        context = context + entry_template.format(**doc) + "\n\n"
    
    prompt = prompt_template.format(question=query, context=context).strip()
    return prompt

def llm(prompt, model='gpt-4o-mini'):
    response = client.chat.completions.create(
        model=model,
        messages=[{"role": "user", "content": prompt}]
    )
    
    return response.choices[0].message.content

def rag(query, model='gpt-4o-mini'):
    search_results = search(query)
    prompt = build_prompt(query, search_results)
    answer = llm(prompt, model=model)
    return answer

In [14]:
query = "Given that this patient is at 22 weeks of gestation and without signs of systemic infection, how does the choice of antibiotic like nitrofurantoin compare to other options in terms of safety during pregnancy, and what factors should be considered when prescribing antibiotics to pregnant patients?"

In [15]:
answer = rag(query)
print(answer)

The context provided does not specifically mention nitrofurantoin or safety comparisons of antibiotics during pregnancy. However, in general, nitrofurantoin is considered safe for use during pregnancy, particularly in the second trimester, as it is less likely to cause adverse effects on the developing fetus compared to some other antibiotics. 

When prescribing antibiotics to pregnant patients, several factors should be considered, including:

1. **Gestational Age**: The safety of certain antibiotics can vary depending on the stage of pregnancy.
2. **Potential Side Effects**: The risk of side effects to both the mother and the fetus, as well as any known teratogenic effects of the antibiotic.
3. **Drug Efficacy**: The effectiveness of the antibiotic against the specific infection being treated.
4. **Maternal Health**: The overall health of the pregnant patient and any comorbid conditions.
5. **Infection Severity**: The severity of the infection and the urgency of treatment.

In summar

## Retrieval Evaluation

In [16]:
df_question = pd.read_csv('./data/ground_truth_retrieval_small.csv')

In [17]:
ground_truth = df_question.to_dict(orient='records')
ground_truth[0]

{'id': 0,
 'question': 'What are the potential causes of dysuria in a 23-year-old pregnant woman at 22 weeks gestation who presents with burning upon urination, and how do the symptoms of a urinary tract infection compare to other conditions like vulvovaginitis?'}

In [18]:
def hit_rate(relevance_total):
    cnt = 0

    for line in relevance_total:
        if True in line:
            cnt = cnt + 1

    return cnt / len(relevance_total)
    
def mrr(relevance_total):
    total_score = 0.0

    for line in relevance_total:
        for rank in range(len(line)):
            if line[rank] == True:
                total_score = total_score + 1 / (rank + 1)

    return total_score / len(relevance_total)

In [19]:
def minsearch_search(query):
    boost = {}

    results = index.search(
        query=query,
        filter_dict={},
        boost_dict=boost,
        num_results=10
    )

    return results

In [20]:
def evaluate(ground_truth, search_function):
    relevance_total = []

    for q in tqdm(ground_truth):
        doc_id = q['id']
        results = search_function(q)
        relevance = [d['id'] == doc_id for d in results]
        relevance_total.append(relevance)

    return {
        'hit_rate': hit_rate(relevance_total),
        'mrr': mrr(relevance_total),
    }

In [21]:
evaluate(ground_truth, lambda q: minsearch_search(q['question']))

  0%|          | 0/495 [00:00<?, ?it/s]

{'hit_rate': 0.8, 'mrr': 0.48799823633156975}

Baseline: 

Hit Rate: 80%,
MRR: 48.79%

## Finding best parameters

In [22]:
df_validation = df_question[:100]
df_test = df_question[100:]

In [23]:
def simple_optimize(param_ranges, objective_function, n_iterations=10):
    best_params = None
    best_score = float('-inf')  # Assuming we're minimizing. Use float('-inf') if maximizing.

    for _ in range(n_iterations):
        # Generate random parameters
        current_params = {}
        for param, (min_val, max_val) in param_ranges.items():
            if isinstance(min_val, int) and isinstance(max_val, int):
                
                current_params[param] = random.randint(min_val, max_val)
            else:
                current_params[param] = random.uniform(min_val, max_val)

        # Evaluate the objective function
        current_score = objective_function(current_params)
    
        # Update best if current is better
        if current_score > best_score:  # Change to > if maximizing
            best_score = current_score
            best_params = current_params
    
    return best_params, best_score

In [38]:
gt_val = df_validation.to_dict(orient = 'records')

In [39]:
evaluate(gt_val, lambda q: minsearch_search(q['question']))

  0%|          | 0/100 [00:00<?, ?it/s]

{'hit_rate': 0.86, 'mrr': 0.5577460317460318}

In [40]:
def minsearch_search(query, boost=None):
    if boost is None:
        boost = {}

    results = index.search(
        query=query,
        filter_dict={},
        boost_dict=boost,
        num_results=10
    )

    return results

In [41]:
param_ranges = {
     'medical_department': (0.0, 3.0),
     'condition_type': (0.0, 3.0),
     'patient_demographics': (0.0, 3.0),
     'common_symptoms': (0.0, 3.0),
     'treatment_or_management': (0.0, 3.0),
     'severity': (0.0, 3.0),
     'answer': (0.0, 3.0)
}

def objective(boost_params):
    def search_function(q):
        return minsearch_search(q['question'], boost_params)

    results = evaluate(gt_val, search_function)
    return results['mrr']

In [44]:
simple_optimize(param_ranges, objective, n_iterations=10)

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

({'medical_department': 1.4632450010042994,
  'condition_type': 1.5779555497958626,
  'patient_demographics': 0.41401976060622836,
  'common_symptoms': 0.08444334512926921,
  'treatment_or_management': 1.301704814209371,
  'severity': 0.25149696425642243,
  'answer': 0.05077291909600945},
 0.7250515873015874)

In [46]:
def minsearch_search_improved(query, boost=None):
    if boost is None:
       
        boost = {'medical_department': 1.46,
                 'condition_type': 1.57,
                 'patient_demographics': 0.41,
                 'common_symptoms': 0.08,
                 'treatment_or_management': 1.30,
                 'severity': 0.25,
                 'answer': 0.050
                }

    results = index.search(
        query=query,
        filter_dict={},
        boost_dict=boost,
        num_results=10
    )

    return results

evaluate(ground_truth, lambda q: minsearch_search_improved(q['question']))

  0%|          | 0/495 [00:00<?, ?it/s]

{'hit_rate': 0.8828282828282829, 'mrr': 0.6887742504409167}

Boost parameters tuned: 

Hit Rate: 88.28%,
MRR: 68.87%

## RAG Evaluation - LLM-as-a-Judge

In [47]:
prompt1_template = """
You are an expert evaluator for a Retrieval-Augmented Generation (RAG) system.
Your task is to analyze the relevance of the generated answer compared to the original answer provided.
Based on the relevance and similarity of the generated answer to the original answer, you will classify
it as "NON_RELEVANT", "PARTLY_RELEVANT", or "RELEVANT".

Here is the data for evaluation:

Original Answer: {answer_orig}
Generated Question: {question}
Generated Answer: {answer_llm}

Please analyze the content and context of the generated answer in relation to the original
answer and provide your evaluation in parsable JSON without using code blocks:

{{
  "Relevance": "NON_RELEVANT" | "PARTLY_RELEVANT" | "RELEVANT",
  "Explanation": "[Provide a brief explanation for your evaluation]"
}}
""".strip()

prompt2_template = """
You are an expert evaluator for a Retrieval-Augmented Generation (RAG) system.
Your task is to analyze the relevance of the generated answer to the given question.
Based on the relevance of the generated answer, you will classify it
as "NON_RELEVANT", "PARTLY_RELEVANT", or "RELEVANT".

Here is the data for evaluation:

Question: {question}
Generated Answer: {answer_llm}

Please analyze the content and context of the generated answer in relation to the question
and provide your evaluation in parsable JSON without using code blocks:

{{
  "Relevance": "NON_RELEVANT" | "PARTLY_RELEVANT" | "RELEVANT",
  "Explanation": "[Provide a brief explanation for your evaluation]"
}}
""".strip()

In [48]:
record = ground_truth[0]
question = record['question']
answer_orig = documents[0]['answer']
answer_llm = rag(question)
print(answer_llm)

The potential causes of dysuria (burning upon urination) in a 23-year-old pregnant woman at 22 weeks gestation may include a urinary tract infection (UTI) or conditions such as vulvovaginitis. UTIs are characterized by a burning sensation during urination, increased frequency of urination, and possibly urgency. In contrast, vulvovaginitis may also present with a burning sensation but is more likely associated with vaginal discharge, itching, or irritation. 

While both conditions can cause similar symptoms, the presence of specific additional symptoms such as discharge or itchiness may help differentiate vulvovaginitis from a UTI. Thus, a thorough evaluation is essential to determine the underlying cause of dysuria.


In [49]:
prompt = prompt1_template.format(question=question, answer_orig=answer_orig, answer_llm=answer_llm)
print(prompt)

You are an expert evaluator for a Retrieval-Augmented Generation (RAG) system.
Your task is to analyze the relevance of the generated answer compared to the original answer provided.
Based on the relevance and similarity of the generated answer to the original answer, you will classify
it as "NON_RELEVANT", "PARTLY_RELEVANT", or "RELEVANT".

Here is the data for evaluation:

Original Answer: Nitrofurantoin
Generated Question: What are the potential causes of dysuria in a 23-year-old pregnant woman at 22 weeks gestation who presents with burning upon urination, and how do the symptoms of a urinary tract infection compare to other conditions like vulvovaginitis?
Generated Answer: The potential causes of dysuria (burning upon urination) in a 23-year-old pregnant woman at 22 weeks gestation may include a urinary tract infection (UTI) or conditions such as vulvovaginitis. UTIs are characterized by a burning sensation during urination, increased frequency of urination, and possibly urgency. 

In [50]:
llm(prompt)

'{\n  "Relevance": "NON_RELEVANT",\n  "Explanation": "The generated answer does not address the original answer \'Nitrofurantoin\', which is a medication commonly used to treat urinary tract infections. Instead, it discusses potential causes of dysuria and symptoms related to UTI and vulvovaginitis, which is unrelated to the provided original answer."\n}'

In [51]:
prompt = prompt2_template.format(question=question, answer_llm=answer_llm)
print(prompt)

You are an expert evaluator for a Retrieval-Augmented Generation (RAG) system.
Your task is to analyze the relevance of the generated answer to the given question.
Based on the relevance of the generated answer, you will classify it
as "NON_RELEVANT", "PARTLY_RELEVANT", or "RELEVANT".

Here is the data for evaluation:

Question: What are the potential causes of dysuria in a 23-year-old pregnant woman at 22 weeks gestation who presents with burning upon urination, and how do the symptoms of a urinary tract infection compare to other conditions like vulvovaginitis?
Generated Answer: The potential causes of dysuria (burning upon urination) in a 23-year-old pregnant woman at 22 weeks gestation may include a urinary tract infection (UTI) or conditions such as vulvovaginitis. UTIs are characterized by a burning sensation during urination, increased frequency of urination, and possibly urgency. In contrast, vulvovaginitis may also present with a burning sensation but is more likely associated

In [52]:
llm(prompt)

'{\n  "Relevance": "RELEVANT",\n  "Explanation": "The generated answer directly addresses the question by outlining potential causes of dysuria in a pregnant woman, specifically mentioning urinary tract infections and vulvovaginitis. It also compares the symptoms of a UTI with those of vulvovaginitis, thereby providing a clear and relevant response to the inquiry about both conditions and their differential diagnosis."\n}'

## GPT4o-mini

In [53]:
model = 'gpt-4o-mini'

## Evaluation for Prompt 1

In [71]:
evaluations_1 = []

In [72]:
documents_dict = {doc['id']: doc for doc in documents}

In [73]:
for record in tqdm(ground_truth):
    question = record['question']
    record_id = record['id']
    
    if record_id not in documents_dict:
        print(f"Missing index: {record_id}")
        continue

    answer_orig = documents_dict[record_id]['answer']
    answer_llm = rag(question)

    prompt = prompt1_template.format(
        question=question,
        answer_orig=answer_orig,
        answer_llm=answer_llm
    )
    evaluation = llm(prompt, model)
    evaluation = json.loads(evaluation)

    evaluations_1.append((record, answer_orig, answer_llm, evaluation))

  0%|          | 0/495 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
df_eval = pd.DataFrame(evaluations_1, columns = ['record', 'answer_llm', 'evaluation'])

df_eval['id'] = df_eval.record.apply(lambda d: d['id'])
df_eval['question'] = df_eval.record.apply(lambda d: d['question'])

df_eval['relevance'] = df_eval.evaluation.apply(lambda d: d['Relevance'])
df_eval['explanation'] = df_eval.evaluation.apply(lambda d: d['Explanation'])

del df_eval['record']
del df_eval['evaluation']

df_eval.to_csv('llm_as_a_judge_prompt1_gpt-4o-mini.csv', index=False)

df_eval.relevance.value_counts(normalize=True)

## Evaluation for Prompt 2

In [210]:
evaluations_2 = []

In [211]:
for record in tqdm(ground_truth):

    question = record['question']
    answer_llm = rag(question)
    
    prompt = prompt2_template.format(
        question=question,
        answer_llm=answer_llm
    )
    evaluation = llm(prompt, model)
    evaluation = json.loads(evaluation)
    
    evaluations_2.append((record, answer_llm, evaluation))

  0%|          | 0/500 [00:00<?, ?it/s]

In [None]:
df_eval = pd.DataFrame(evaluations_2, columns = ['record', 'answer_llm', 'evaluation'])

df_eval['id'] = df_eval.record.apply(lambda d: d['id'])
df_eval['question'] = df_eval.record.apply(lambda d: d['question'])

df_eval['relevance'] = df_eval.evaluation.apply(lambda d: d['Relevance'])
df_eval['explanation'] = df_eval.evaluation.apply(lambda d: d['Explanation'])

del df_eval['record']
del df_eval['evaluation']

df_eval.to_csv('llm_as_a_judge_prompt2_gpt-4o-mini.csv', index=False)

df_eval.relevance.value_counts(normalize=True)

In [231]:
df_eval.relevance.value_counts(normalize=True)

relevance
RELEVANT           0.956
NON_RELEVANT       0.030
PARTLY_RELEVANT    0.014
Name: proportion, dtype: float64

## GPT4o

In [235]:
model = 'gpt-4o'

## Evaluation for Prompt 1

In [None]:
evaluations_3 = []

In [None]:
for record in tqdm(ground_truth):

    question = record['question']
    answer_orig = documents[record['id']]['answer']
    answer_llm = rag(question)
    
    prompt = prompt1_template.format(
        question=question,
        answer_orig=answer_orig,
        answer_llm=answer_llm
    )
    evaluation = llm(prompt, model)
    evaluation = json.loads(evaluation)
    
    evaluations_3.append((record, answer_orig, answer_llm, evaluation))

In [None]:
df_eval = pd.DataFrame(evaluations_3, columns = ['record', 'answer_llm', 'evaluation'])

df_eval['id'] = df_eval.record.apply(lambda d: d['id'])
df_eval['question'] = df_eval.record.apply(lambda d: d['question'])

df_eval['relevance'] = df_eval.evaluation.apply(lambda d: d['Relevance'])
df_eval['explanation'] = df_eval.evaluation.apply(lambda d: d['Explanation'])

del df_eval['record']
del df_eval['evaluation']

df_eval.to_csv('llm_as_a_judge_prompt1_gpt-4o.csv', index=False)

df_eval.relevance.value_counts(normalize=True)

## Evaluation for Prompt 2

In [236]:
evaluations_4 = []

In [237]:
for record in tqdm(ground_truth):

    question = record['question']
    answer_llm = rag(question)
    
    prompt = prompt2_template.format(
        question=question,
        answer_llm=answer_llm
    )
    evaluation = llm(prompt, model)
    evaluation = json.loads(evaluation)
    
    evaluations_4.append((record, answer_llm, evaluation))

  0%|          | 0/500 [00:00<?, ?it/s]

In [None]:
df_eval = pd.DataFrame(evaluations_4, columns = ['record', 'answer_llm', 'evaluation'])

df_eval['id'] = df_eval.record.apply(lambda d: d['id'])
df_eval['question'] = df_eval.record.apply(lambda d: d['question'])

df_eval['relevance'] = df_eval.evaluation.apply(lambda d: d['Relevance'])
df_eval['explanation'] = df_eval.evaluation.apply(lambda d: d['Explanation'])

del df_eval['record']
del df_eval['evaluation']

df_eval.to_csv('llm_as_a_judge_prompt2_gpt-4o.csv', index=False)

df_eval.relevance.value_counts(normalize=True)

In [242]:
df_eval.relevance.value_counts(normalize=True)

relevance
RELEVANT           0.872
PARTLY_RELEVANT    0.100
NON_RELEVANT       0.028
Name: proportion, dtype: float64