In [None]:
!pip install datasets

In [None]:
!pip install faiss-cpu

In [None]:
!pip install gradio

In [None]:
import os
import logging
import gc
import pandas as pd
import numpy as np
import torch
from torch import nn
from torchvision import transforms, models
from datasets import load_dataset, Dataset
from tqdm import tqdm
from google.colab import drive
from PIL import Image
import faiss
from transformers import AutoTokenizer, AutoModel, T5ForConditionalGeneration, T5Tokenizer, TrainingArguments, Trainer
import multiprocessing as mp
import datetime

# === Configuration for Flexibility ===
class Config:
    IMAGE_MODEL = "chexnet"  # Options: "chexnet", "densenet"
    TEXT_MODEL = "biobert"   # Options: "biobert", "clinicalbert"
    GEN_MODEL = "google/flan-t5-base"  # Options: "flan-t5-base", "flan-t5-large"
    BATCH_SIZE = 32

# === Data Preprocessing ===
class MIMICPreprocessor:
    def __init__(self, save_dir="<your_local_or_drive_path>", batch_size=Config.BATCH_SIZE):
        self.save_dir = save_dir
        self.batch_size = batch_size
        self.image_size = 224
        self.mount_drive()
        self.setup_logging()
        self.setup_transforms()

    def mount_drive(self):
        drive.mount('/content/drive')
        os.makedirs(self.save_dir, exist_ok=True)
        logging.info(f"Mounted Google Drive and created directory: {self.save_dir}")

    def setup_logging(self):
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(levelname)s - %(message)s',
            handlers=[logging.FileHandler(f"{self.save_dir}/preprocessing.log"), logging.StreamHandler()]
        )

    def setup_transforms(self):
        self.transform = transforms.Compose([
            transforms.Resize((self.image_size, self.image_size)),
            transforms.Grayscale(num_output_channels=1),
            transforms.ToTensor(),
        ])

    def process_text(self, findings, impression):
        findings = findings or "No findings available."
        impression = impression or "No impression available."
        combined_text = f"FINDINGS: {findings}\nIMPRESSION: {impression}"
        return " ".join(combined_text.split())

    def process_image(self, image):
        try:
            if not isinstance(image, Image.Image):
                raise ValueError("Input must be a PIL Image")
            if image.mode != 'L':
                image = image.convert('L')
            return self.transform(image)
        except Exception as e:
            logging.warning(f"Failed to process image: {str(e)}")
            return None

    def preprocess_batch(self, batch, batch_idx):
        processed_data = {'image_tensors': [], 'combined_text': [], 'valid_indices': []}
        for idx in range(len(batch['image'])):
            try:
                combined_text = self.process_text(batch['findings'][idx], batch['impression'][idx])
                img_tensor = self.process_image(batch['image'][idx])
                if img_tensor is not None:
                    processed_data['image_tensors'].append(img_tensor)
                    processed_data['combined_text'].append(combined_text)
                    processed_data['valid_indices'].append(idx)
            except Exception as e:
                logging.warning(f"Failed to process item {idx} in batch {batch_idx}: {str(e)}")
            if idx % 50 == 0:
                gc.collect()
                torch.cuda.empty_cache()
        return processed_data

    def save_preprocessed_data(self, processed_data, batch_idx):
        try:
            torch.save(torch.stack(processed_data['image_tensors']), f"{self.save_dir}/images_batch_{batch_idx}.pt")
            pd.DataFrame({
                'combined_text': processed_data['combined_text'],
                'valid_indices': processed_data['valid_indices']
            }).to_csv(f"{self.save_dir}/text_batch_{batch_idx}.csv", index=False)
            logging.info(f"Saved batch {batch_idx}")
        except Exception as e:
            logging.error(f"Failed to save batch {batch_idx}: {str(e)}")

    def preprocess(self):
        print("\n=== Starting MIMIC-CXR Dataset Preprocessing ===\n")
        dataset = load_dataset("itsanmolgupta/mimic-cxr-dataset")
        total_samples = len(dataset['train'])
        print(f"✓ Loaded dataset with {total_samples} samples\n")
        processed_images, failed_images = 0, 0
        for batch_idx in tqdm(range(0, total_samples, self.batch_size), desc="Processing Batches"):
            batch = dataset['train'][batch_idx:batch_idx + self.batch_size]
            processed_data = self.preprocess_batch(batch, batch_idx)
            if processed_data['valid_indices']:
                self.save_preprocessed_data(processed_data, batch_idx)
                processed_images += len(processed_data['valid_indices'])
            failed_images += self.batch_size - len(processed_data['valid_indices'])
            if batch_idx % (self.batch_size * 10) == 0 and batch_idx > 0:
                print(f"\nProgress: {processed_images} processed, {failed_images} failed\n")
            gc.collect()
            torch.cuda.empty_cache()
        print(f"\n=== Preprocessing Complete ===\nTotal: {processed_images + failed_images}, Success: {processed_images}, Failed: {failed_images}")

# === Knowledge Base Creation with Distributed Processing ===
class EmbeddingGenerator:
    def __init__(self, preprocessed_dir="<your_path>",
                 output_dir="<your_path>"):
        self.preprocessed_dir = preprocessed_dir
        self.output_dir = output_dir
        self.device = torch.device("cpu")
        drive.mount('/content/drive', force_remount=True)
        os.makedirs(self.output_dir, exist_ok=True)
        self.setup_logging()
        self.setup_models()

    def setup_logging(self):
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(levelname)s - %(message)s',
            handlers=[logging.FileHandler(f"{self.output_dir}/embedding_generation.log"), logging.StreamHandler()]
        )

    def setup_models(self):
        logging.info("Loading models on CPU...")
        if Config.IMAGE_MODEL == "chexnet":
            self.image_model = models.densenet121(pretrained=False)
            chexnet_url = "https://github.com/arnoweng/CheXNet/raw/master/model.pth.tar"
            checkpoint = torch.hub.load_state_dict_from_url(chexnet_url, progress=True, map_location=self.device)
            state_dict = {k.replace('densenet121.', '').replace('.norm.', '.norm').replace('.conv.', '.conv')
                         .replace('.1', '1').replace('.2', '2').replace('classifier.0', 'classifier'): v
                         for k, v in checkpoint['state_dict'].items()}
            self.image_model.load_state_dict(state_dict, strict=False)
        else:
            self.image_model = models.densenet121(pretrained=True)
        self.image_model = nn.Sequential(*list(self.image_model.children())[:-1]).to(self.device).eval()

        text_model_name = "dmis-lab/biobert-v1.1" if Config.TEXT_MODEL == "biobert" else "emilyalsentzer/Bio_ClinicalBERT"
        try:
            self.text_tokenizer = AutoTokenizer.from_pretrained(text_model_name)
            self.text_model = AutoModel.from_pretrained(text_model_name).to(self.device).eval()
            logging.info(f"{text_model_name} loaded successfully")
        except Exception as e:
            logging.error(f"Failed to load text model: {str(e)}")
            raise

    def generate_text_embedding(self, text):
        try:
            if len(self.text_tokenizer.encode(text, add_special_tokens=False)) > 512:
                logging.warning(f"Text truncated: {text[:50]}...")
            inputs = self.text_tokenizer(text, padding=True, truncation=True, return_tensors="pt", max_length=512).to(self.device)
            with torch.no_grad():
                outputs = self.text_model(**inputs)
            return outputs.last_hidden_state.mean(dim=1).cpu().numpy()
        except Exception as e:
            logging.warning(f"Failed to generate text embedding: {str(e)}")
            return None

    def generate_image_embedding(self, image_tensor):
        try:
            if image_tensor.dim() != 3 or image_tensor.shape[1:] != (224, 224):
                raise ValueError("Image tensor must be [C, 224, 224]")
            if image_tensor.shape[0] == 1:
                image_tensor = image_tensor.repeat(3, 1, 1)
            image_tensor = image_tensor.unsqueeze(0).to(self.device)
            with torch.no_grad():
                embedding = self.image_model(image_tensor)
                embedding = nn.functional.avg_pool2d(embedding, kernel_size=(7, 7)).squeeze().flatten().cpu().numpy()
            return embedding.reshape(1, -1)
        except Exception as e:
            logging.warning(f"Failed to generate image embedding: {str(e)}")
            return None

    def process_batch(self, args):
        image_file, text_file = args
        try:
            image_path = os.path.join(self.preprocessed_dir, image_file)
            text_path = os.path.join(self.preprocessed_dir, text_file)
            if not os.path.exists(image_path) or not os.path.exists(text_path):
                logging.warning(f"Missing file: {image_file} or {text_file}")
                return None
            image_tensors = torch.load(image_path, map_location=self.device, weights_only=True)
            text_data = pd.read_csv(text_path)
            image_embeddings, text_embeddings, valid_indices, valid_texts = [], [], [], []
            for i in range(len(image_tensors)):
                img_emb = self.generate_image_embedding(image_tensors[i])
                txt_emb = self.generate_text_embedding(text_data['combined_text'].iloc[i])
                if img_emb is not None and txt_emb is not None:
                    image_embeddings.append(img_emb[0])
                    text_embeddings.append(txt_emb[0])
                    valid_indices.append(text_data['valid_indices'].iloc[i])
                    valid_texts.append(text_data['combined_text'].iloc[i])
            return {
                'image_embeddings': np.array(image_embeddings) if image_embeddings else None,
                'text_embeddings': np.array(text_embeddings) if text_embeddings else None,
                'valid_indices': valid_indices,
                'valid_texts': valid_texts
            }
        except Exception as e:
            logging.error(f"Failed to process batch {image_file}: {str(e)}")
            return None

    def create_faiss_index(self, embeddings):
        try:
            dimension = embeddings.shape[1]
            index = faiss.IndexFlatL2(dimension)
            index.add(embeddings.astype('float32'))
            return index
        except Exception as e:
            logging.error(f"Failed to create FAISS index: {str(e)}")
            return None

    def generate_embeddings(self, chunk_size=5):
        print("\n=== Starting Knowledge Base Creation with Distributed Processing ===\n")
        image_files = sorted([f for f in os.listdir(self.preprocessed_dir) if f.startswith('images_batch_')])
        text_files = sorted([f for f in os.listdir(self.preprocessed_dir) if f.startswith('text_batch_')])
        matched_files = list(zip(image_files, text_files))
        print(f"Found {len(matched_files)} batches to process")

        with mp.Pool(processes=mp.cpu_count()) as pool:
            results = list(tqdm(pool.imap(self.process_batch, matched_files), total=len(matched_files), desc="Processing Batches"))

        all_image_embeddings, all_text_embeddings, all_valid_indices, all_valid_texts = [], [], [], []
        for chunk_start in range(0, len(matched_files), chunk_size):
            chunk_end = min(chunk_start + chunk_size, len(matched_files))
            for result in results[chunk_start:chunk_end]:
                if result and result['image_embeddings'] is not None:
                    all_image_embeddings.append(result['image_embeddings'])
                    all_text_embeddings.append(result['text_embeddings'])
                    all_valid_indices.extend(result['valid_indices'])
                    all_valid_texts.extend(result['valid_texts'])
            if all_image_embeddings:
                self.save_intermediate_results(all_image_embeddings, all_text_embeddings, all_valid_indices, all_valid_texts, chunk_start)
                all_image_embeddings, all_text_embeddings, all_valid_indices, all_valid_texts = [], [], [], []
            gc.collect()
        self.merge_and_save_results()

    def save_intermediate_results(self, image_embs, text_embs, indices, texts, chunk_id):
        try:
            intermediate_dir = os.path.join(self.output_dir, "intermediate")
            os.makedirs(intermediate_dir, exist_ok=True)
            np.save(f"{intermediate_dir}/img_emb_chunk_{chunk_id}.npy", np.vstack(image_embs))
            np.save(f"{intermediate_dir}/txt_emb_chunk_{chunk_id}.npy", np.vstack(text_embs))
            pd.DataFrame({'valid_index': indices, 'combined_text': texts}).to_csv(
                f"{intermediate_dir}/text_data_chunk_{chunk_id}.csv", index=False
            )
            logging.info(f"Saved chunk {chunk_id}")
        except Exception as e:
            logging.error(f"Failed to save chunk {chunk_id}: {str(e)}")

    def merge_and_save_results(self):
        intermediate_dir = os.path.join(self.output_dir, "intermediate")
        img_emb_files = sorted([f for f in os.listdir(intermediate_dir) if f.startswith('img_emb_chunk_')])
        txt_emb_files = sorted([f for f in os.listdir(intermediate_dir) if f.startswith('txt_emb_chunk_')])
        text_files = sorted([f for f in os.listdir(intermediate_dir) if f.startswith('text_data_chunk_')])
        total_samples = sum(np.load(os.path.join(intermediate_dir, f)).shape[0] for f in img_emb_files)
        final_image_embs = np.memmap(f"{self.output_dir}/image_embeddings_temp.dat", dtype='float32', mode='w+', shape=(total_samples, 1024))
        final_text_embs = np.memmap(f"{self.output_dir}/text_embeddings_temp.dat", dtype='float32', mode='w+', shape=(total_samples, 768))
        offset = 0
        for img_file, txt_file in zip(img_emb_files, txt_emb_files):
            img_data = np.load(os.path.join(intermediate_dir, img_file))
            txt_data = np.load(os.path.join(intermediate_dir, txt_file))
            chunk_size = img_data.shape[0]
            final_image_embs[offset:offset + chunk_size] = img_data
            final_text_embs[offset:offset + chunk_size] = txt_data
            offset += chunk_size
            logging.info(f"Merged chunk: {img_file}")
        np.save(f"{self.output_dir}/image_embeddings.npy", final_image_embs)
        np.save(f"{self.output_dir}/text_embeddings.npy", final_text_embs)
        image_index = self.create_faiss_index(final_image_embs)
        text_index = self.create_faiss_index(final_text_embs)
        final_text_data = pd.concat([pd.read_csv(os.path.join(intermediate_dir, f)) for f in text_files])
        faiss.write_index(image_index, f"{self.output_dir}/image_index.faiss")
        faiss.write_index(text_index, f"{self.output_dir}/text_index.faiss")
        final_text_data.to_csv(f"{self.output_dir}/text_data.csv", index=False)
        os.remove(f"{self.output_dir}/image_embeddings_temp.dat")
        os.remove(f"{self.output_dir}/text_embeddings_temp.dat")
        print(f"\n=== Knowledge Base Creation Complete ===\nSaved to: {self.output_dir}")

# === Fine-Tuning Data Preparation ===
def create_pairs(row):
    combined_text = row['combined_text']
    findings = combined_text.split('FINDINGS:')[1].strip() if 'FINDINGS:' in combined_text else combined_text
    impression = combined_text.split('IMPRESSION:')[1].strip() if 'IMPRESSION:' in combined_text else ""
    conditions = []
    if any(k in combined_text.lower() for k in {"pneumonia", "consolidation", "opacity", "infiltrate"}):
        conditions.append("pneumonia")
    if any(k in combined_text.lower() for k in {"pulmonary edema", "interstitial", "effusion", "haze", "cardiomegaly"}):
        conditions.append("pulmonary edema")
    if any(k in combined_text.lower() for k in {"clear", "normal", "unremarkable", "no consolidation", "no effusion"}):
        conditions.append("normal")
    if not conditions:
        conditions.append("unknown")
    condition_str = ", ".join(conditions)
    if "normal" in conditions:
        query = "What are normal chest X-ray findings?"
        response = f"""This chest X-ray demonstrates normal findings:\n1. Radiographic Findings: {findings}\n2. Key Characteristics: Clear lung fields, normal cardiac silhouette.\n3. Clinical Significance: No acute abnormalities."""
    elif "unknown" in conditions:
        query = f"Describe the findings in this chest X-ray with {condition_str}."
        response = f"""Chest X-ray Analysis:\n1. Radiographic Findings: {findings}\n2. Key Observations: {impression if impression else 'No impression provided'}.\n3. Clinical Significance: Correlate with clinical presentation."""
    else:
        query = f"What does a chest X-ray with {condition_str} look like?"
        response = f"""Analysis for {condition_str}:\n1. Radiographic Findings: {findings}\n2. Key Characteristics: {condition_str}-specific features.\n3. Clinical Significance: Suggests {condition_str}, correlate clinically."""
    return {"input": f"Query: {query}\nContext: {combined_text}\nAnswer:", "output": response, "condition": condition_str}

def validate_responses(dataset, n=5):
    for i in range(min(n, len(dataset))):
        logging.info(f"Sample {i}: Query: {dataset['input'][i]}\nResponse: {dataset['output'][i]}")

def prepare_finetuning_data():
    drive.mount('/content/drive')
    df = pd.read_csv('<your_path>')
    df_sample = df.sample(2000, random_state=42)
    data = [create_pairs(row) for _, row in df_sample.iterrows()]
    dataset = pd.DataFrame(data)
    train_size = int(0.8 * len(dataset))
    train_data = dataset[:train_size]
    eval_data = dataset[train_size:]
    validate_responses(train_data)
    train_data.to_csv('<your_path>', index=False)
    eval_data.to_csv('<your_path>', index=False)
    print(f"Prepared {len(train_data)} training samples and {len(eval_data)} evaluation samples")

# === Fine-Tuning ===
def fine_tune_model():
    drive.mount('/content/drive', force_remount=True)
    device = 'cpu'
    train_df = pd.read_csv('<your_path>')
    eval_df = pd.read_csv('<your_path>')
    train_dataset = Dataset.from_pandas(train_df)
    eval_dataset = Dataset.from_pandas(eval_df)
    model_name = Config.GEN_MODEL
    tokenizer = T5Tokenizer.from_pretrained(model_name)
    model = T5ForConditionalGeneration.from_pretrained(model_name)

    def tokenize_function(examples):
        inputs = tokenizer(examples["input"], padding="max_length", truncation=True, max_length=512)
        outputs = tokenizer(examples["output"], padding="max_length", truncation=True, max_length=512)
        labels = outputs["input_ids"].copy()
        for i in range(len(labels)):
            labels[i] = [label if label != tokenizer.pad_token_id else -100 for label in labels[i]]
        return {"input_ids": inputs["input_ids"], "attention_mask": inputs["attention_mask"], "labels": labels}

    max_train_samples, max_eval_samples = 1600, 400
    if len(train_dataset) > max_train_samples:
        train_dataset = train_dataset.select(range(max_train_samples))
    if len(eval_dataset) > max_eval_samples:
        eval_dataset = eval_dataset.select(range(max_eval_samples))
    tokenized_train = train_dataset.map(tokenize_function, batched=True, batch_size=200, remove_columns=["input", "output", "condition"])
    tokenized_eval = eval_dataset.map(tokenize_function, batched=True, batch_size=200, remove_columns=["input", "output", "condition"])
    gc.collect()

    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    output_dir = f"./flan-t5-finetuned-mediquery_cpu_{timestamp}"
    model_save_path = f"<your_path>/flan-t5-finetuned-mediquery_cpu_{timestamp}"
    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=5,
        per_device_train_batch_size=2,
        per_device_eval_batch_size=2,
        warmup_steps=500,
        weight_decay=0.01,
        logging_dir='./logs',
        logging_steps=10,
        eval_strategy="steps",
        eval_steps=100,
        save_strategy="steps",
        save_steps=100,
        load_best_model_at_end=True,
        fp16=False,
        gradient_accumulation_steps=8,
        dataloader_num_workers=0,
        seed=42,
        learning_rate=2e-5,
        lr_scheduler_type="cosine",
        remove_unused_columns=False,
        label_names=["labels"],
        optim="adamw_torch"
    )
    trainer = Trainer(model=model, args=training_args, train_dataset=tokenized_train, eval_dataset=tokenized_eval)
    trainer.train()
    gc.collect()
    torch.cuda.empty_cache()
    os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
    model.save_pretrained(model_save_path)
    tokenizer.save_pretrained(model_save_path)
    print(f"Fine-tuning completed and model saved to {model_save_path}")

# === Execution ===
if __name__ == "__main__":
    logging.getLogger().handlers.clear()
    preprocessor = MIMICPreprocessor()
    preprocessor.preprocess()
    embedding_generator = EmbeddingGenerator()
    embedding_generator.generate_embeddings()
    prepare_finetuning_data()
    fine_tune_model()

In [None]:
import os
import logging
import numpy as np
import torch
from torch import nn
import pandas as pd
from torchvision import transforms, models
from PIL import Image
import faiss
from transformers import AutoTokenizer, AutoModel, T5ForConditionalGeneration, T5Tokenizer
import gradio as gr
import cv2
import traceback
from datetime import datetime
from sklearn.metrics import precision_score, recall_score, f1_score
import re
import random
import functools
import gc
from collections import OrderedDict
from google.colab import drive
import json
import sys
import time
from tqdm.auto import tqdm
import warnings
import matplotlib.pyplot as plt

# Suppress unnecessary warnings
warnings.filterwarnings("ignore", category=UserWarning)

# === Configuration ===
class Config:
    """Configuration for MediQuery system"""
    # Model configuration
    IMAGE_MODEL = "chexnet"  # Options: "chexnet", "densenet"
    TEXT_MODEL = "biobert"   # Options: "biobert", "clinicalbert"
    GEN_MODEL = "flan-t5-base-finetuned"  # Base generation model

    # Resource management
    CACHE_SIZE = 200
    CACHE_EXPIRY_TIME = 3600  # Cache expiry time in seconds (1 hour)
    LAZY_LOADING = True       # Enable lazy loading of models
    USE_HALF_PRECISION = True # Use half precision for models if available

    # Feature flags
    DEBUG = True              # Enable detailed debugging
    PHI_DETECTION_ENABLED = True  # Enable PHI detection
    ANATOMY_MAPPING_ENABLED = True  # Enable anatomical mapping

    # Thresholds and parameters
    CONFIDENCE_THRESHOLD = 0.4  # Threshold for flagging low confidence
    TOP_K_RETRIEVAL = 30        # Number of items to retrieve from knowledge base
    MAX_CONTEXT_DOCS = 5        # Maximum documents to include in context

    # Advanced retrieval settings
    DYNAMIC_RERANKING = True    # Dynamically adjust reranking weights
    DIVERSITY_PENALTY = 0.1     # Penalty for duplicate content

    # Performance optimization
    BATCH_SIZE = 4              # Batch size for processing
    OPTIMIZE_MEMORY = True      # Optimize memory usage
    USE_CACHING = True          # Use caching for embeddings and queries

    # UI Settings
    IMAGE_HEIGHT = 400  # Increased from 300
    IMAGE_WIDTH = 400   # Increased from 300
    EXAMPLES_TO_SHOW = 8  # Increased from 6
    THEME = "default"  # Options: "default", "dark", "light"

    # Path settings
    DEFAULT_KNOWLEDGE_BASE_DIR = "<your_path>"
    DEFAULT_MODEL_PATH = "<your_path>"
    LOG_DIR = "./logs"

    # Advanced settings
    EMBEDDING_AGGREGATION = "weighted_avg"  # Options: "avg", "weighted_avg", "cls", "pooled"
    EMBEDDING_NORMALIZE = True  # Normalize embeddings to unit length

    # Error recovery settings
    MAX_RETRIES = 3  # Maximum retry attempts for model operations
    RECOVERY_WAIT_TIME = 1  # Seconds to wait between retries

# Set up logging with improved formatting
os.makedirs(Config.LOG_DIR, exist_ok=True)
logging.basicConfig(
    level=logging.DEBUG if Config.DEBUG else logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(os.path.join(Config.LOG_DIR, f"mediquery_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger("MediQuery")

def debug_print(msg):
    """Print and log debug messages"""
    if Config.DEBUG:
        logger.debug(msg)
        print(f"DEBUG: {msg}")

# === Helper Functions for Conditions ===
def get_mimic_cxr_conditions():
    """Return the comprehensive list of conditions in MIMIC-CXR dataset"""
    return [
        "atelectasis",
        "cardiomegaly",
        "consolidation",
        "edema",
        "enlarged cardiomediastinum",
        "fracture",
        "lung lesion",
        "lung opacity",
        "no finding",
        "pleural effusion",
        "pleural other",
        "pneumonia",
        "pneumothorax",
        "support devices"
    ]

def get_condition_synonyms():
    """Return synonyms for conditions to improve matching"""
    return {
        "atelectasis": ["atelectatic change", "collapsed lung", "lung collapse"],
        "cardiomegaly": ["enlarged heart", "cardiac enlargement", "heart enlargement"],
        "consolidation": ["airspace opacity", "air-space opacity", "alveolar opacity"],
        "edema": ["pulmonary edema", "fluid overload", "vascular congestion"],
        "fracture": ["broken bone", "bone fracture", "rib fracture"],
        "lung opacity": ["pulmonary opacity", "opacification", "lung opacification"],
        "pleural effusion": ["pleural fluid", "fluid in pleural space", "effusion"],
        "pneumonia": ["pulmonary infection", "lung infection", "bronchopneumonia"],
        "pneumothorax": ["air in pleural space", "collapsed lung", "ptx"],
        "support devices": ["tube", "line", "catheter", "pacemaker", "device"]
    }

def get_anatomical_regions():
    """Return mapping of anatomical regions with descriptions and conditions"""
    return {
        "upper_right_lung": {
            "description": "Upper right lung field",
            "conditions": ["pneumonia", "lung lesion", "pneumothorax", "atelectasis"]
        },
        "upper_left_lung": {
            "description": "Upper left lung field",
            "conditions": ["pneumonia", "lung lesion", "pneumothorax", "atelectasis"]
        },
        "middle_right_lung": {
            "description": "Middle right lung field",
            "conditions": ["pneumonia", "lung opacity", "atelectasis"]
        },
        "lower_right_lung": {
            "description": "Lower right lung field",
            "conditions": ["pneumonia", "pleural effusion", "atelectasis"]
        },
        "lower_left_lung": {
            "description": "Lower left lung field",
            "conditions": ["pneumonia", "pleural effusion", "atelectasis"]
        },
        "heart": {
            "description": "Cardiac silhouette",
            "conditions": ["cardiomegaly", "enlarged cardiomediastinum"]
        },
        "hilar": {
            "description": "Hilar regions",
            "conditions": ["enlarged cardiomediastinum", "adenopathy"]
        },
        "costophrenic_angles": {
            "description": "Costophrenic angles",
            "conditions": ["pleural effusion", "pneumothorax"]
        },
        "spine": {
            "description": "Spine",
            "conditions": ["fracture", "degenerative changes"]
        },
        "diaphragm": {
            "description": "Diaphragm",
            "conditions": ["elevated diaphragm", "flattened diaphragm"]
        }
    }

# === PHI Detection and Anonymization ===
def detect_phi(text):
    """Detect potential PHI (Protected Health Information) in text"""
    # Patterns for PHI detection
    patterns = {
        'name': r'\b[A-Z][a-z]+ [A-Z][a-z]+\b',
        'mrn': r'\b[A-Z]{0,3}[0-9]{4,10}\b',
        'ssn': r'\b[0-9]{3}[-]?[0-9]{2}[-]?[0-9]{4}\b',
        'date': r'\b(0?[1-9]|1[0-2])[\/\-](0?[1-9]|[12]\d|3[01])[\/\-](19|20)\d{2}\b',
        'phone': r'\b(\+\d{1,2}\s?)?\(?\d{3}\)?[\s.-]?\d{3}[\s.-]?\d{4}\b',
        'email': r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b',
        'address': r'\b\d+\s+[A-Z][a-z]+\s+[A-Z][a-z]+\.?\b'
    }

    # Check each pattern
    phi_detected = {}
    for phi_type, pattern in patterns.items():
        matches = re.findall(pattern, text)
        if matches:
            phi_detected[phi_type] = matches

    return phi_detected

def anonymize_text(text):
    """Replace potential PHI with [REDACTED]"""
    if not text:
        return ""

    if not Config.PHI_DETECTION_ENABLED:
        return text

    try:
        # Detect PHI
        phi_detected = detect_phi(text)

        # Replace PHI with [REDACTED]
        anonymized = text
        for phi_type, matches in phi_detected.items():
            for match in matches:
                anonymized = anonymized.replace(match, "[REDACTED]")

        return anonymized
    except Exception as e:
        debug_print(f"Error in anonymize_text: {str(e)}")
        return text

# === LRU Cache Implementation with Enhanced Features ===
class LRUCache:
    """LRU (Least Recently Used) Cache implementation with TTL and size tracking"""
    def __init__(self, capacity=Config.CACHE_SIZE, expiry_time=Config.CACHE_EXPIRY_TIME):
        self.cache = OrderedDict()
        self.capacity = capacity
        self.expiry_time = expiry_time  # in seconds
        self.timestamps = {}
        self.size_tracking = {
            "current_size_bytes": 0,
            "max_size_bytes": 0,
            "items_evicted": 0,
            "cache_hits": 0,
            "cache_misses": 0
        }

    def get(self, key):
        """Get item from cache with statistics tracking"""
        if key not in self.cache:
            self.size_tracking["cache_misses"] += 1
            return None

        # Check expiry
        if self.is_expired(key):
            self._remove_with_tracking(key)
            self.size_tracking["cache_misses"] += 1
            return None

        # Move to end (recently used)
        self.size_tracking["cache_hits"] += 1
        value = self.cache.pop(key)
        self.cache[key] = value
        return value

    def put(self, key, value):
        """Add item to cache with size tracking"""
        # Calculate approximate size of the value
        value_size = self._estimate_size(value)

        if key in self.cache:
            old_value = self.cache.pop(key)
            old_size = self._estimate_size(old_value)
            self.size_tracking["current_size_bytes"] -= old_size

        # Make space if needed
        while len(self.cache) >= self.capacity or (
            Config.OPTIMIZE_MEMORY and
            self.size_tracking["current_size_bytes"] + value_size > 1e9  # 1 GB limit
        ):
            self._evict_least_recently_used()

        # Add new item and timestamp
        self.cache[key] = value
        self.timestamps[key] = datetime.now().timestamp()
        self.size_tracking["current_size_bytes"] += value_size

        # Update max size
        if self.size_tracking["current_size_bytes"] > self.size_tracking["max_size_bytes"]:
            self.size_tracking["max_size_bytes"] = self.size_tracking["current_size_bytes"]

    def is_expired(self, key):
        """Check if item has expired"""
        if key not in self.timestamps:
            return True

        current_time = datetime.now().timestamp()
        return (current_time - self.timestamps[key]) > self.expiry_time

    def _evict_least_recently_used(self):
        """Remove least recently used item with tracking"""
        if not self.cache:
            return

        # Get oldest item
        key, value = self.cache.popitem(last=False)
        # Remove from timestamps and update tracking
        self._remove_with_tracking(key)

    def _remove_with_tracking(self, key):
        """Remove item with size tracking"""
        if key in self.cache:
            value = self.cache.pop(key)
            value_size = self._estimate_size(value)
            self.size_tracking["current_size_bytes"] -= value_size
            self.size_tracking["items_evicted"] += 1

        if key in self.timestamps:
            self.timestamps.pop(key)

    def remove(self, key):
        """Remove item from cache"""
        self._remove_with_tracking(key)

    def clear(self):
        """Clear the cache"""
        self.cache.clear()
        self.timestamps.clear()
        self.size_tracking["current_size_bytes"] = 0

    def get_stats(self):
        """Get cache statistics"""
        return {
            "size_bytes": self.size_tracking["current_size_bytes"],
            "max_size_bytes": self.size_tracking["max_size_bytes"],
            "items": len(self.cache),
            "capacity": self.capacity,
            "items_evicted": self.size_tracking["items_evicted"],
            "hit_rate": self.size_tracking["cache_hits"] /
                        (self.size_tracking["cache_hits"] + self.size_tracking["cache_misses"] + 1e-8)
        }

    def _estimate_size(self, obj):
        """Estimate memory size of an object in bytes"""
        if obj is None:
            return 0

        if isinstance(obj, np.ndarray):
            return obj.nbytes
        elif isinstance(obj, torch.Tensor):
            return obj.element_size() * obj.nelement()
        elif isinstance(obj, (str, bytes)):
            return len(obj)
        elif isinstance(obj, (list, tuple)):
            return sum(self._estimate_size(x) for x in obj)
        elif isinstance(obj, dict):
            return sum(self._estimate_size(k) + self._estimate_size(v) for k, v in obj.items())
        else:
            # Fallback - rough estimate
            return sys.getsizeof(obj)

# === Improved Lazy Model Loading ===
class LazyModel:
    """Lazy loading wrapper for models with proper method forwarding and error recovery"""
    def __init__(self, model_name, model_class, device, **kwargs):
        self.model_name = model_name
        self.model_class = model_class
        self.device = device
        self.kwargs = kwargs
        self._model = None
        self.last_error = None
        self.last_used = datetime.now()
        debug_print(f"LazyModel initialized for {model_name}")

    def _ensure_loaded(self, retries=Config.MAX_RETRIES):
        """Ensure model is loaded with retry mechanism"""
        if self._model is None:
            debug_print(f"Lazy loading model: {self.model_name}")
            for attempt in range(retries):
                try:
                    self._model = self.model_class.from_pretrained(self.model_name, **self.kwargs)

                    # Apply memory optimizations
                    if Config.OPTIMIZE_MEMORY:
                        # Convert to half precision if available and enabled
                        if Config.USE_HALF_PRECISION and self.device.type == 'cuda' and hasattr(self._model, 'half'):
                            self._model = self._model.half()
                            debug_print(f"Using half precision for {self.model_name}")

                    self._model = self._model.to(self.device)
                    self._model.eval()  # Set to evaluation mode
                    debug_print(f"Model {self.model_name} loaded successfully")
                    self.last_error = None
                    break
                except Exception as e:
                    self.last_error = str(e)
                    debug_print(f"Error loading model {self.model_name} (attempt {attempt+1}/{retries}): {str(e)}")
                    if attempt < retries - 1:
                        # Wait before retrying
                        time.sleep(Config.RECOVERY_WAIT_TIME)
                    else:
                        raise RuntimeError(f"Failed to load model {self.model_name} after {retries} attempts: {str(e)}")

        # Update last used timestamp
        self.last_used = datetime.now()
        return self._model

    def __call__(self, *args, **kwargs):
        """Call the model"""
        model = self._ensure_loaded()
        return model(*args, **kwargs)

    # Forward common model methods
    def generate(self, *args, **kwargs):
        """Forward generate method to model with error recovery"""
        model = self._ensure_loaded()
        try:
            return model.generate(*args, **kwargs)
        except Exception as e:
            # If generation fails, try reloading the model once
            debug_print(f"Generation failed, reloading model: {str(e)}")
            self.unload()
            model = self._ensure_loaded()
            return model.generate(*args, **kwargs)

    def to(self, device):
        """Move model to specified device"""
        self.device = device
        if self._model is not None:
            self._model = self._model.to(device)
        return self

    def eval(self):
        """Set model to evaluation mode"""
        if self._model is not None:
            self._model = self._model.eval()
        return self

    def unload(self):
        """Unload model to free memory"""
        if self._model is not None:
            debug_print(f"Unloading model {self.model_name}")
            self._model = None
            # Force garbage collection
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

    def is_loaded(self):
        """Check if model is loaded"""
        return self._model is not None

# === MIMIC-CXR Index Builder ===
def build_mimic_cxr_index(mimic_meta_path=None):
    """Build an index from MIMIC-CXR metadata or create a synthetic one"""
    debug_print("Building MIMIC-CXR index (synthetic version)")

    try:
        # Check if text_data.csv exists and use it instead
        text_data_path = os.path.join(os.path.dirname(mimic_meta_path), "text_data.csv")
        if os.path.exists(text_data_path):
            debug_print(f"Using text_data.csv instead of mimic_meta.csv")
            meta_df = pd.read_csv(text_data_path)
            debug_print(f"Loaded text_data with {len(meta_df)} rows")

            # Extract metadata from text_data.csv
            # Assuming text_data.csv has a combined_text column
            conditions = get_mimic_cxr_conditions()
            condition_synonyms = get_condition_synonyms()

            # Build index by condition
            index_by_condition = {condition: [] for condition in conditions}
            index_by_id = {}

            # Use row indices as pseudo-image IDs
            for idx, row in meta_df.iterrows():
                pseudo_id = f"img_{idx}"
                text = row.get('combined_text', '')

                # Store in index_by_id
                index_by_id[pseudo_id] = {
                    'findings': text,
                    'impression': '',  # We don't have separate impression
                    'subject_id': pseudo_id,
                    'study_id': pseudo_id
                }

                # Check for conditions in the text
                text_lower = text.lower()
                for condition in conditions:
                    if condition in text_lower:
                        index_by_condition[condition].append(pseudo_id)
                        continue

                    # Check synonyms
                    if condition in condition_synonyms:
                        for synonym in condition_synonyms[condition]:
                            if synonym in text_lower:
                                index_by_condition[condition].append(pseudo_id)
                                break

            # Debug output
            for condition in conditions:
                count = len(index_by_condition[condition])
                debug_print(f"Found {count} examples for {condition}")

            return {
                "by_image_id": index_by_id,
                "by_condition": index_by_condition
            }
        else:
            debug_print("Neither mimic_meta.csv nor text_data.csv found, using fully synthetic index")
            return {"by_image_id": {}, "by_condition": {}}

    except Exception as e:
        debug_print(f"Error building MIMIC-CXR index: {str(e)}")
        debug_print(traceback.format_exc())
        return {"by_image_id": {}, "by_condition": {}}

# === Performance Monitoring Decorator ===
def performance_monitor(func):
    """Decorator to monitor function performance"""
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        # Get instance if it's a method
        instance = args[0] if args and hasattr(args[0], '__class__') else None

        start_time = time.time()
        start_memory = 0
        if torch.cuda.is_available():
            torch.cuda.synchronize()
            start_memory = torch.cuda.memory_allocated()

        try:
            result = func(*args, **kwargs)
        except Exception as e:
            end_time = time.time()
            # Log performance even on failure
            func_name = func.__name__
            elapsed = end_time - start_time
            debug_print(f"ERROR in {func_name}: {str(e)}, took {elapsed:.2f} seconds")

            # Store metrics if available
            if instance and hasattr(instance, 'performance_metrics'):
                if func_name not in instance.performance_metrics:
                    instance.performance_metrics[func_name] = []
                instance.performance_metrics[func_name].append(elapsed)
            raise  # Re-raise the exception

        end_time = time.time()
        end_memory = 0
        if torch.cuda.is_available():
            torch.cuda.synchronize()
            end_memory = torch.cuda.memory_allocated()
            memory_diff = (end_memory - start_memory) / (1024 * 1024)  # Convert to MB
        else:
            memory_diff = 0

        elapsed = end_time - start_time

        # Store metrics
        if instance and hasattr(instance, 'performance_metrics'):
            if func.__name__ not in instance.performance_metrics:
                instance.performance_metrics[func.__name__] = []
            instance.performance_metrics[func.__name__].append(elapsed)

        # Log performance info
        if elapsed > 1.0:  # Only log if took more than 1 second
            memory_str = f", memory usage: {memory_diff:.1f} MB" if torch.cuda.is_available() else ""
            debug_print(f"Function {func.__name__} took {elapsed:.2f} seconds{memory_str}")

        return result

    return wrapper

# === Main RAG System Class ===
class MediQueryRAG:
    def __init__(self, knowledge_base_dir=Config.DEFAULT_KNOWLEDGE_BASE_DIR,
                 finetuned_model_path=Config.DEFAULT_MODEL_PATH):
        """Initialize with enhanced features"""
        self.knowledge_base_dir = knowledge_base_dir
        self.finetuned_model_path = finetuned_model_path
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        debug_print(f"Using device: {self.device}")

        # Initialize evaluation metrics
        self.eval_metrics = {
            "precision": [],
            "recall": [],
            "f1": []
        }

        # Initialize performance metrics
        self.performance_metrics = {}

        # Initialize cache
        self.cache = LRUCache(capacity=Config.CACHE_SIZE, expiry_time=Config.CACHE_EXPIRY_TIME)

        # Confidence calibration data
        self.confidence_calibration = {
            0.1: 0.03,
            0.2: 0.08,
            0.3: 0.15,
            0.4: 0.25,
            0.5: 0.40,
            0.6: 0.58,
            0.7: 0.72,
            0.8: 0.88,
            0.9: 0.95,
            1.0: 0.99
        }

        # Load MIMIC-CXR metadata - Use text_data.csv directly instead of non-existent mimic_meta.csv
        self.mimic_meta_path = os.path.join(knowledge_base_dir, "text_data.csv")
        self.mimic_index = build_mimic_cxr_index(self.mimic_meta_path)

        # Track if we have valid MIMIC-CXR data
        self.has_mimic_data = "error" not in self.mimic_index

        if self.has_mimic_data:
            print("Successfully integrated MIMIC-CXR data")
        else:
            print("Warning: MIMIC-CXR data unavailable, using synthetic ground truth")

        # Load knowledge base and models
        self.load_knowledge_base()
        self.load_models()

    @performance_monitor
    def load_knowledge_base(self):
        debug_print("Loading knowledge base...")
        try:
            self.image_embs = np.load(f"{self.knowledge_base_dir}/image_embeddings.npy")
            self.text_embs = np.load(f"{self.knowledge_base_dir}/text_embeddings.npy")
            self.text_data = pd.read_csv(f"{self.knowledge_base_dir}/text_data.csv")
            self.image_index = faiss.read_index(f"{self.knowledge_base_dir}/image_index.faiss")
            self.text_index = faiss.read_index(f"{self.knowledge_base_dir}/text_index.faiss")
            self.image_distances = np.linalg.norm(self.image_embs - self.image_embs.mean(axis=0), axis=1)
            debug_print(f"Knowledge base loaded. Image embeddings shape: {self.image_embs.shape}, Text embeddings shape: {self.text_embs.shape}")

            # Enhanced optimization: Normalize embeddings if enabled
            if Config.EMBEDDING_NORMALIZE:
                debug_print("Normalizing embeddings for better retrieval")
                self.image_embs = self.image_embs / (np.linalg.norm(self.image_embs, axis=1, keepdims=True) + 1e-8)
                self.text_embs = self.text_embs / (np.linalg.norm(self.text_embs, axis=1, keepdims=True) + 1e-8)

            # Load or create anatomical region mapping
            if Config.ANATOMY_MAPPING_ENABLED:
                self.anatomical_regions = get_anatomical_regions()
                debug_print(f"Loaded {len(self.anatomical_regions)} anatomical regions for mapping")

        except Exception as e:
            error_msg = f"Error loading knowledge base: {str(e)}"
            debug_print(error_msg)
            raise RuntimeError(error_msg)

    @performance_monitor
    def load_models(self):
        debug_print("Loading models...")
        try:
            # Load image model (CheXNet or DenseNet)
            if Config.LAZY_LOADING:
                debug_print("Using lazy loading for models")

                # Setup image model
                if Config.IMAGE_MODEL == "chexnet":
                    self.image_model = models.densenet121(weights=None)
                    chexnet_url = "https://github.com/arnoweng/CheXNet/raw/master/model.pth.tar"
                    checkpoint = torch.hub.load_state_dict_from_url(chexnet_url, progress=True, map_location=self.device)
                    state_dict = {k.replace('densenet121.', '').replace('.norm.', '.norm').replace('.conv.', '.conv')
                                 .replace('.1', '1').replace('.2', '2').replace('classifier.0', 'classifier'): v
                                 for k, v in checkpoint['state_dict'].items()}
                    self.image_model.load_state_dict(state_dict, strict=False)
                else:
                    self.image_model = models.densenet121(weights=None)

                # Remove classifier to get features
                self.image_model = nn.Sequential(*list(self.image_model.children())[:-1]).to(self.device).eval()

                # Setup text model with lazy loading
                text_model_name = "dmis-lab/biobert-v1.1" if Config.TEXT_MODEL == "biobert" else "emilyalsentzer/Bio_ClinicalBERT"
                self.text_tokenizer = AutoTokenizer.from_pretrained(text_model_name)
                self.text_model = LazyModel(text_model_name, AutoModel, self.device)

                # Setup generator model with lazy loading
                self.gen_tokenizer = T5Tokenizer.from_pretrained(self.finetuned_model_path)
                self.gen_model = LazyModel(self.finetuned_model_path, T5ForConditionalGeneration, self.device)
            else:
                # Load all models eagerly
                if Config.IMAGE_MODEL == "chexnet":
                    self.image_model = models.densenet121(weights=None)
                    chexnet_url = "https://github.com/arnoweng/CheXNet/raw/master/model.pth.tar"
                    checkpoint = torch.hub.load_state_dict_from_url(chexnet_url, progress=True, map_location=self.device)
                    state_dict = {k.replace('densenet121.', '').replace('.norm.', '.norm').replace('.conv.', '.conv')
                                 .replace('.1', '1').replace('.2', '2').replace('classifier.0', 'classifier'): v
                                 for k, v in checkpoint['state_dict'].items()}
                    self.image_model.load_state_dict(state_dict, strict=False)
                else:
                    self.image_model = models.densenet121(weights=None)

                # Remove classifier to get features
                # Remove classifier to get features
                self.image_model = nn.Sequential(*list(self.image_model.children())[:-1]).to(self.device).eval()

                # Load text model (BioBERT or ClinicalBERT)
                text_model_name = "dmis-lab/biobert-v1.1" if Config.TEXT_MODEL == "biobert" else "emilyalsentzer/Bio_ClinicalBERT"
                self.text_tokenizer = AutoTokenizer.from_pretrained(text_model_name)
                self.text_model = AutoModel.from_pretrained(text_model_name).to(self.device).eval()

                # Load finetuned generator model
                self.gen_tokenizer = T5Tokenizer.from_pretrained(self.finetuned_model_path)
                self.gen_model = T5ForConditionalGeneration.from_pretrained(self.finetuned_model_path).to(self.device).eval()

                # Apply half precision if enabled
                if Config.USE_HALF_PRECISION and self.device.type == 'cuda':
                    debug_print("Using half precision for models")
                    self.text_model = self.text_model.half()
                    self.gen_model = self.gen_model.half()

            # Initialize transforms for image preprocessing
            self.transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])

            debug_print("Models loaded successfully")
        except Exception as e:
            error_msg = f"Error loading models: {str(e)}"
            debug_print(error_msg)
            debug_print(traceback.format_exc())
            raise RuntimeError(error_msg)

    @performance_monitor
    def generate_text_embedding(self, text):
        """Generate embedding for text query with caching and optimization"""
        debug_print(f"Generating text embedding for: {text[:50]}...")

        # Check cache first
        cache_key = f"text_emb_{hash(text)}"
        cached_emb = self.cache.get(cache_key)
        if cached_emb is not None:
            debug_print("Using cached text embedding")
            return cached_emb

        try:
            # Tokenize with retry logic
            for attempt in range(Config.MAX_RETRIES):
                try:
                    inputs = self.text_tokenizer(text, padding=True, truncation=True,
                                               return_tensors="pt", max_length=512).to(self.device)
                    break
                except Exception as e:
                    if attempt == Config.MAX_RETRIES - 1:
                        debug_print(f"Failed to tokenize text after {Config.MAX_RETRIES} attempts: {str(e)}")
                        raise
                    debug_print(f"Tokenization error (attempt {attempt+1}): {str(e)}")
                    time.sleep(Config.RECOVERY_WAIT_TIME)

            # Generate embedding
            with torch.no_grad():
                # If using LazyModel, the call will trigger lazy loading
                if isinstance(self.text_model, LazyModel):
                    outputs = self.text_model(**inputs)
                else:
                    outputs = self.text_model(**inputs)

                # Different aggregation strategies
                if Config.EMBEDDING_AGGREGATION == "cls":
                    embedding = outputs.last_hidden_state[:, 0].cpu().numpy()  # Use CLS token
                elif Config.EMBEDDING_AGGREGATION == "pooled":
                    if hasattr(outputs, 'pooler_output'):
                        embedding = outputs.pooler_output.cpu().numpy()  # Use pooler output
                    else:
                        embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
                elif Config.EMBEDDING_AGGREGATION == "weighted_avg":
                    # Use weighted average of last hidden states (attention-based)
                    last_hidden = outputs.last_hidden_state
                    weights = torch.softmax(torch.sum(last_hidden * inputs['input_ids'].unsqueeze(-1), dim=-1), dim=1)
                    embedding = (last_hidden * weights.unsqueeze(-1)).sum(dim=1).cpu().numpy()
                else:  # Default to average pooling
                    embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy()

            debug_print(f"Text embedding shape: {embedding.shape}")

            # Normalize embedding if configured
            if Config.EMBEDDING_NORMALIZE:
                embedding = embedding / (np.linalg.norm(embedding, axis=1, keepdims=True) + 1e-8)

            # Cache the result
            self.cache.put(cache_key, embedding)

            return embedding
        except Exception as e:
            debug_print(f"Error in generate_text_embedding: {str(e)}")
            debug_print(traceback.format_exc())
            raise

    @performance_monitor
    def generate_image_embedding(self, image_tensor):
        """Generate embedding for image with enhanced attention weight extraction"""
        debug_print(f"Generating image embedding. Input type: {type(image_tensor)}")

        # Check cache if we have image hash
        if hasattr(image_tensor, 'filename'):
            cache_key = f"img_emb_{hash(image_tensor.filename)}"
            cached_emb = self.cache.get(cache_key)
            if cached_emb is not None:
                debug_print("Using cached image embedding")
                return cached_emb

        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.Grayscale(num_output_channels=1),
            transforms.ToTensor()
        ])

        try:
            # Handle different input types
            if isinstance(image_tensor, Image.Image):
                debug_print(f"Input is PIL Image with mode: {image_tensor.mode}")
                if image_tensor.mode == 'RGB':
                    image_tensor = transform(image_tensor.convert('L'))
                else:
                    image_tensor = transform(image_tensor)
            elif isinstance(image_tensor, torch.Tensor):
                debug_print(f"Input is torch.Tensor with shape: {image_tensor.shape}")
                if image_tensor.dim() == 3:
                    if image_tensor.shape[0] == 3:  # RGB tensor in CHW format
                        image_tensor = transform(Image.fromarray((image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)))
                    elif image_tensor.shape[-1] == 3:  # RGB tensor in HWC format
                        image_tensor = transform(Image.fromarray((image_tensor.cpu().numpy() * 255).astype(np.uint8)))
                    else:  # Single-channel
                        image_tensor = transform(Image.fromarray((image_tensor.squeeze().cpu().numpy() * 255).astype(np.uint8)))
                elif image_tensor.dim() == 2:  # Grayscale
                    image_tensor = transform(Image.fromarray((image_tensor.cpu().numpy() * 255).astype(np.uint8)))
            elif isinstance(image_tensor, np.ndarray):
                debug_print(f"Input is numpy.ndarray with shape: {image_tensor.shape}")
                if len(image_tensor.shape) == 3 and image_tensor.shape[-1] == 3:  # RGB array
                    image_tensor = transform(Image.fromarray(image_tensor))
                else:  # Grayscale or single-channel
                    image_tensor = transform(Image.fromarray(image_tensor))
            else:
                raise ValueError(f"Unexpected image tensor type: {type(image_tensor)}")

            # Ensure 3 channels for model
            if image_tensor.shape[0] == 1:
                image_tensor = image_tensor.repeat(3, 1, 1)

            image_tensor = image_tensor.unsqueeze(0).to(self.device)
            debug_print(f"Processed image tensor shape: {image_tensor.shape}")

            with torch.no_grad():
                # Extract features with hooks for attention weights
                activation = {}
                def get_activation(name):
                    def hook(model, input, output):
                        activation[name] = output.detach()
                    return hook

                # Register forward hook on last convolutional layer to extract attention
                if not isinstance(self.image_model, LazyModel):
                    # Find the last conv layer for hook registration
                    for name, layer in self.image_model.named_modules():
                        if isinstance(layer, nn.Conv2d):
                            last_conv_name = name

                    # Get last conv layer for attention extraction
                    for name, layer in self.image_model.named_modules():
                        if name == last_conv_name:
                            layer.register_forward_hook(get_activation('features'))

                # Get features
                features = self.image_model(image_tensor)

                # Enhanced attention weight extraction
                if 'features' in activation:
                    feature_activations = activation['features'].squeeze()
                    # Save raw activations for detailed analysis
                    self.raw_activations = feature_activations.cpu().numpy()

                    # Calculate attention weights with softmax to enhance key areas
                    if feature_activations.dim() > 1:
                        attention_weights = nn.functional.softmax(feature_activations.flatten(), dim=0).reshape(feature_activations.shape)
                        self.last_attention_weights = attention_weights.cpu().numpy()

                        # Extract spatial attention patterns
                        h, w = attention_weights.shape[-2:]
                        # Calculate attention by regions
                        self.region_attention = {
                            'upper_left': attention_weights[:, :h//2, :w//2].mean().item(),
                            'upper_right': attention_weights[:, :h//2, w//2:].mean().item(),
                            'lower_left': attention_weights[:, h//2:, :w//2].mean().item(),
                            'lower_right': attention_weights[:, h//2:, w//2:].mean().item(),
                            'central': attention_weights[:, h//4:3*h//4, w//4:3*w//4].mean().item()
                        }
                    else:
                        self.last_attention_weights = np.ones((7, 7))
                        self.region_attention = {
                            'upper_left': 0.2, 'upper_right': 0.2,
                            'lower_left': 0.2, 'lower_right': 0.2, 'central': 0.2
                        }
                else:
                    # Fallback for LazyModel or if hook failed
                    feature_map_size = int(np.sqrt(features.shape[1]))
                    self.last_attention_weights = np.ones((feature_map_size, feature_map_size))
                    self.region_attention = {
                        'upper_left': 0.2, 'upper_right': 0.2,
                        'lower_left': 0.2, 'lower_right': 0.2, 'central': 0.2
                    }

                # Global pooling for embedding
                pooled = nn.functional.adaptive_avg_pool2d(features, (1, 1)).flatten()

            # Extract embedded anatomical regions
            self.detect_anatomical_regions()

            embedding = pooled.cpu().numpy().reshape(1, -1)
            debug_print(f"Image embedding shape: {embedding.shape}")

            # Normalize embedding if configured
            if Config.EMBEDDING_NORMALIZE:
                embedding = embedding / (np.linalg.norm(embedding) + 1e-8)

            # Cache the embedding if we have a filename
            if hasattr(image_tensor, 'filename'):
                self.cache.put(cache_key, embedding)

            return embedding
        except Exception as e:
            debug_print(f"Error in generate_image_embedding: {str(e)}")
            debug_print(traceback.format_exc())
            raise

    def detect_anatomical_regions(self):
        """Detect anatomical regions based on attention weights"""
        if not hasattr(self, 'last_attention_weights') or not Config.ANATOMY_MAPPING_ENABLED:
            return

        try:
            debug_print("Detecting anatomical regions from attention weights")

            # Get attention map
            attention_map = self.last_attention_weights

            # Initialize region detection results
            self.detected_regions = {}

            # Map regions based on spatial positions
            h, w = attention_map.shape[:2] if len(attention_map.shape) > 1 else (7, 7)

            # Define region coordinates (normalized to attention map size)
            regions = {
                "upper_right_lung": (slice(0, h//3), slice(2*w//3, w)),
                "upper_left_lung": (slice(0, h//3), slice(0, w//3)),
                "middle_right_lung": (slice(h//3, 2*h//3), slice(2*w//3, w)),
                "lower_right_lung": (slice(2*h//3, h), slice(2*w//3, w)),
                "lower_left_lung": (slice(2*h//3, h), slice(0, w//3)),
                "heart": (slice(h//3, 2*h//3), slice(w//3, 2*w//3)),
                "hilar": (slice(h//4, h//2), slice(w//3, 2*w//3)),
                "costophrenic_angles": (slice(2*h//3, h), slice(w//4, 3*w//4)),
                "spine": (slice(h//3, 2*h//3), slice(w//2-1, w//2+2)),
                "diaphragm": (slice(2*h//3, h), slice(w//4, 3*w//4))
            }

            # Calculate attention score for each region
            for region_name, (y_slice, x_slice) in regions.items():
                if isinstance(attention_map, np.ndarray) and len(attention_map.shape) > 1:
                    region_attention = attention_map[y_slice, x_slice]
                    attention_score = np.mean(region_attention)
                else:
                    # Fallback if attention map is incorrect shape
                    attention_score = 0.1

                # Map score to anatomical region
                self.detected_regions[region_name] = {
                    "attention_score": float(attention_score),
                    "description": self.anatomical_regions[region_name]["description"] if region_name in self.anatomical_regions else "",
                    "possible_conditions": self.anatomical_regions[region_name]["conditions"] if region_name in self.anatomical_regions else []
                }

            # Find regions with highest attention
            sorted_regions = sorted(self.detected_regions.items(), key=lambda x: x[1]["attention_score"], reverse=True)
            self.primary_regions = [region[0] for region in sorted_regions[:3]]

            debug_print(f"Primary detected regions: {self.primary_regions}")
        except Exception as e:
            debug_print(f"Error detecting anatomical regions: {str(e)}")
            debug_print(traceback.format_exc())

    @performance_monitor
    def hybrid_retrieve(self, text_query, image_embedding_tuple=None, k=Config.TOP_K_RETRIEVAL):
        """Retrieve relevant documents using text and/or image embeddings with enhanced algorithms"""
        debug_print(f"hybrid_retrieve called with query: {text_query[:50]}... and image_embedding_tuple: {'provided' if image_embedding_tuple else 'None'}")

        # Check cache first if caching is enabled
        if Config.USE_CACHING:
            cache_key = f"hybrid_{hash(text_query)}_{hash(str(image_embedding_tuple)) if image_embedding_tuple else 'noimg'}"
            cached_result = self.cache.get(cache_key)
            if cached_result is not None:
                debug_print("Using cached hybrid retrieval results")
                return cached_result

        try:
            text_embedding = self.generate_text_embedding(text_query)

            if image_embedding_tuple is not None:
                debug_print(f"Processing with image embedding, type: {type(image_embedding_tuple)}")
                # Convert tuple to array
                image_embedding = np.array(image_embedding_tuple).reshape(1, -1)
                debug_print(f"Converted image embedding shape: {image_embedding.shape}")

                # Search using text embedding
                text_distances, text_indices = self.text_index.search(text_embedding.astype('float32'), k)
                debug_print(f"Text search results: {len(text_indices[0])} indices")

                # Search using image embedding
                image_distances, image_indices = self.image_index.search(image_embedding.astype('float32'), k)
                debug_print(f"Image search results: {len(image_indices[0])} indices")

                # Combine results with improved weighting
                combined_indices = np.unique(np.concatenate([text_indices[0], image_indices[0]]))
                debug_print(f"Combined unique indices: {len(combined_indices)}")

                # Enhanced reranking with query-adaptive weighting
                reranked_results = self.rerank_results(combined_indices, text_embedding, image_embedding)
                debug_print(f"Reranked results: {len(reranked_results)} items")

                # Process results
                result = self.process_ranked_results(reranked_results)
                debug_print(f"hybrid_retrieve returning tuple with {len(result)} items - {type(result)}")

                # Cache the result if enabled
                if Config.USE_CACHING:
                    self.cache.put(cache_key, result)

                return result
            else:
                debug_print("Processing text-only query")

                # Enhanced text retrieval with query expansion
                distances, indices = self.text_index.search(text_embedding.astype('float32'), k)

                # Add diversity to results if enabled
                if Config.DYNAMIC_RERANKING:
                    indices, distances = self.diversify_results(indices[0], distances[0], k)
                    indices = np.array([indices])
                    distances = np.array([distances])

                retrieved_texts = [self.text_data['combined_text'].iloc[idx] for idx in indices[0]]
                debug_print(f"Retrieved {len(retrieved_texts)} texts for query")
                result = (retrieved_texts, indices[0], distances[0])

                # Cache the result if enabled
                if Config.USE_CACHING:
                    self.cache.put(cache_key, result)

                debug_print(f"hybrid_retrieve returning tuple with {len(result)} items - {type(result)}")
                return result
        except Exception as e:
            debug_print(f"Error in hybrid_retrieve: {str(e)}")
            debug_print(traceback.format_exc())
            # Return fallback values
            empty_texts = ["Error retrieving medical context."]
            empty_indices = np.array([0])
            empty_distances = np.array([1.0])
            return empty_texts, empty_indices, empty_distances

    def diversify_results(self, indices, distances, k):
        """Add diversity to search results using maximal marginal relevance"""
        debug_print("Applying diversity optimization to search results")

        try:
            # If we have fewer than threshold results, just return them
            if len(indices) <= min(5, k // 2):
                return indices, distances

            # Use the first item as anchor
            selected_indices = [indices[0]]
            selected_distances = [distances[0]]
            remaining_indices = list(indices[1:])
            remaining_distances = list(distances[1:])

            # Get embeddings for remaining documents
            remaining_embeddings = np.array([self.text_embs[idx] for idx in remaining_indices])

            # Add documents one by one with diversity consideration
            while len(selected_indices) < k and remaining_indices:
                # Calculate similarity between each remaining doc and the selected docs
                max_similarities = []

                # For each remaining document
                for i in range(len(remaining_indices)):
                    # Get its embedding
                    embedding = remaining_embeddings[i]

                    # Calculate similarities with all selected documents
                    similarities = []
                    for sel_idx in selected_indices:
                        sel_embedding = self.text_embs[sel_idx]
                        sim = np.dot(embedding, sel_embedding) / (np.linalg.norm(embedding) * np.linalg.norm(sel_embedding) + 1e-8)
                        similarities.append(sim)

                    # Get max similarity with any selected document
                    max_sim = max(similarities) if similarities else 0
                    max_similarities.append(max_sim)

                # Apply Maximal Marginal Relevance formula:
                # MMR = lambda * relevance - (1-lambda) * max_similarity
                lambda_param = 0.7  # Balance between relevance and diversity
                mmr_scores = [lambda_param * (1 - dist) - (1 - lambda_param) * sim
                             for dist, sim in zip(remaining_distances, max_similarities)]

                # Select item with highest MMR score
                best_idx = np.argmax(mmr_scores)
                selected_indices.append(remaining_indices[best_idx])
                selected_distances.append(remaining_distances[best_idx])

                # Remove selected item from remaining items
                del remaining_indices[best_idx]
                del remaining_distances[best_idx]
                remaining_embeddings = np.delete(remaining_embeddings, best_idx, axis=0)

            return np.array(selected_indices), np.array(selected_distances)
        except Exception as e:
            debug_print(f"Error in diversify_results: {str(e)}")
            # Fall back to original indices
            return indices, distances

    @performance_monitor
    def rerank_results(self, indices, text_embedding, image_embedding):
        """Rerank results combining text and image similarity scores with enhanced weighting"""
        debug_print(f"Reranking {len(indices)} results")
        try:
            reranked = []

            # Check if query appears to be looking for specific conditions
            # For medically specific queries, increase text weight
            query_specificity = self.estimate_query_specificity(text_embedding)

            # Adjust weights based on query specificity
            text_weight = min(0.8, 0.5 + query_specificity * 0.3)
            image_weight = 1.0 - text_weight

            debug_print(f"Reranking with text_weight={text_weight:.2f}, image_weight={image_weight:.2f}")

            # NOTE: Cannot directly compare text and image embeddings because of different dimensions
            # (text: 768, image: 1024) - so we'll skip the coherence calculation

            for idx in indices:
                try:
                    # Compute similarities separately (no cross-comparisons)
                    text_sim = np.dot(text_embedding.flatten(), self.text_embs[idx].flatten()) / (
                        np.linalg.norm(text_embedding) * np.linalg.norm(self.text_embs[idx]) + 1e-8)

                    image_sim = np.dot(image_embedding.flatten(), self.image_embs[idx].flatten()) / (
                        np.linalg.norm(image_embedding) * np.linalg.norm(self.image_embs[idx]) + 1e-8)

                    # Compute combined score with dynamic weighting
                    combined_score = text_weight * text_sim + image_weight * image_sim

                    # Store for reranking
                    reranked.append((idx, combined_score, text_sim, image_sim))
                except Exception as item_e:
                    debug_print(f"Error processing item {idx}: {str(item_e)}")
                    # Use a fallback score
                    reranked.append((idx, 0.0, 0.0, 0.0))

            # Sort by combined score
            sorted_results = sorted(reranked, key=lambda x: x[1], reverse=True)
            debug_print(f"Reranking complete, top score: {sorted_results[0][1] if sorted_results else 'N/A'}")
            return sorted_results
        except Exception as e:
            debug_print(f"Error in rerank_results: {str(e)}")
            debug_print(traceback.format_exc())
            return [(idx, 0.0, 0.0, 0.0) for idx in indices[:5]]  # Return some indices with zero scores

    def estimate_query_specificity(self, text_embedding):
        """Estimate query specificity based on embedding characteristics"""
        try:
            # Compute statistics on the embedding
            mean_activation = np.mean(text_embedding)
            std_activation = np.std(text_embedding)
            max_activation = np.max(text_embedding)

            # Higher std suggests more specific semantic content
            # Improved formula with logistic function for smoother scaling
            specificity = 1.0 / (1.0 + np.exp(-10 * (std_activation - 0.05)))
            debug_print(f"Query specificity estimate: {specificity:.2f}, std: {std_activation:.4f}")
            return specificity
        except Exception as e:
            debug_print(f"Error estimating query specificity: {str(e)}")
            return 0.5  # Default to balanced weight

    @performance_monitor
    def process_ranked_results(self, reranked_results):
        """Process reranked results into texts, indices, and distances"""
        debug_print(f"Processing {len(reranked_results)} ranked results")
        try:
            indices = [x[0] for x in reranked_results]
            distances = [1 - x[1] for x in reranked_results]  # Convert similarity to distance

            # Store text and image similarities for confidence analysis
            self.text_similarities = [x[2] for x in reranked_results]
            self.image_similarities = [x[3] for x in reranked_results]

            # Get retrieved texts
            retrieved_texts = [self.text_data['combined_text'].iloc[idx] for idx in indices]

            # Apply PHI detection and anonymization
            if Config.PHI_DETECTION_ENABLED:
                retrieved_texts = [anonymize_text(text) for text in retrieved_texts]

            # Enhanced context-based filtering
            # Remove duplicative content using text similarity
            if len(retrieved_texts) > 5:
                filtered_texts, filtered_indices, filtered_distances = self.filter_duplicative_content(
                    retrieved_texts, indices, distances
                )
                debug_print(f"Filtered {len(retrieved_texts) - len(filtered_texts)} duplicative documents")
                retrieved_texts, indices, distances = filtered_texts, filtered_indices, filtered_distances

            debug_print(f"Processed ranked results: {len(retrieved_texts)} texts, indices shape: {len(indices)}, distances shape: {len(distances)}")
            return retrieved_texts, np.array(indices), np.array(distances)
        except Exception as e:
            debug_print(f"Error in process_ranked_results: {str(e)}")
            debug_print(traceback.format_exc())
            return ["Error processing results."], np.array([0]), np.array([1.0])

    def filter_duplicative_content(self, texts, indices, distances, similarity_threshold=0.8):
        """Filter out duplicative content based on text similarity"""
        if len(texts) <= 1:
            return texts, indices, distances

        try:
            # Keep track of documents to retain
            keep_indices = [0]  # Always include the top document

            # Compare each document with the ones we've decided to keep
            for i in range(1, len(texts)):
                # Check if current document is too similar to any kept document
                is_duplicate = False
                for keep_idx in keep_indices:
                    # Simple similarity check - word overlap ratio
                    text1_words = set(texts[keep_idx].lower().split())
                    text2_words = set(texts[i].lower().split())

                    # Calculate Jaccard similarity
                    if len(text1_words) > 0 and len(text2_words) > 0:
                        intersection = len(text1_words.intersection(text2_words))
                        union = len(text1_words) + len(text2_words) - intersection
                        similarity = intersection / union

                        if similarity > similarity_threshold:
                            is_duplicate = True
                            break

                # If not a duplicate, keep it
                if not is_duplicate:
                    keep_indices.append(i)

                # Don't process too many documents for efficiency
                if len(keep_indices) >= Config.MAX_CONTEXT_DOCS:
                    break

            # Create filtered lists
            filtered_texts = [texts[i] for i in keep_indices]
            filtered_indices = [indices[i] for i in keep_indices]
            filtered_distances = [distances[i] for i in keep_indices]

            return filtered_texts, filtered_indices, filtered_distances

        except Exception as e:
            debug_print(f"Error in filter_duplicative_content: {str(e)}")
            return texts, indices, distances

    def calibrate_confidence(self, raw_confidence):
        """Calibrate confidence scores using pre-computed mapping"""
        try:
            # Find closest calibration points in our mapping
            keys = list(self.confidence_calibration.keys())
            distances = [abs(k - raw_confidence) for k in keys]
            closest_idx = distances.index(min(distances))
            closest_key = keys[closest_idx]

            # Get calibrated confidence
            calibrated = self.confidence_calibration[closest_key]

            # Interpolate if between calibration points
            if raw_confidence > closest_key and closest_idx < len(keys) - 1:
                next_key = keys[closest_idx + 1]
                next_value = self.confidence_calibration[next_key]
                # Linear interpolation
                weight = (raw_confidence - closest_key) / (next_key - closest_key)
                calibrated = self.confidence_calibration[closest_key] + weight * (next_value - self.confidence_calibration[closest_key])

            debug_print(f"Calibrated confidence: raw={raw_confidence:.2f}, calibrated={calibrated:.2f}")
            return calibrated
        except Exception as e:
            debug_print(f"Error in confidence calibration: {str(e)}")
            return raw_confidence  # Return raw value if calibration fails

    def explain_confidence(self, distances):
        """
        Generate confidence explanation from distances with enhanced calibration
        """
        debug_print(f"Explaining confidence for distances: {type(distances)}, shape: {getattr(distances, 'shape', 'N/A')}")
        try:
            if not isinstance(distances, np.ndarray) or not distances.size:
                return "Confidence score unavailable due to no retrieved results."

            # Print raw distance values for debugging
            debug_print(f"Distance stats - min: {np.min(distances):.4f}, max: {np.max(distances):.4f}, mean: {np.mean(distances):.4f}")

            # Modified calibration for the observed distance range (71-76)
            # Transform to a more usable range
            min_dist = np.min(distances[:5])  # Use top 5 results

            # Map the observed range to confidence scores
            if min_dist > 76:
                confidence = 0.1  # Very low confidence
            elif min_dist > 74:
                confidence = 0.3  # Low confidence
            elif min_dist > 72:
                confidence = 0.5  # Moderate confidence
            elif min_dist > 70:
                confidence = 0.7  # Good confidence
            else:
                confidence = 0.9  # High confidence

            # Apply calibration
            confidence = self.calibrate_confidence(confidence)

            # Enhanced explanation based on calibrated confidence
            if confidence >= 0.9:
                confidence_level = "Very High"
                explanation = "The system found very closely matching reference cases in the medical knowledge base."
            elif confidence >= 0.75:
                confidence_level = "High"
                explanation = "The system found strong matches in the medical knowledge base."
            elif confidence >= 0.6:
                confidence_level = "Moderate"
                explanation = "The system found reasonable matches, but with some differences from reference cases."
            elif confidence >= 0.4:
                confidence_level = "Fair"
                explanation = "The retrieved references only partially match this case."
            else:
                confidence_level = "Low"
                explanation = "The system found limited matches in the knowledge base for this case."

            # Flag if confidence is below threshold
            if confidence < Config.CONFIDENCE_THRESHOLD:
                caution_note = "Caution: Response confidence is below the reliability threshold. Please verify with other sources."
            else:
                caution_note = ""

            # Format percentages
            conf_pct = f"{confidence * 100:.0f}%"

            full_explanation = f"Confidence Level: {confidence_level} ({conf_pct})\nExplanation: {explanation}\n{caution_note}"
            debug_print(f"Generated confidence explanation: {confidence_level} ({conf_pct})")

            return full_explanation

        except Exception as e:
            debug_print(f"Error generating confidence explanation: {str(e)}")
            return "Confidence assessment unavailable."

    @performance_monitor
    def handle_differentiation_query(self, query, retrieved_texts):
        """Handle differentiation queries dynamically based on retrieved knowledge"""
        conditions_to_diff = []
        all_conditions = get_mimic_cxr_conditions()

        for condition in all_conditions:
            if condition in query.lower():
                conditions_to_diff.append(condition)

        if len(conditions_to_diff) < 2:
            print(f"Warning: Differentiation query but found only {len(conditions_to_diff)} conditions: {conditions_to_diff}")
            return None

        print(f"Differentiating between: {conditions_to_diff[:2]}")

        condition_texts = {c: [] for c in conditions_to_diff[:2]}

        for text in retrieved_texts:
            for condition in condition_texts.keys():
                if condition in text.lower():
                    sentences = re.split(r'[.!?]', text)
                    for sentence in sentences:
                        if condition in sentence.lower():
                            condition_texts[condition].append(sentence)

        if any(len(texts) == 0 for texts in condition_texts.values()):
            print("Warning: Insufficient information from knowledge base for differentiation")
            return None

        diff_text = f"Differentiating {conditions_to_diff[0]} from {conditions_to_diff[1]}:\n\n"

        # Enhanced differentiation with structured comparison
        diff_text += "Key Differences:\n"

        # For each condition, gather key characteristics
        for condition, texts in condition_texts.items():
            diff_text += f"\n{condition.capitalize()}:\n"

            # Extract characteristic sentences
            characteristic_sentences = set()
            for text in texts[:5]:  # Use top 5 sentences
                cleaned_text = text.strip()
                if cleaned_text:
                    characteristic_sentences.add(cleaned_text)

            # Add bullet points for characteristics
            for i, sentence in enumerate(characteristic_sentences):
                if i < 5:  # Limit to 5 bullet points
                    diff_text += f"• {sentence}\n"

        # Add direct comparison summary if available
        for text in retrieved_texts:
            if all(condition in text.lower() for condition in conditions_to_diff[:2]) and "differ" in text.lower():
                diff_text += f"\nDirect Comparison:\n{text}\n"
                break

        return diff_text

    @performance_monitor
    def generate_response(self, query, retrieved_texts, indices, distances):
        """Generate response from retrieved texts with improved formatting and caching"""
        debug_print(f"Generating response for query: {query[:50]}...")

        # Check cache for this exact query+context combination
        cache_key = f"response_{hash(query)}_{hash(str(retrieved_texts[:3]))}"
        cached_response = self.cache.get(cache_key)
        if cached_response is not None:
            debug_print("Using cached response")
            return cached_response

        try:
            # Clean and prepare context
            cleaned_texts = []
            for text in retrieved_texts[:5]:
                # Remove any potential patient identifiers before using as context
                cleaned_text = re.sub(r'\[REDACTED\]', '', text)
                cleaned_texts.append(cleaned_text)

            context = "\n\nRetrieved Context:\n" + "\n".join(cleaned_texts)
            input_text = f"Query: {query}\nContext: {context}\nAnswer:"

            debug_print("Tokenizing input for generation")
            inputs = self.gen_tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True).to(self.device)

            debug_print("Generating text with model")
            with torch.no_grad():
                outputs = self.gen_model.generate(
                    **inputs,
                    max_length=250,
                    num_beams=2,
                    early_stopping=True,
                    do_sample=True,  # Fix for temperature warning
                    temperature=0.7,
                    no_repeat_ngram_size=3  # Prevent repetition
                )

            debug_print("Decoding generated text")
            response = self.gen_tokenizer.decode(outputs[0], skip_special_tokens=True)

            # Check if this is a differentiation query
            if "differentiate" in query.lower():
                debug_print("Processing differentiation query")
                diff_response = self.handle_differentiation_query(query, retrieved_texts)
                if diff_response:
                    response = diff_response

            # Format the response
            response = self.format_response(response, query)

            # Generate confidence explanation
            confidence_explanation = self.explain_confidence(distances)
            response_with_confidence = response + "\n\n" + confidence_explanation

            # Generate visual analysis if we have attention weights
            visual_analysis = ""
            if hasattr(self, 'last_attention_weights'):
                debug_print("Generating visual analysis")
                visual_analysis = "\n\n" + self.generate_visual_analysis()

            # Final response with PHI detection
            final_response = anonymize_text(response_with_confidence + visual_analysis)

            # Cache the result
            self.cache.put(cache_key, final_response)

            debug_print(f"Response generation complete, length: {len(final_response)}")
            return final_response
        except Exception as e:
            debug_print(f"Error in generate_response: {str(e)}")
            debug_print(traceback.format_exc())
            return f"Error generating response: {str(e)}"

    def format_response(self, text, query):
        """Format response for consistency and readability with enhanced cleanup"""
        # Remove any [REDACTED] markers
        text = re.sub(r'\[\s*REDACTED\s*\]\s*:', '', text)
        text = re.sub(r'\[\s*REDACTED\s*\]', '', text)

        # Remove numbered list formatting if present
        text = re.sub(r'^\d+\.\s+', '', text, flags=re.MULTILINE)

        # Clean up any improper line breaks or spacing
        text = re.sub(r'\s{2,}', ' ', text)
        text = re.sub(r'\n\s*\n', '\n\n', text)

        # Extract only the first complete report (to avoid mixing multiple reports)
        # Look for patterns like "FINDINGS: ... IMPRESSION: ..." or similar
        report_pattern = r'(FINDINGS?:.*?IMPRESSIONS?:.*?)(?:FINDINGS?:|$)'
        report_match = re.search(report_pattern, text, re.DOTALL | re.IGNORECASE)

        if report_match:
            # Use only the first complete report
            text = report_match.group(1).strip()

        # Find and extract findings and impression
        findings_match = re.search(r'FINDINGS?:(.*?)(?:IMPRESSIONS?:|$)', text, re.DOTALL | re.IGNORECASE)
        impression_match = re.search(r'IMPRESSIONS?:(.*?)(?:FINDINGS?:|$)', text, re.DOTALL | re.IGNORECASE)

        if findings_match or impression_match:
            # Structure existing findings/impression sections
            formatted = ""
            if findings_match:
                findings = findings_match.group(1).strip()
                formatted += "Findings:\n"

                # Convert to bullet points if not already
                if not re.search(r'^\s*[•\-\*]', findings, re.MULTILINE):
                    # Split by sentences or semi-colons
                    findings_points = []
                    for sentence in re.split(r'[.;]\s+', findings):
                        if sentence.strip():
                            findings_points.append(sentence.strip())

                    for point in findings_points:
                        if point and not point.endswith('.'):
                            point += '.'
                        if point:
                            formatted += f"• {point}\n"
                else:
                    formatted += findings + "\n"

            if impression_match:
                impression = impression_match.group(1).strip()
                formatted += f"\nImpression:\n{impression}"

            text = formatted
        else:
            # If no structured sections found, create a simple findings section
            if any(term in query.lower() for term in ['x-ray', 'xray', 'ct', 'mri', 'scan', 'imaging', 'radiograph']):
                # Split by sentences
                sentences = re.split(r'[.!?]\s+', text)
                formatted = "Findings:\n"

                # Take only the first 5 sentences to avoid mixing reports
                for i, sentence in enumerate(sentences[:5]):
                    sentence = sentence.strip()
                    if sentence:
                        if not sentence.endswith('.'):
                            sentence += '.'
                        formatted += f"• {sentence}\n"

                text = formatted

        # Ensure text doesn't end mid-sentence
        if text and not text.rstrip().endswith(('.', '!', '?')):
            text = text.rstrip() + '.'

        return text

    @performance_monitor
    def generate_visual_analysis(self):
        """Generate detailed analysis text from visual attention"""
        debug_print("Generating visual analysis from attention weights")

        if not hasattr(self, 'last_attention_weights') or self.last_attention_weights is None:
            return "Visual Analysis: Not available for this query."

        try:
            # Get attention map and region information
            attention_map = self.last_attention_weights

            # Start with basic information (removed [REDACTED] label)
            analysis = "Regions of Interest:\n"

            # Add anatomical region analysis if available
            if hasattr(self, 'detected_regions') and self.detected_regions:
                # Get top 3 regions with highest attention
                top_regions = sorted(self.detected_regions.items(), key=lambda x: x[1]["attention_score"], reverse=True)[:3]

                for region_name, region_data in top_regions:
                    # Format the region name for display
                    display_name = region_name.replace('_', ' ').title()

                    # Add region description
                    analysis += f"• {region_data['description']}: "

                    # Add attention level description
                    score = region_data['attention_score']
                    if score > 0.5:
                        attention_level = "high"
                    elif score > 0.2:
                        attention_level = "moderate"
                    else:
                        attention_level = "low"

                    analysis += f"{attention_level} attention"

                    # Add potential findings if available
                    if region_data['possible_conditions'] and len(region_data['possible_conditions']) > 0:
                        conditions = [c.capitalize() for c in region_data['possible_conditions'][:2]]
                        analysis += f" (may indicate {' or '.join(conditions)})"

                    analysis += "\n"

                # Add asymmetry analysis
                if hasattr(self, 'region_attention'):
                    left_attention = (self.region_attention['upper_left'] + self.region_attention['lower_left']) / 2
                    right_attention = (self.region_attention['upper_right'] + self.region_attention['lower_right']) / 2

                    if abs(left_attention - right_attention) > 0.15:  # Significant asymmetry
                        dominant_side = "left" if left_attention > right_attention else "right"
                        analysis += f"\nAsymmetry: The {dominant_side} side shows significantly more findings of interest.\n"
            else:
                # Fallback if detailed region data isn't available
                analysis += "• Areas of high attention indicate potential findings requiring clinical correlation.\n"

            return analysis

        except Exception as e:
            debug_print(f"Error in generate_visual_analysis: {str(e)}")
            debug_print(traceback.format_exc())
            return "Visual Analysis: Attention heatmap shows regions of interest in the image."

    @performance_monitor
    def evaluate_response(self, query, response, image_id=None):
        """
        Evaluate response quality using MIMIC-CXR ground truth when available

        Args:
            query: The user query string
            response: The generated response
            image_id: The ID of the image being analyzed (if applicable)

        Returns:
            tuple: (precision, recall, f1)
        """
        debug_print(f"Evaluating response for query: {query[:50]}...")

        try:
            conditions = get_mimic_cxr_conditions()
            condition_synonyms = get_condition_synonyms()

            response_mentions = {c: False for c in conditions}
            response_lower = response.lower()

            for condition in conditions:
                if condition in response_lower:
                    response_mentions[condition] = True
                    continue

                if condition in condition_synonyms:
                    for synonym in condition_synonyms[condition]:
                        if synonym in response_lower:
                            response_mentions[condition] = True
                            break

            mentioned_conditions = [c for c, mentioned in response_mentions.items() if mentioned]
            primary_condition = mentioned_conditions[0] if mentioned_conditions else None

            ground_truth = self.get_ground_truth(
                image_id=image_id,
                condition=primary_condition,
                query=query
            )

            debug_print(f"Using ground truth: '{ground_truth[:100]}...'")

            ground_truth_mentions = {c: False for c in conditions}
            ground_truth_lower = ground_truth.lower()

            for condition in conditions:
                if condition in ground_truth_lower:
                    ground_truth_mentions[condition] = True
                    continue

                if condition in condition_synonyms:
                    for synonym in condition_synonyms[condition]:
                        if synonym in ground_truth_lower:
                            ground_truth_mentions[condition] = True
                            break

            pred_labels = [int(response_mentions[c]) for c in conditions]
            true_labels = [int(ground_truth_mentions[c]) for c in conditions]

            debug_print("Condition detection results:")
            for i, condition in enumerate(conditions):
                debug_print(f"- {condition}: pred={bool(pred_labels[i])}, true={bool(true_labels[i])}")

            if sum(true_labels) > 0:
                if "differentiate" in query.lower():
                    diff_conditions = [c for c in conditions if c in query.lower()]
                    if len(diff_conditions) >= 2:
                        resp_has_both = all(response_mentions[c] for c in diff_conditions[:2])
                        truth_has_both = all(ground_truth_mentions[c] for c in diff_conditions[:2])

                        precision = 1.0 if resp_has_both == truth_has_both else 0.0
                        recall = 1.0 if resp_has_both == truth_has_both else 0.0
                        f1 = 1.0 if resp_has_both == truth_has_both else 0.0

                        self.eval_metrics["precision"].append(precision)
                        self.eval_metrics["recall"].append(recall)
                        self.eval_metrics["f1"].append(f1)

                        debug_print(f"Differentiation query evaluation - Precision: {precision:.2f}, Recall: {recall:.2f}, F1: {f1:.2f}")
                        return precision, recall, f1

                precision = precision_score(true_labels, pred_labels, zero_division=0)
                recall = recall_score(true_labels, pred_labels, zero_division=0)
                f1 = f1_score(true_labels, pred_labels, zero_division=0)
            else:
                if sum(pred_labels) == 0:
                    precision = 1.0
                    recall = 1.0
                    f1 = 1.0
                else:
                    precision = 0.0
                    recall = 0.0
                    f1 = 0.0

            self.eval_metrics["precision"].append(precision)
            self.eval_metrics["recall"].append(recall)
            self.eval_metrics["f1"].append(f1)

            debug_print(f"Evaluation results - Precision: {precision:.2f}, Recall: {recall:.2f}, F1: {f1:.2f}")
            return precision, recall, f1

        except Exception as e:
            debug_print(f"Error in evaluate_response: {str(e)}")
            debug_print(traceback.format_exc())
            return 0.0, 0.0, 0.0

    def get_evaluation_metrics(self):
        """Get overall evaluation metrics with proper averaging"""
        debug_print("Getting evaluation metrics")

        if not self.eval_metrics["precision"]:
            debug_print("No evaluation metrics collected yet")
            return {"avg_precision": 0.0, "avg_recall": 0.0, "avg_f1": 0.0, "evaluations": 0}

        avg_precision = np.mean(self.eval_metrics["precision"])
        avg_recall = np.mean(self.eval_metrics["recall"])
        avg_f1 = np.mean(self.eval_metrics["f1"])

        debug_print(f"Metrics collected: {len(self.eval_metrics['precision'])} evaluations")
        debug_print(f"Avg precision: {avg_precision:.2f}, Avg recall: {avg_recall:.2f}, Avg F1: {avg_f1:.2f}")

        metrics = {
            "avg_precision": round(float(avg_precision), 2),
            "avg_recall": round(float(avg_recall), 2),
            "avg_f1": round(float(avg_f1), 2),
            "evaluations": len(self.eval_metrics["precision"])
        }

        # Add performance metrics if available
        if self.performance_metrics:
            perf_summary = {}
            for func_name, times in self.performance_metrics.items():
                if times:
                    perf_summary[func_name] = {
                        "avg_time": round(float(np.mean(times)), 3),
                        "max_time": round(float(np.max(times)), 3),
                        "calls": len(times)
                    }
            metrics["performance"] = perf_summary

        return metrics

    def get_synthesized_ground_truth(self, condition=None, query=None):
        """Generate synthesized ground truth when MIMIC-CXR data is not available"""
        debug_print("Using synthesized ground truth (fallback)")

        # For differentiation queries
        if query and ("differentiate" in query.lower() or "difference" in query.lower()):
            conditions = [c for c in get_mimic_cxr_conditions() if c in query.lower()]
            if len(conditions) >= 2:
                return f"The key radiographic differences between {conditions[0]} and {conditions[1]} relate to their distribution, appearance, and associated findings."

        # For specific condition queries
        if condition:
            return f"The typical radiographic appearance of {condition} includes its characteristic findings as documented in radiology literature."

        # General fallback
        return "Accurate chest X-ray interpretation requires assessment of all anatomical structures and potential pathological findings based on established radiographic principles."

    def get_ground_truth(self, image_id=None, condition=None, query=None):
        """Get ground truth from MIMIC-CXR index

        Args:
            image_id: ID of the image if available
            condition: Specific condition to get ground truth for
            query: Query text to find relevant ground truth

        Returns:
            str: Ground truth text from MIMIC-CXR
        """
        try:
            if not self.has_mimic_data:
                return self.get_synthesized_ground_truth(condition, query)

            if image_id and image_id in self.mimic_index["by_image_id"]:
                record = self.mimic_index["by_image_id"][image_id]
                if record["impression"]:
                    return anonymize_text(record["impression"])
                elif record["findings"]:
                    return anonymize_text(record["findings"])

            if condition and condition in self.mimic_index["by_condition"]:
                condition_examples = self.mimic_index["by_condition"][condition]
                if condition_examples:
                    random_id = random.choice(condition_examples)
                    record = self.mimic_index["by_image_id"][random_id]
                    if record["impression"]:
                        return anonymize_text(record["impression"])
                    elif record["findings"]:
                        return anonymize_text(record["findings"])

            return self.get_synthesized_ground_truth(condition, query)

        except Exception as e:
            debug_print(f"Error accessing ground truth: {str(e)}")
            return self.get_synthesized_ground_truth(condition, query)

    @performance_monitor
    def process_image_query(self, image):
        """Process an image query with explicit attention weight extraction"""
        try:
            debug_print("Starting process_image_query with attention weight focus")
            image_embedding = self.generate_image_embedding(image)
            debug_print(f"Image embedding generated, shape: {image_embedding.shape}")

            if not hasattr(self, 'last_attention_weights'):
                debug_print("Warning: No attention weights were captured during embedding generation!")
                self.last_attention_weights = np.ones((7, 7))
            else:
                debug_print(f"Attention weights captured, shape: {self.last_attention_weights.shape}")

            image_embedding_list = image_embedding.flatten().tolist()
            image_embedding_tuple = tuple(image_embedding_list)
            debug_print(f"Converted to tuple of length: {len(image_embedding_tuple)}")

            try:
                debug_print("Calling hybrid_retrieve")
                hybrid_result = self.hybrid_retrieve("Describe this chest X-ray", image_embedding_tuple)
                if isinstance(hybrid_result, tuple) and len(hybrid_result) == 3:
                    retrieved_texts, indices, distances = hybrid_result
                else:
                    debug_print(f"WARNING: hybrid_retrieve returned unexpected value: {type(hybrid_result)}")
                    retrieved_texts = ["No medical context available."]
                    indices = np.array([0])
                    distances = np.array([1.0])
            except Exception as e:
                debug_print(f"Error in hybrid_retrieve: {e}")
                retrieved_texts = ["Error retrieving medical context."]
                indices = np.array([0])
                distances = np.array([1.0])

            debug_print("Generating response")
            # Generate only the findings and impression, without visual analysis
            query = "Describe this chest X-ray"

            # Clean and prepare context
            cleaned_texts = []
            for text in retrieved_texts[:5]:
                # Remove any potential patient identifiers before using as context
                cleaned_text = re.sub(r'\[REDACTED\]', '', text)
                cleaned_texts.append(cleaned_text)

            context = "\n\nRetrieved Context:\n" + "\n".join(cleaned_texts)
            input_text = f"Query: {query}\nContext: {context}\nAnswer:"

            # Generate the medical findings
            try:
                inputs = self.gen_tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True).to(self.device)

                with torch.no_grad():
                    outputs = self.gen_model.generate(
                        **inputs,
                        max_length=250,
                        num_beams=2,
                        early_stopping=True,
                        do_sample=True,
                        temperature=0.7,
                        no_repeat_ngram_size=3
                    )

                response = self.gen_tokenizer.decode(outputs[0], skip_special_tokens=True)

                # Format the response without including visual analysis again
                medical_findings = self.format_response(response, query)
            except Exception as e:
                debug_print(f"Error in text generation: {str(e)}")
                medical_findings = "Unable to generate findings due to an error."

            # Generate the visual analysis separately
            visual_analysis = self.generate_visual_analysis()

            # Combine findings with visual analysis
            combined_response = f"{medical_findings}\n\n{visual_analysis}"

            debug_print("Capturing attention weights for visualization")
            attention_weights = np.array(self.last_attention_weights).copy()
            debug_print(f"Attention weights shape for return: {attention_weights.shape}")

            # Evaluate response
            self.evaluate_response(query, combined_response, None)

            debug_print("Returning from process_image_query with response and attention weights")
            return combined_response, attention_weights

        except Exception as e:
            debug_print(f"Error in process_image_query: {str(e)}")
            debug_print(traceback.format_exc())
            return f"Error processing image: {str(e)}", np.ones((7, 7))

    @performance_monitor
    def process_text_query(self, text):
        """Process a text query"""
        debug_print(f"process_text_query called with text: {text[:50]}...")
        try:
            debug_print("Retrieving documents for text query")
            retrieved_texts, indices, distances = self.hybrid_retrieve(text)
            debug_print(f"Retrieved {len(retrieved_texts)} texts")

            debug_print("Generating response")
            response = self.generate_response(text, retrieved_texts, indices, distances)
            debug_print(f"Response generated, length: {len(response)}")

            debug_print("Returning from process_text_query")
            return response, None

        except Exception as e:
            debug_print(f"ERROR in process_text_query: {str(e)}")
            debug_print(traceback.format_exc())
            return (f"Error processing text query: {str(e)}", None)

    def free_memory(self):
        """Free memory by clearing caches and moving models to CPU"""
        debug_print("Freeing memory...")
        try:
            # Clear cache
            if hasattr(self, 'cache'):
                self.cache.clear()
                debug_print("Cache cleared")

            # Move models to CPU to free GPU memory
            if torch.cuda.is_available():
                # Check if we're using lazy loading
                if Config.LAZY_LOADING:
                    if hasattr(self, 'text_model') and isinstance(self.text_model, LazyModel):
                        self.text_model.unload()
                        debug_print("Unloading model dmis-lab/biobert-v1.1")

                    if hasattr(self, 'gen_model') and isinstance(self.gen_model, LazyModel):
                        self.gen_model.unload()
                        debug_print("Unloading model <your_path>")
                else:
                    # Move regular models to CPU
                    if hasattr(self, 'text_model'):
                        self.text_model = self.text_model.cpu()

                    if hasattr(self, 'image_model'):
                        self.image_model = self.image_model.cpu()

                    if hasattr(self, 'gen_model'):
                        self.gen_model = self.gen_model.cpu()

                # Clear CUDA cache
                torch.cuda.empty_cache()
                debug_print("Models moved to CPU and CUDA cache cleared")

            # Clear any stored embeddings or attention weights
            if hasattr(self, 'last_attention_weights'):
                del self.last_attention_weights

            if hasattr(self, 'raw_activations'):
                del self.raw_activations

            if hasattr(self, 'region_attention'):
                del self.region_attention

            if hasattr(self, 'detected_regions'):
                del self.detected_regions

            if hasattr(self, 'primary_regions'):
                del self.primary_regions

            # Run garbage collection
            gc.collect()

            debug_print("Memory freed successfully")
        except Exception as e:
            debug_print(f"Error freeing memory: {str(e)}")

# === Utility Functions for Gradio Interface ===
def process_query(text_query=None, image_file=None):
    """Process user query with improved MIMIC-CXR based evaluation"""
    global rag

    debug_print(f"process_query called with text_query: {'provided' if text_query else 'None'}, image_file: {'provided' if image_file is not None else 'None'}")

    if not text_query and image_file is None:
        return "Please provide a query or image.", None

    try:
        image_id = None
        if image_file is not None and hasattr(image_file, 'filename'):
            filename = os.path.basename(image_file.filename)
            match = re.search(r'p\d+/p\d+/s\d+/\d+', filename)
            if match:
                image_id = match.group(0)
                debug_print(f"Extracted image_id: {image_id}")

        if image_file is not None:
            debug_print(f"Processing image query, image_file type: {type(image_file)}")

            # Handle different image input types
            if not isinstance(image_file, Image.Image):
                debug_print(f"Converting {type(image_file)} to PIL Image")
                if isinstance(image_file, np.ndarray):
                    # NumPy array
                    image_file = Image.fromarray(image_file)
                elif isinstance(image_file, str) and os.path.exists(image_file):
                    # File path
                    image_file = Image.open(image_file)
                else:
                    # Try generic conversion
                    image_file = Image.fromarray(np.array(image_file))

            # Ensure image is in correct format (RGB)
            if image_file.mode != 'RGB':
                debug_print(f"Converting image from {image_file.mode} to RGB")
                image_file = image_file.convert('RGB')

            result = rag.process_image_query(image_file)

            if isinstance(result, tuple) and len(result) >= 2:
                response_text = result[0]
                attention_weights = result[1]
            else:
                debug_print(f"Warning: Expected tuple result, got {type(result)}")
                response_text = str(result)
                attention_weights = None

            attention_map = None
            if attention_weights is not None:
                attention_map = visualize_attention_map(image_file, attention_weights)

            # Internal evaluation (not shown to user)
            rag.evaluate_response(
                "Describe this chest X-ray",
                response_text,
                image_id
            )

            # Clean up response - remove confidence explanation
            response_text = remove_confidence_section(response_text)

            return response_text, attention_map

        elif text_query and len(text_query.strip()) >= 3:
            debug_print("Processing text query")
            result = rag.process_text_query(text_query)

            if isinstance(result, tuple) and len(result) >= 1:
                response_text = result[0]
            else:
                response_text = str(result)

            # Internal evaluation (not shown to user)
            rag.evaluate_response(
                text_query,
                response_text,
                None
            )

            # Clean up response - remove confidence explanation
            response_text = remove_confidence_section(response_text)

            return response_text, None

        else:
            return "Please enter a detailed query or upload an image.", None

    except Exception as e:
        error_message = f"Error processing query: {str(e)}"
        debug_print(error_message)
        debug_print(traceback.format_exc())
        return error_message, None

def remove_confidence_section(text):
    """Remove confidence information from response"""
    # Pattern to match confidence section - matches standard formats used
    confidence_patterns = [
        r'Confidence Level:.*?(?=\n\n|$)',
        r'Explanation:.*?(?=\n\n|$)'
    ]

    result = text
    for pattern in confidence_patterns:
        result = re.sub(pattern, '', result, flags=re.DOTALL)

    # Clean up any duplicate newlines created by removal
    result = re.sub(r'\n{3,}', '\n\n', result)

    return result.strip()

def visualize_attention_map(image, attention_weights):
    """Create a visualization of model attention on the image with enhanced anatomical mapping"""
    debug_print("visualize_attention_map called with:")
    debug_print(f"- Image type: {type(image)}, size: {image.size if hasattr(image, 'size') else 'unknown'}")
    debug_print(f"- Attention weights type: {type(attention_weights)}")

    try:
        if attention_weights is None:
            debug_print("ERROR: attention_weights is None")
            return None

        if not isinstance(image, Image.Image):
            debug_print(f"Converting {type(image)} to PIL Image")
            image = Image.fromarray(np.array(image))

        img_array = np.array(image)
        debug_print(f"Image array shape: {img_array.shape}")

        if hasattr(attention_weights, 'shape'):
            debug_print(f"Attention weights shape: {attention_weights.shape}")
            if len(attention_weights.shape) == 1:
                size = int(np.sqrt(attention_weights.shape[0]))
                attention_weights = attention_weights.reshape(size, size)
                debug_print(f"Reshaped attention weights to {attention_weights.shape}")
            elif len(attention_weights.shape) > 2:
                attention_weights = np.mean(attention_weights, axis=0)
                debug_print(f"Averaged attention weights to shape {attention_weights.shape}")
        else:
            debug_print("ERROR: attention_weights has no shape attribute")
            attention_weights = np.ones((224, 224))

        attention_weights = np.abs(attention_weights)

        # Apply Gaussian smoothing to make heatmap more natural
        attention_weights = cv2.GaussianBlur(attention_weights, (5, 5), 0)

        target_size = (img_array.shape[1], img_array.shape[0])
        debug_print(f"Resizing attention map to {target_size}")
        attention_map = cv2.resize(attention_weights, target_size)

        # Enhanced normalization with histogram equalization for better contrast
        min_val = np.min(attention_map)
        max_val = np.max(attention_map)
        if max_val > min_val:
            # Apply non-linear transformation to enhance contrast
            attention_map = np.power((attention_map - min_val) / (max_val - min_val), 0.7)
        else:
            attention_map = np.zeros_like(attention_map)

        debug_print(f"Normalized attention map shape: {attention_map.shape}, min: {np.min(attention_map)}, max: {np.max(attention_map)}")

        # Use improved colormap for medical imaging
        heatmap = cv2.applyColorMap((attention_map * 255).astype(np.uint8), cv2.COLORMAP_INFERNO)
        debug_print(f"Heatmap shape: {heatmap.shape}")

        # Prepare the background image
        if len(img_array.shape) == 2:
            debug_print("Converting grayscale image to BGR")
            img_bgr = cv2.cvtColor(img_array, cv2.COLOR_GRAY2BGR)
        elif len(img_array.shape) == 3:
            if img_array.shape[2] == 4:
                debug_print("Converting RGBA image to BGR")
                img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGBA2BGR)
            elif img_array.shape[2] == 3:
                debug_print("Converting RGB image to BGR")
                img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
            else:
                debug_print(f"Unexpected image shape: {img_array.shape}")
                img_bgr = img_array
        else:
            debug_print(f"Unexpected image dimensions: {img_array.shape}")
            img_bgr = img_array

        if img_bgr.shape[:2] != heatmap.shape[:2]:
            debug_print(f"Warning: Image shape {img_bgr.shape[:2]} and heatmap shape {heatmap.shape[:2]} mismatch")
            heatmap = cv2.resize(heatmap, (img_bgr.shape[1], img_bgr.shape[0]))

        # Enhanced overlay with adaptive transparency
        # This makes the heatmap more transparent in low-attention areas
        alpha = np.clip(attention_map * 0.7, 0.2, 0.7)  # Adaptive transparency
        alpha = np.expand_dims(alpha, axis=2)
        alpha = np.repeat(alpha, 3, axis=2)  # Repeat for RGB channels

        # Create the blended image
        overlay = img_bgr.astype(float) * (1 - alpha) + heatmap.astype(float) * alpha
        overlay = np.clip(overlay, 0, 255).astype(np.uint8)

        # Add anatomical region markers if available
        if hasattr(rag, 'detected_regions') and hasattr(rag, 'primary_regions'):
            # Define region coordinates (normalized to image size)
            h, w = img_bgr.shape[:2]
            regions = {
                "upper_right_lung": (int(w*0.75), int(h*0.25)),
                "upper_left_lung": (int(w*0.25), int(h*0.25)),
                "middle_right_lung": (int(w*0.75), int(h*0.5)),
                "lower_right_lung": (int(w*0.75), int(h*0.75)),
                "lower_left_lung": (int(w*0.25), int(h*0.75)),
                "heart": (int(w*0.5), int(h*0.5)),
                "hilar": (int(w*0.5), int(h*0.4)),
                "costophrenic_angles": (int(w*0.5), int(h*0.85)),
                "spine": (int(w*0.5), int(h*0.5)),
                "diaphragm": (int(w*0.5), int(h*0.75))
            }

            # Add markers for primary regions
            for region_name in rag.primary_regions:
                if region_name in regions:
                    x, y = regions[region_name]
                    cv2.circle(overlay, (x, y), 10, (255, 255, 255), -1)
                    cv2.circle(overlay, (x, y), 10, (0, 0, 0), 2)

                    # Add label
                    label = region_name.replace('_', ' ').title()
                    cv2.putText(overlay, label,
                              (x - 50, y - 15),
                              cv2.FONT_HERSHEY_SIMPLEX,
                              0.5, (255, 255, 255), 2)

        result = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)
        debug_print(f"Final visualization shape: {result.shape}")

        try:
            cv2.imwrite('debug_visualization.jpg', cv2.cvtColor(result, cv2.COLOR_RGB2BGR))
            debug_print("Debug visualization saved to debug_visualization.jpg")
        except Exception as save_error:
            debug_print(f"Error saving debug visualization: {save_error}")

        return result

    except Exception as e:
        debug_print(f"Error in visualize_attention_map: {str(e)}")
        debug_print(traceback.format_exc())
        if isinstance(image, Image.Image):
            return np.array(image)
        else:
            return image

# === Examples for Gradio Interface ===
def get_examples():
    """Generate examples for the Gradio interface"""
    examples = [
        ["What does pleural effusion look like on a chest X-ray?"],
        ["How to differentiate pulmonary edema from pneumonia?"],
        ["What radiographic findings are typical for tuberculosis?"],
        ["What are the key features of cardiomegaly on X-ray?"],
        ["How do atelectasis and pneumothorax differ radiographically?"],
        ["What should I look for to identify lung nodules?"],
        ["Describe radiographic signs of COPD on chest X-ray."],
        ["What features suggest malignancy in a lung nodule?"]
    ]

    return examples[:Config.EXAMPLES_TO_SHOW]

# === Main execution ===
if __name__ == "__main__":
    print("\n\n===== STARTING MEDIQUERY RAG SYSTEM (ENHANCED 2025) =====")
    print(f"Current date and time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"MediQuery version: 2.0.1 (March 9, 2025)")
    print(f"Developer: Tanishk")

    try:
        print("Mounting Google Drive...")
        drive.mount('/content/drive', force_remount=True)
        print("Google Drive mounted successfully")
    except Exception as e:
        print(f"Error mounting drive: {str(e)}")
        print("Continuing without mounted drive. Some features may not be available.")

    try:
        print("Initializing RAG system...")
        rag = MediQueryRAG(
            knowledge_base_dir="<your_path>",
            finetuned_model_path="<your_path>"
        )
        print("RAG system initialized successfully")

        print("Testing RAG with simple text query...")
        test_result = rag.process_text_query("What does pneumonia look like?")
        print(f"Test query result type: {type(test_result)}, is tuple: {isinstance(test_result, tuple)}")
        if isinstance(test_result, tuple):
            print(f"Test response length: {len(test_result[0]) if test_result[0] else 'empty'}")

        print("Creating Gradio interface...")
        with gr.Blocks(theme=gr.themes.Base()) as interface:
            gr.Markdown("# MediQuery - Advanced Chest X-Ray Analysis")
            gr.Markdown("""
            Ask about chest X-rays or upload an image for analysis.

            **Developed by:** Tanishk | **Version:** 2.0.1 (March 2025)

            *This tool uses a fine-tuned RAG system on the MIMIC-CXR dataset to provide medical imaging analysis.*
            """)

            with gr.Row():
                with gr.Column(scale=1):
                    text_input = gr.Textbox(
                        label="Enter your medical query",
                        placeholder="e.g., 'What does pneumonia look like?'",
                        lines=2
                    )
                    image_input = gr.Image(
                        label="Upload a chest X-ray image",
                        type="pil",
                        height=Config.IMAGE_HEIGHT
                    )
                    submit_btn = gr.Button("Submit", variant="primary")

                    gr.Examples(
                        get_examples(),
                        inputs=[text_input],
                        label="Example Queries"
                    )

                with gr.Column(scale=2):
                    text_output = gr.Textbox(
                        label="Medical Analysis",
                        lines=15
                    )
                    image_output = gr.Image(
                        label="Regions of Interest Visualization",
                        height=Config.IMAGE_HEIGHT
                    )

            submit_btn.click(
                fn=process_query,
                inputs=[text_input, image_input],
                outputs=[text_output, image_output]
            )

            gr.Markdown("""
            ### Important Notes:
            - This tool is for educational purposes only and is not a substitute for professional medical advice
            - Results should be verified by qualified healthcare professionals
            - For real medical concerns, please consult with a doctor

            *MIMIC-CXR dataset is used under appropriate research permissions.*
            """)

        print("Launching Gradio interface...")
        interface.launch(debug=True)

        metrics = rag.get_evaluation_metrics()
        print(f"Final Evaluation Metrics: {metrics}")

        # Release resources when done
        rag.free_memory()

    except Exception as e:
        print(f"Critical error initializing or running RAG system: {str(e)}")
        print(traceback.format_exc())