In [None]:
import operator
import warnings
from typing import *
import traceback
import random
from tqdm.notebook import tqdm
import concurrent.futures
from concurrent.futures import ThreadPoolExecutor

import os
import torch
import base64
from PIL import Image
import io
from dotenv import load_dotenv
from IPython.display import Image
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, StateGraph
from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_openai import ChatOpenAI
from transformers import logging
import matplotlib.pyplot as plt
import numpy as np
import re

from medrax.agent import *
from medrax.tools import *
from medrax.utils import *
from medrax.models import ModelFactory

import json
import openai
import os
import glob
import time
import logging
from datetime import datetime
from tenacity import retry, wait_exponential, stop_after_attempt

# Disable verbose logging for external libraries
warnings.filterwarnings("ignore")
logging.getLogger("httpx").setLevel(logging.ERROR)
logging.getLogger("openai").setLevel(logging.ERROR)
logging.getLogger("langchain").setLevel(logging.ERROR)
logging.getLogger("langchain_core").setLevel(logging.ERROR)
logging.getLogger("langchain_openai").setLevel(logging.ERROR)
logging.getLogger("langchain_xai").setLevel(logging.ERROR)
logging.getLogger("urllib3").setLevel(logging.ERROR)
logging.getLogger("requests").setLevel(logging.ERROR)
logging.getLogger("pinecone").setLevel(logging.ERROR)
logging.getLogger("cohere").setLevel(logging.ERROR)
logging.getLogger("datasets").setLevel(logging.ERROR)
logging.getLogger("transformers").setLevel(logging.ERROR)
logging.getLogger("langchain.chains.retrieval_qa").setLevel(logging.ERROR)
logging.getLogger("langchain.chains").setLevel(logging.ERROR)
logging.getLogger("langchain.retrievers").setLevel(logging.ERROR)

_ = load_dotenv()

# Setup directory paths
ROOT = "/home/adib/MedRAX2"
PROMPT_FILE = f"{ROOT}/medrax/docs/system_prompts.txt"
MODEL_DIR = "/model-weights"
BENCHMARK_DIR = f"{ROOT}/chestagentbench"

model_name = "grok-4"
temperature = 0.7
top_p = 1
max_tokens = 128000
medrax_logs = f"{ROOT}/experiments/grok4_logs"

# Create the logs directory if it doesn't exist
os.makedirs(medrax_logs, exist_ok=True)
device = "cuda:0"

In [None]:
def get_tools():

    rag_config = RAGConfig(
        model="command-a-03-2025",  # Chat model for generating responses
        embedding_model="embed-v4.0",  # Embedding model for the RAG system
        rerank_model="rerank-v3.5",  # Reranking model for the RAG system
        temperature=0.3,
        pinecone_index_name="medrax2",  # Name for the Pinecone index
        chunk_size=1500,
        chunk_overlap=300,
        retriever_k=7,
        local_docs_dir="rag_docs",  # Change this to the path of the documents for RAG
        huggingface_datasets=["VictorLJZ/medrax2"],  # List of HuggingFace datasets to load
        dataset_split="train",  # Which split of the datasets to use
    )

    all_tools = {
        "TorchXRayVisionClassifierTool": lambda: TorchXRayVisionClassifierTool(device=device),
        # "ArcPlusClassifierTool": lambda: ArcPlusClassifierTool(cache_dir=MODEL_DIR, device=device),
        "XRayVQATool": lambda: XRayVQATool(cache_dir=MODEL_DIR, device=device),
        "ChestXRayReportGeneratorTool": lambda: ChestXRayReportGeneratorTool(
            cache_dir=MODEL_DIR, device=device
        ),
        "XRayPhraseGroundingTool": lambda: XRayPhraseGroundingTool(
            cache_dir=MODEL_DIR, temp_dir="temp", load_in_8bit=True, device=device
        ),
        # "MedicalRAGTool": lambda: RAGTool(config=rag_config),
        "WebBrowserTool": lambda: WebBrowserTool(),
    }

    try:
        all_tools["PythonSandboxTool"] = lambda: create_python_sandbox()
    except Exception as e:
        print(f"Error creating PythonSandboxTool: {e}")
        print("Skipping PythonSandboxTool")

    # Initialize all tools for benchmark
    tools_dict = {}
    for tool_name in all_tools.keys():
        try:
            tools_dict[tool_name] = all_tools[tool_name]()
        except Exception as e:
            print(f"Error initializing {tool_name}: {e}")
            print(f"Skipping {tool_name}")

    return list(tools_dict.values())


def get_agent(tools):
    prompts = load_prompts_from_file(PROMPT_FILE)
    prompt = prompts["MEDICAL_ASSISTANT"]

    checkpointer = MemorySaver()

    llm = ModelFactory.create_model(
        model_name=model_name, temperature=temperature, top_p=top_p, max_tokens=max_tokens
    )

    agent = Agent(
        llm,
        tools=tools,
        log_tools=True,
        log_dir="grok4_logs",
        system_prompt=prompt,
        checkpointer=checkpointer,
    )
    thread = {"configurable": {"thread_id": "1"}}
    return agent, thread


def run_medrax(agent, thread, prompt, image_urls=[]):
    from langchain_core.messages import HumanMessage, AIMessage, AIMessageChunk

    messages = [
        HumanMessage(
            content=[
                {"type": "text", "text": prompt},
            ]
            + [{"type": "image_url", "image_url": {"url": image_url}} for image_url in image_urls]
        )
    ]

    final_content = ""
    # Use the same streaming approach as the Gradio interface
    for chunk in agent.workflow.stream({"messages": messages}, thread, stream_mode="updates"):
        for node_name, node_output in chunk.items():
            if "messages" not in node_output:
                continue
            for msg in node_output["messages"]:
                if isinstance(msg, AIMessageChunk) and msg.content:
                    final_content += msg.content
                elif isinstance(msg, AIMessage) and msg.content:
                    # A full AIMessage can sometimes be part of the stream
                    final_content = msg.content

    # After the stream is finished, get the final state for logging
    final_state = agent.workflow.get_state(thread)

    # If we successfully captured content from the stream, return it
    if final_content:
        return final_content.strip(), str(final_state)

    # Fallback: If streaming produced no text, check the final state directly
    # This maintains the logic of your original function as a backup.
    for msg in reversed(final_state.get("messages", [])):
        if isinstance(msg, AIMessage) and msg.content:
            return msg.content.strip(), str(final_state)

    # If no content was found anywhere, report it.
    return "No AI response found", str(final_state)

In [None]:
def create_multimodal_request(question_data, case_details, case_id, question_id, agent, thread):

    # Parse required figures - simplified to work with actual data structure
    try:
        # Get image paths from question data
        if isinstance(question_data["images"], str):
            try:
                required_figures = json.loads(question_data["images"])
            except json.JSONDecodeError:
                required_figures = [question_data["images"]]
        elif isinstance(question_data["images"], list):
            required_figures = question_data["images"]
        else:
            required_figures = [str(question_data["images"])]
    except Exception as e:
        print(f"Error parsing figures: {e}")
        required_figures = []

    # Load local images and convert to base64
    image_data_urls = []
    valid_figures = []

    for fig_path in required_figures:
        try:
            # fig_path is now already an absolute path
            full_image_path = fig_path
            
            # Check if file exists
            if not os.path.exists(full_image_path):
                print(f"Warning: Image file not found: {full_image_path}")
                continue

            # Load and process image
            with Image.open(full_image_path) as img:
                # Convert to RGB if necessary
                if img.mode != "RGB":
                    img = img.convert("RGB")

                # Resize if too large (to save tokens)
                max_size = 1024
                if img.width > max_size or img.height > max_size:
                    img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)

                # Convert to bytes
                buffer = io.BytesIO()
                img.save(buffer, format="JPEG", quality=85)
                img_bytes = buffer.getvalue()

                # Convert to base64
                base64_string = base64.b64encode(img_bytes).decode("utf-8")
                data_url = f"data:image/jpeg;base64,{base64_string}"

                image_data_urls.append(data_url)
                valid_figures.append(fig_path)

        except Exception as e:
            print(f"Error processing image {fig_path}: {e}")
            continue

    # Build figure prompt with available images
    figure_prompt = ""
    if valid_figures:
        figure_prompt = "The following images are provided for this question:\n"
        for i, fig_path in enumerate(valid_figures):
            figure_prompt += f"Image {i+1}: {fig_path}\n"

    # Ensure we have valid images
    if not image_data_urls:
        print(f"Warning: No valid images found for question {question_id}")

    prompt = (
        f"Answer this question using our own vision and reasoning and then "
        "use tools to complement your reasoning. Trust your own judgement over any tools.\\n\\n"
        "After using tools, you MUST provide a final reasoning and answer. Do not stop after a tool call.\\n\\n"
        f"{question_data['question']}\\n\\n{figure_prompt}\\n\\n"
        "Your final response should end with Final answer: <|X|> where X is the letter of your choice (A, B, C, D, E, or F)."
    )

    try:
        start_time = time.time()

        final_response, agent_state = run_medrax(
            agent=agent, thread=thread, prompt=prompt, image_urls=image_data_urls
        )

        # Parse the final answer using regex
        model_answer = ""
        match = re.search(r"<\|([A-F])\|>", final_response, re.IGNORECASE)
        if match:
            model_answer = match.group(1).upper()
        else:
            print(f"Warning: Could not parse final answer from response for question {question_id}")

        duration = time.time() - start_time

        log_entry = {
            "case_id": case_id,
            "question_id": question_id,
            "timestamp": datetime.now().isoformat(),
            "model": model_name,
            "temperature": temperature,
            "duration": round(duration, 2),
            "usage": "",
            "cost": 0,
            "raw_response": final_response,
            "model_answer": model_answer,
            "correct_answer": question_data["answer"],
            "input": {
                "messages": prompt,
                "question_data": {
                    "question": question_data["question"],
                    "explanation": question_data["explanation"],
                    "metadata": question_data.get("metadata", {}),
                    "figures": question_data["images"],
                },
                "image_paths": valid_figures,
                "images_processed": len(image_data_urls),
                "images_found": len(required_figures),
            },
            "agent_state": agent_state,
        }

        # Save detailed log to individual file
        detailed_log_file = f"{medrax_logs}/{case_id}_{question_id}_detailed.json"
        with open(detailed_log_file, "w") as f:
            json.dump(log_entry, f, indent=2)

        return final_response, model_answer

    except Exception as e:
        log_entry = {
            "case_id": case_id,
            "question_id": question_id,
            "timestamp": datetime.now().isoformat(),
            "model": model_name,
            "temperature": temperature,
            "status": "error",
            "error": str(e),
            "cost": 0,
            "input": {
                "messages": prompt,
                "question_data": {
                    "question": question_data["question"],
                    "explanation": question_data["explanation"],
                    "metadata": question_data.get("metadata", {}),
                    "figures": question_data["images"],
                },
                "image_paths": valid_figures,
                "images_processed": len(image_data_urls),
                "images_found": len(required_figures),
            },
        }

        # Save detailed error log to individual file
        detailed_log_file = f"{medrax_logs}/{case_id}_{question_id}_detailed.json"
        with open(detailed_log_file, "w") as f:
            json.dump(log_entry, f, indent=2)

        print(f"Error processing case {case_id}, question {question_id}: {str(e)}")
        return "", ""

In [None]:
def process_single_question(question_data, tools):
    """Process a single question and save results to individual JSON file"""
    try:
        # Get a fresh agent for each question
        agent, thread = get_agent(tools)
        
        case_id = question_data["case_id"]
        question_id = question_data["question_id"]
        
        print(f"Processing Case ID: {case_id}, Question ID: {question_id}")
        
        # Create case details from the question data - simplified structure
        case_details = {
            "case_id": case_id,
            "images": question_data["images"],
            "image_source_urls": question_data.get("image_source_urls", []),
        }
        
        final_response, model_answer = create_multimodal_request(
            question_data, case_details, case_id, question_id, agent, thread
        )
        
        # Create individual log file for this question
        individual_log_file = f"{medrax_logs}/{case_id}_{question_id}.json"
        
        result = {
            "case_id": case_id,
            "question_id": question_id,
            "success": True,
            "final_response": final_response,
            "model_answer": model_answer,
            "correct_answer": question_data["answer"],
            "skipped": final_response is None or final_response == ""
        }
        
        # Save to individual file
        with open(individual_log_file, 'w') as f:
            json.dump(result, f, indent=2)
            
        return result
        
    except Exception as e:
        # Log error to individual file
        error_log_file = f"{medrax_logs}/{case_id}_{question_id}_ERROR.json"
        error_result = {
            "case_id": case_id,
            "question_id": question_id,
            "success": False,
            "error": str(e),
            "skipped": True
        }
        
        with open(error_log_file, 'w') as f:
            json.dump(error_result, f, indent=2)
            
        print(f"Error processing Case ID: {case_id}, Question ID: {question_id}: {str(e)}")
        return error_result

In [None]:
def main(tools):
    questions_data = []
    with open("/home/adib/MedRAX2/chestagentbench/metadata.jsonl", "r") as file:
        for line in file:
            question_data = json.loads(line.strip())

            # Fix image paths by making them absolute
            if "images" in question_data:
                if isinstance(question_data["images"], str):
                    try:
                        # Try to parse as JSON first
                        images = json.loads(question_data["images"])
                        # Convert to absolute paths
                        if isinstance(images, list):
                            images = [
                                os.path.join("/home/adib/MedRAX2/chestagentbench", img)
                                for img in images
                            ]
                        else:
                            images = [
                                os.path.join("/home/adib/MedRAX2/chestagentbench", str(images))
                            ]
                        question_data["images"] = images
                    except json.JSONDecodeError:
                        # Single image path
                        question_data["images"] = [
                            os.path.join(
                                "/home/adib/MedRAX2/chestagentbench", question_data["images"]
                            )
                        ]
                elif isinstance(question_data["images"], list):
                    # List of image paths
                    question_data["images"] = [
                        os.path.join("/home/adib/MedRAX2/chestagentbench", img)
                        for img in question_data["images"]
                    ]

            questions_data.append(question_data)

    # Shuffle with seed 23
    random.seed(23)
    random.shuffle(questions_data)

    # Limit to 20 questions for debugging
    # questions_data = questions_data[:20]
    total_questions = len(questions_data)

    print(f"Beginning benchmark evaluation for model {model_name}\n")
    print(f"Total questions to process: {total_questions}\n")
    print(f"Processing questions in parallel...\n")

    results = []
    batch_size = 5

    # Initialize progress bar for total questions
    with tqdm(total=total_questions, desc="Processing questions", unit="question") as pbar:
        # Process questions in batches of 5
        for i in range(0, total_questions, batch_size):
            batch = questions_data[i : i + batch_size]
            batch_num = i // batch_size + 1

            print(
                f"Processing batch {batch_num}/{(total_questions + batch_size - 1) // batch_size} ({len(batch)} questions)"
            )

            # Process batch in parallel
            with ThreadPoolExecutor(max_workers=batch_size) as executor:
                # Submit all questions in the batch
                future_to_question = {
                    executor.submit(process_single_question, question_data, tools): question_data
                    for question_data in batch
                }

                # Collect results as they complete
                batch_results = []
                for future in concurrent.futures.as_completed(future_to_question):
                    result = future.result()
                    batch_results.append(result)
                    pbar.update(1)  # Update progress bar after each question completes

                results.extend(batch_results)

            print(f"Completed batch {batch_num}\n")

    # Summary
    successful = sum(1 for r in results if r["success"] and not r["skipped"])
    skipped = sum(1 for r in results if r["skipped"])
    errors = sum(1 for r in results if not r["success"])

    print(f"\nBenchmark Summary:")
    print(f"Total Questions: {total_questions}")
    print(f"Successfully Processed: {successful}")
    print(f"Skipped: {skipped}")
    print(f"Errors: {errors}")

    # Save overall summary
    summary = {
        "model": model_name,
        "temperature": temperature,
        "total_questions": total_questions,
        "successful": successful,
        "skipped": skipped,
        "errors": errors,
        "timestamp": datetime.now().isoformat(),
        "results": results,
    }

    summary_file = (
        f"{medrax_logs}/benchmark_summary_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
    )
    with open(summary_file, "w") as f:
        json.dump(summary, f, indent=2)

    print(f"Summary saved to: {summary_file}")

In [None]:
tools = get_tools()

In [None]:
main(tools)