1. Reconnect to Milvus

2. Recreate the same embedder (for query embedding only)

3. Point LangChain’s Milvus wrapper at the existing collection

4. Use that as your retriever

In [54]:
import logging
import numpy as np
from typing import List, Tuple, Dict, Any, Optional
from collections import defaultdict
from pymilvus import connections
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import Milvus
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain.agents import AgentExecutor, create_react_agent
from langchain.tools import Tool
from dotenv import load_dotenv
from pathlib import Path
import sacrebleu
import os
from tqdm.notebook import tqdm

In [55]:
to_import = None
from datasets import load_dataset

from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import requests

# Set-Up

In [56]:
# Reconnect
connections.connect(host="localhost", port="19530")

# Reuse embedding model
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")

# Connect to existing collection
vector_store = Milvus(
    embedding_function=embedder,
    collection_name="medical_knowledge_base_v2",
    connection_args={"host": "localhost", "port": "19530"}
)

# Build retriever
retriever = vector_store.as_retriever(search_kwargs={"k": 5})

2025-05-05 16:45:23,427 - INFO - Use pytorch device_name: mps
2025-05-05 16:45:23,428 - INFO - Load pretrained SentenceTransformer: sentence-transformers/all-MiniLM-L6-v2


In [57]:
# Set up logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

# Open AI Key
env_path = Path('.env')
load_dotenv(dotenv_path=env_path)
# print(os.getenv("OPENAI_API_KEY"))


False

# RL

In [58]:
class RLFeedbackSystem:
    def __init__(self, rag_system):
        self.rag_system = rag_system
        self.feedback_history = []
        self.exploration_rate = 0.2
        self.min_k = 3
        self.max_k = 10

    def record_feedback(self, query: str, retrieved_docs: List[Any], answer: str, feedback_score: float):
        self.feedback_history.append({
            "query": query,
            "docs": retrieved_docs,
            "answer": answer,
            "score": feedback_score,
            "params": self.rag_system.retrieval_params.copy()
        })

        if len(self.feedback_history) % 10 == 0:
            self._update_parameters()

    def _update_parameters(self):
        if len(self.feedback_history) < 10:
            return

        recent = self.feedback_history[-10:]
        avg = sum(item["score"] for item in recent) / 10
        current_k = self.rag_system.retrieval_params.get("k", 5)

        if np.random.random() < self.exploration_rate:
            new_k = np.random.randint(self.min_k, self.max_k + 1)
        elif avg < 0.6:
            new_k = min(current_k + 1, self.max_k)
        elif avg > 0.8:
            new_k = max(current_k - 1, self.min_k)
        else:
            new_k = current_k

        self.rag_system.update_retrieval_params({"k": new_k})


# Vector Store Manager

In [59]:
class VectorStoreManager:
    def __init__(self, embedder, host, port, collection_name):
        connections.connect(host=host, port=port)
        self.vector_store = Milvus(
            embedding_function=embedder,
            collection_name=collection_name,
            connection_args={"host": host, "port": port}
        )

    def get_retriever(self, search_kwargs=None):
        if search_kwargs is None:
            search_kwargs = {"k": 5}
        return self.vector_store.as_retriever(search_kwargs=search_kwargs)


# RAG

In [60]:
class RAGSystem:
    def __init__(self, vsm: VectorStoreManager,
                 clip_model: CLIPModel, clip_proc: CLIPProcessor,
                 model_name="gpt-4o", temp=0.1):
        self.llm = ChatOpenAI(model=model_name, temperature=temp)
        self.vsm = vsm
        self.clip_model = clip_model
        self.clip_proc  = clip_proc

        self.text_rtr = self.vsm.vector_store.as_retriever(
            search_kwargs={"k":5, "filter":{"modality":"text"}}
        )

        self.prompt = ChatPromptTemplate.from_template(
            """
            Context:
            {context}

            Question: {question}
            """
        )
        self.chain = (
            {"context": self.text_rtr, "question": RunnablePassthrough()}
            | self.prompt
            | self.llm
            | StrOutputParser()
        )

    def answer_text(self, question: str) -> Tuple[str, List[Any]]:
        docs = self.text_rtr.invoke(question)
        ctx  = "\n\n".join(d.page_content for d in docs)
        ans  = self.chain.invoke({"context": ctx, "question": question})
        return ans, docs

    def answer_multimodal(self, question: str, image_input) -> Tuple[str, List[Any]]:
        docs_text = self.text_rtr.invoke(question)

        if isinstance(image_input, str):
            img = Image.open(image_input).convert("RGB")
        elif isinstance(image_input, Image.Image):
            img = image_input.convert("RGB")
        else:
            raise ValueError("answer_multimodal: image_input must be path or PIL.Image")

        inputs = self.clip_proc(images=img, return_tensors="pt")
        emb    = self.clip_model.get_image_features(**inputs).detach().cpu().numpy()[0]

        docs_img = self.vsm.vector_store.similarity_search_by_vector(
            emb.tolist(),
            k=5,
            filter={"modality":"image"},
        )

        docs = docs_text + docs_img
        ctx  = "\n\n".join(d.page_content for d in docs)
        ans  = self.chain.invoke({"context": ctx, "question": question})
        return ans, docs

# Agent

In [61]:
class AgentSystem:
    def __init__(
        self,
        rag_system: RAGSystem,
        clip_model: CLIPModel,
        clip_proc: CLIPProcessor,
    ):
        self.rag_system = rag_system
        self.clip_model  = clip_model
        self.clip_proc   = clip_proc

        self.tools = [
            Tool(
                name="SearchDocs",
                description="Retrieve documents from the knowledge base",
                func=self._search_docs
            ),
            Tool(
                name="SynthesizeAnswer",
                description="Generate a final answer from retrieved documents",
                func=self._synthesize_answer
            )
        ]
        prompt_template = PromptTemplate(
            input_variables=["input", "agent_scratchpad", "tools", "tool_names"],
            template="""You are an intelligent agent using tools to answer questions.

                    Tools:
                    {tools}

                    Use this format:
                    Question: {input}
                    Thought: I need to think about how to solve this
                    Action: the action to take, should be one of [{tool_names}]
                    Action Input: the input to the action
                    Observation: the result of the action
                    ... (this Thought/Action/Action Input/Observation can repeat N times)
                    Thought: I now know the final answer
                    Final Answer: the final answer to the original input question

                    {agent_scratchpad}"""
        )
        self.agent = create_react_agent(
            llm=self.rag_system.llm,
            tools=self.tools,
            prompt=prompt_template
        )
        self.agent_executor = AgentExecutor(
            agent=self.agent,
            tools=self.tools,
            verbose=True,
            max_iterations=10,
            max_execution_time=60.0,
            handle_parsing_errors=True,
        )

    def _search_docs(self, query: str) -> str:
        docs = self.rag_system.text_rtr.invoke(query)
        return "\n\n".join(d.page_content for d in docs)

    def _synthesize_answer(self, query: str) -> str:
        answer, _ = self.rag_system.answer_text(query)
        return answer

    def process_query(self, question: str, image_input=None) -> Dict[str,Any]:
        if image_input is not None:
            answer, docs = self.rag_system.answer_multimodal(question, image_input)
            return {"output": answer, "docs": docs}
        else:
            return self.agent_executor.invoke({"input": question})

# Comprehensive RAG

In [62]:
class ComprehensiveRAGSystem:
    def __init__(
        self,
        embedding_model="sentence-transformers/all-MiniLM-L6-v2",
        llm_model="gpt-4o",
        milvus_host="localhost",
        milvus_port="19530",
        collection_name="medical_knowledge_base_v2"
    ):
        clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        clip_proc  = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

        embedder = HuggingFaceEmbeddings(model_name=embedding_model)
        self.vsm = VectorStoreManager(embedder, milvus_host, milvus_port, collection_name)
        self.rag = RAGSystem(self.vsm, model_name=llm_model, clip_model=clip_model, clip_proc=clip_proc)
        self.rl = RLFeedbackSystem(self.rag)
        self.agent = AgentSystem(self.rag, self.rag.clip_model, self.rag.clip_proc)

    def answer_question(self, question: str, use_agent=False):
        if use_agent:
            return self.agent.process_query(question)["output"]
        return self.rag.answer_text(question)[0]

    def provide_feedback(self, query, answer, feedback_score):
        docs = self.rag.retriever.invoke(query)
        self.rl.record_feedback(query, docs, answer, feedback_score)


# Run System

In [63]:
system = ComprehensiveRAGSystem(
    embedding_model="sentence-transformers/all-MiniLM-L6-v2",
    llm_model="gpt-4o",
    milvus_host="localhost",
    milvus_port="19530",
    collection_name="medical_knowledge_base_v2"
)

query = "What are the symptoms of pneumonia?"
rag_answer = system.answer_question(query)
print("\n🧠 RAG Answer:", rag_answer)

agent_answer = system.answer_question(query, use_agent=True)
print("\n🤖 Agent Answer:", agent_answer)


2025-05-05 16:45:26,433 - INFO - Use pytorch device_name: mps
2025-05-05 16:45:26,434 - INFO - Load pretrained SentenceTransformer: sentence-transformers/all-MiniLM-L6-v2
2025-05-05 16:45:30,414 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"



🧠 RAG Answer: Pneumonia is an infection that inflames the air sacs in one or both lungs, which can fill with fluid or pus. The symptoms of pneumonia can vary from mild to severe and may include:

1. Cough, which may produce phlegm (mucus)
2. Fever, sweating, and chills
3. Shortness of breath
4. Rapid, shallow breathing
5. Sharp or stabbing chest pain that worsens when you breathe deeply or cough
6. Loss of appetite, low energy, and fatigue
7. Nausea and vomiting, especially in young children
8. Confusion, especially in older adults
9. Headache
10. Muscle pain

It's important to seek medical attention if you suspect pneumonia, especially if you have difficulty breathing, chest pain, persistent fever, or a persistent cough.


[1m> Entering new AgentExecutor chain...[0m


2025-05-05 16:45:30,732 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


[32;1m[1;3mQuestion: What are the symptoms of pneumonia?
Thought: I need to think about how to solve this
Action: SearchDocs
Action Input: "symptoms of pneumonia"[0m[36;1m[1;3m[0m

2025-05-05 16:45:31,574 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


[32;1m[1;3mI have retrieved relevant documents about the symptoms of pneumonia. Now, I need to synthesize this information to provide a comprehensive answer.
Action: SynthesizeAnswer
Action Input: "symptoms of pneumonia"[0m

2025-05-05 16:45:36,599 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


[33;1m[1;3mPneumonia is an infection that inflames the air sacs in one or both lungs, which can fill with fluid or pus. The symptoms of pneumonia can vary from mild to severe and may include:

1. **Cough**: Often producing phlegm or mucus.
2. **Fever**: Usually high, sometimes with chills.
3. **Shortness of Breath**: Difficulty breathing or rapid breathing.
4. **Chest Pain**: Sharp or stabbing pain that worsens with deep breathing or coughing.
5. **Fatigue**: Feeling very tired or weak.
6. **Sweating and Shaking Chills**: Often accompanying fever.
7. **Nausea, Vomiting, or Diarrhea**: Sometimes present.
8. **Confusion or Changes in Mental Awareness**: Particularly in older adults.
9. **Lower than Normal Body Temperature**: Especially in older adults and people with weak immune systems.

It's important to seek medical attention if you suspect pneumonia, especially if you have difficulty breathing, chest pain, persistent fever, or a persistent cough.[0m

2025-05-05 16:45:36,878 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


[32;1m[1;3mI now know the final answer. 

Final Answer: The symptoms of pneumonia can vary from mild to severe and typically include cough (often producing phlegm), fever (usually high, sometimes with chills), shortness of breath, chest pain, fatigue, sweating and shaking chills, nausea, vomiting, or diarrhea, confusion or changes in mental awareness (particularly in older adults), and lower than normal body temperature (especially in older adults and people with weak immune systems). It's important to seek medical attention if you suspect pneumonia, especially if you experience difficulty breathing, chest pain, persistent fever, or a persistent cough.[0m

[1m> Finished chain.[0m

🤖 Agent Answer: The symptoms of pneumonia can vary from mild to severe and typically include cough (often producing phlegm), fever (usually high, sometimes with chills), shortness of breath, chest pain, fatigue, sweating and shaking chills, nausea, vomiting, or diarrhea, confusion or changes in mental aw

In [64]:
# Activate RL?
#system.provide_feedback(query, answer, feedback_score=0.9)

In [66]:
txt_embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vsm = VectorStoreManager(txt_embedder, "localhost", 19530, "medical_knowledge_base_v2")

clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_proc  = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

ds = load_dataset("flaviagiammarino/vqa-rad", split="test").shuffle(42).select(range(5))
data = []
for ex in ds:
    q, a, img = ex["question"], ex["answer"], ex["image"]  # PIL.Image
    qt = "binary" if a.lower() in ("yes","no") else "open-ended"
    data.append({"question":q,"answer":a,"image":img,"qtype":qt})


2025-05-05 16:46:13,219 - INFO - Use pytorch device_name: mps
2025-05-05 16:46:13,220 - INFO - Load pretrained SentenceTransformer: sentence-transformers/all-MiniLM-L6-v2


In [67]:
rag   = RAGSystem(vsm, clip_model, clip_proc)
agent = AgentSystem(rag, clip_model, clip_proc)

In [68]:
res1 = agent.process_query("What are the symptoms of pneumonia?")
print(res1["output"])



[1m> Entering new AgentExecutor chain...[0m


2025-05-05 16:46:15,991 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


[32;1m[1;3mQuestion: What are the symptoms of pneumonia?
Thought: I need to think about how to solve this
Action: SearchDocs
Action Input: "symptoms of pneumonia"[0m[36;1m[1;3m[0m

2025-05-05 16:46:16,545 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


[32;1m[1;3mI have retrieved relevant documents about the symptoms of pneumonia. Now, I need to synthesize this information to provide a comprehensive answer.
Action: SynthesizeAnswer
Action Input: "symptoms of pneumonia"[0m

2025-05-05 16:46:24,032 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


[33;1m[1;3mPneumonia is an infection that inflames the air sacs in one or both lungs, which can fill with fluid or pus. The symptoms of pneumonia can vary from mild to severe, depending on factors such as the type of germ causing the infection, age, and overall health. Common symptoms include:

1. **Cough**: Often producing phlegm or mucus.
2. **Fever**: Usually high, sometimes with chills.
3. **Shortness of Breath**: Difficulty breathing or rapid breathing.
4. **Chest Pain**: Sharp or stabbing pain that worsens with deep breathing or coughing.
5. **Fatigue**: Feeling very tired or weak.
6. **Sweating and Shaking Chills**: Often accompanying fever.
7. **Nausea, Vomiting, or Diarrhea**: Sometimes present, especially in children.
8. **Confusion or Changes in Mental Awareness**: Particularly in older adults.
9. **Lower than Normal Body Temperature**: Especially in older adults and people with weak immune systems.

If you suspect you or someone else has pneumonia, it is important to seek

2025-05-05 16:46:24,341 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


[32;1m[1;3mI now know the final answer. 

Final Answer: The symptoms of pneumonia can vary from mild to severe and include cough (often producing phlegm), fever (usually high, sometimes with chills), shortness of breath, chest pain, fatigue, sweating and shaking chills, nausea, vomiting, or diarrhea, confusion or changes in mental awareness (particularly in older adults), and lower than normal body temperature (especially in older adults and people with weak immune systems). It is important to seek medical attention if pneumonia is suspected.[0m

[1m> Finished chain.[0m
The symptoms of pneumonia can vary from mild to severe and include cough (often producing phlegm), fever (usually high, sometimes with chills), shortness of breath, chest pain, fatigue, sweating and shaking chills, nausea, vomiting, or diarrhea, confusion or changes in mental awareness (particularly in older adults), and lower than normal body temperature (especially in older adults and people with weak immune syst

In [69]:
import re, math, numpy as np
from typing import List
from sentence_transformers import SentenceTransformer, util
from rapidfuzz import fuzz

In [70]:
_number_re = re.compile(r"[-+]?\d*\.?\d+")
_SBERT     = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

def _norm(s: str) -> str:
    return re.sub(r"[^\w\s]", " ", s.lower().strip())

def _extract_number(s: str):
    m = _number_re.search(s)
    return float(m.group()) if m else None

def vqa_soft_accuracy(pred: str, ref: str, num_tol: float = 0.05) -> float:
    """
    1) yes/no or categorical → exact match
    2) numeric → within ±tol% (floor 0.5 units)
    3) short phrase → fuzzy token-set ratio then SBERT cosine
    """
    p, r = _norm(pred), _norm(ref)
    # binary / yes-no
    if r in {"yes", "no", "true", "false"}:
        return 1.0 if p == r else 0.0
    # numeric
    rn, pn = _extract_number(r), _extract_number(p)
    if rn is not None and pn is not None:
        tol = max(abs(rn) * num_tol, 0.5)
        return 1.0 if abs(rn - pn) <= tol else 0.0
    # fuzzy for very short
    if len(r.split()) <= 4 and fuzz.token_set_ratio(p, r) >= 90:
        return 1.0
    # SBERT cosine fallback
    e1, e2 = _SBERT.encode([pred, ref], convert_to_tensor=True)
    return float(util.cos_sim(e1, e2).item())

2025-05-05 16:46:26,337 - INFO - Use pytorch device_name: mps
2025-05-05 16:46:26,337 - INFO - Load pretrained SentenceTransformer: sentence-transformers/all-MiniLM-L6-v2


In [71]:
def compute_bleu(pred: str, ref: str) -> float:
        return sacrebleu.corpus_bleu([pred], [[ref]]).score / 100.0

In [72]:
def compute_exact_match(prediction: str, reference: str) -> float:
    return 1.0 if prediction.strip().lower() == reference.strip().lower() else 0.0

In [73]:
def compute_recall_at_k(docs: List[Any], gt: str, k: int) -> float:
    return float(any(gt.lower() in d.page_content.lower() for d in docs[:k]))

In [74]:
def run_baseline_llm(rag: RAGSystem, data: List[Dict[str,Any]]):
    scores = []
    for ex in data:
        pred, _ = rag.answer_text(ex["question"])
        scores.append(vqa_soft_accuracy(pred, ex["answer"]))
    return {"avg_accuracy": float(np.mean(scores))}

In [75]:
def run_modality_ablation(rag: RAGSystem, data: List[Dict[str,Any]]):
    out = {"text":[], "mm":[]}

    for ex in data:
        qt, qa, qi = ex["question"], ex["answer"], ex["image"]
        p_txt,_ = rag.answer_text(qt)
        p_mm,_  = rag.answer_multimodal(qt, qi)
        out["text"].append(vqa_soft_accuracy(p_txt, qa))
        out["mm"].append(vqa_soft_accuracy(p_mm, qa))

    return {
        mode: {"avg_accuracy": float(np.mean(vals))}
        for mode, vals in out.items()
    }

In [76]:
print("Baseline LLM:", run_baseline_llm(rag, data))
print("Modality Ablation:", run_modality_ablation(rag, data))

2025-05-05 16:46:29,017 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

2025-05-05 16:46:36,701 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

2025-05-05 16:46:39,217 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-05-05 16:46:40,525 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-05-05 16:46:42,182 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


Baseline LLM: {'avg_accuracy': 0.0816872701048851}


2025-05-05 16:46:44,481 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-05-05 16:46:46,177 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

2025-05-05 16:46:52,737 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-05-05 16:46:58,260 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

2025-05-05 16:47:00,686 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-05-05 16:47:03,707 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-05-05 16:47:05,030 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-05-05 16:47:06,426 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-05-05 16:47:10,148 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-05-05 16:47:13,003 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


Modality Ablation: {'text': {'avg_accuracy': 0.07756358161568641}, 'mm': {'avg_accuracy': 0.07804577872157097}}


In [80]:
SHORT_PROMPT = """\
{context}

Question: {question}
**Answer in ≤4 words.
If yes/no, output only yes or no.
If numeric, output only the number.**"""

rag.prompt = ChatPromptTemplate.from_template(SHORT_PROMPT)

scores = run_baseline_llm(rag, data)          # expect jump from ~0.08 → >0.3
print(scores)

2025-05-05 16:50:59,561 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

2025-05-05 16:51:03,748 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

2025-05-05 16:51:06,015 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-05-05 16:51:07,382 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-05-05 16:51:10,150 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


{'avg_accuracy': 0.07649514973163604}


In [77]:
def run_prompt_refinement(rag: RAGSystem, data: List[Dict[str,Any]], max_iter=2):
    results = {i:[] for i in range(max_iter+1)}
    for ex in data:
        q, a, img = ex["question"], ex["answer"], ex["image"]
        ans0, docs = rag.answer_multimodal(q, img)
        ctx = "\n\n".join(d.page_content for d in docs)
        prev = ans0

        for i in range(max_iter+1):
            if i>0:
                prompt = (
                    f"Context:\n{ctx}\nQuestion: {q}\n"
                    f"Your previous answer: {prev}\nPlease improve."
                )
                gen = rag.llm.invoke(prompt)
                prev = gen.content if hasattr(gen, "content") else str(gen)
            results[i].append(vqa_soft_accuracy(prev, a))

    return {
        i: {"avg_accuracy": float(np.mean(scores))}
        for i, scores in results.items()
    }

In [78]:
def run_zero_shot(rag: RAGSystem, data: List[Dict[str,Any]], qtype: str):
    subset = [ex for ex in data if ex["qtype"]==qtype]
    scores = []

    for ex in subset:
        pred, _ = rag.answer_multimodal(ex["question"], ex["image"])
        scores.append(vqa_soft_accuracy(pred, ex["answer"]))

    return {"avg_accuracy": float(np.mean(scores))}

In [79]:
print("Prompt Refinement:", run_prompt_refinement(rag, data, max_iter=2))
print("Zero-Shot (binary):", run_zero_shot(rag, data, "binary"))
print("Zero-Shot (open-ended):", run_zero_shot(rag, data, "open-ended"))

2025-05-05 16:47:15,847 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

2025-05-05 16:47:17,719 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

2025-05-05 16:47:19,342 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

2025-05-05 16:47:24,400 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

2025-05-05 16:47:31,205 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

2025-05-05 16:47:39,484 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

2025-05-05 16:47:41,568 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-05-05 16:47:43,501 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-05-05 16:47:45,461 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-05-05 16:47:47,170 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-05-05 16:47:49,335 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-05-05 16:47:53,408 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-05-05 16:47:55,927 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-05-05 16:47:57,556 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-05-05 16:47:58,654 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "

Prompt Refinement: {0: {'avg_accuracy': 0.06146879494190216}, 1: {'avg_accuracy': 0.07250726372003555}, 2: {'avg_accuracy': 0.07340931445360184}}


2025-05-05 16:48:01,588 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-05-05 16:48:03,136 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-05-05 16:48:06,036 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


Zero-Shot (binary): {'avg_accuracy': 0.0}


2025-05-05 16:48:08,155 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

2025-05-05 16:48:14,042 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Zero-Shot (open-ended): {'avg_accuracy': 0.18615511618554592}
