In [1]:
import os
import glob
import pandas as pd
import ast
import re
from collections import defaultdict
import json
import datetime
import time
import traceback
from PIL import Image
from dotenv import load_dotenv
from google import genai
import numpy as np
import lancedb
from datasets import load_dataset
from sentence_transformers import SentenceTransformer, CrossEncoder
from rank_bm25 import BM25Okapi
import torch
import random



In [2]:
class Args:
    def __init__(self, use_finetuning=True, use_test_dataset=True):
        """
        Initialize arguments with options for dataset and model type.
        
        Parameters:
        - use_finetuning: Whether to use the fine-tuned model predictions (True) or base model predictions (False)
        - use_test_dataset: Whether to use the test dataset (True) or validation dataset (False)
        """
        self.use_finetuning = use_finetuning
        self.use_test_dataset = use_test_dataset
        
        # Base directory paths
        self.base_dir = os.getcwd()
        self.output_dir = os.path.join(self.base_dir, "outputs")
        self.model_predictions_dir = os.path.join(self.output_dir, "05022025")
        
        # Set paths based on dataset type
        if self.use_test_dataset:
            self.dataset_name = "test"
            self.dataset_path = os.path.join(self.output_dir, "test_dataset.csv")
            self.images_dir = os.path.join(self.base_dir, "2025_dataset", "test", "images_test")
            self.prediction_prefix = "aggregated_test_predictions_"
        else:
            self.dataset_name = "validation"
            self.dataset_path = os.path.join(self.output_dir, "val_dataset.csv")
            self.images_dir = os.path.join(self.base_dir, "2025_dataset", "valid", "images_valid")
            self.prediction_prefix = "aggregated_predictions_"
        
        # Set model type suffix
        self.model_type = "finetuned" if self.use_finetuning else "base"
        
        # Other configurations
        self.gemini_model = "gemini-2.5-flash-preview-04-17"
        
        # Reflection configurations
        self.max_reflection_cycles = 2
        self.confidence_threshold = 0.75  # Threshold for accepting an answer without reflection
        
        # RAG-specific configurations
        self.knowledge_db_path = os.path.join(self.base_dir, "knowledge_db")
        self.embedding_model = "pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb"
        self.cross_encoder_model = "cross-encoder/ms-marco-MiniLM-L-6-v2"
        self.vector_dimension = 768
        self.top_k_semantic = 7
        self.top_k_keyword = 7
        self.top_k_hybrid = 10
        self.top_k_rerank = 5
        
        # Dataset
        self.dataset_name_huggingface = "brucewayne0459/Skin_diseases_and_care"
        
        # Question type configurations for RAG
        self.question_type_retrieval_config = {
            "Site Location": {"use_rag": False, "weight": 0.2},
            "Lesion Color": {"use_rag": False, "weight": 0.2},
            "Size": {"use_rag": False, "weight": 0.1},
            "Skin Description": {"use_rag": True, "weight": 0.3},
            "Onset": {"use_rag": True, "weight": 0.4},
            "Itch": {"use_rag": True, "weight": 0.4},
            "Extent": {"use_rag": False, "weight": 0.2},
            "Treatment": {"use_rag": True, "weight": 0.7},
            "Lesion Evolution": {"use_rag": True, "weight": 0.5},
            "Texture": {"use_rag": True, "weight": 0.3},
            "Lesion Count": {"use_rag": False, "weight": 0.1},
            "Differential": {"use_rag": True, "weight": 0.8},
            "Specific Diagnosis": {"use_rag": True, "weight": 0.8},
        }
        
        # Default for question types not explicitly listed
        self.default_rag_config = {"use_rag": True, "weight": 0.4}
        
        print(f"\nConfiguration initialized:")
        print(f"- Using {'test' if self.use_test_dataset else 'validation'} dataset")
        print(f"- Looking for {self.model_type} model predictions")
        print(f"- Dataset path: {self.dataset_path}")
        print(f"- Images directory: {self.images_dir}")
        print(f"- Prediction file prefix: {self.prediction_prefix}")

In [3]:
class DataLoader:
    @staticmethod
    def get_latest_aggregated_files(args):
        """Get the latest aggregated prediction files for each model."""
        # Use the appropriate pattern based on args
        pattern = os.path.join(args.model_predictions_dir, f"{args.prediction_prefix}*_{args.model_type}_*.csv")
        
        agg_files = glob.glob(pattern)
        
        if len(agg_files) == 0:
            return []
        
        latest_files = {}
        
        for file_path in agg_files:
            file_name = os.path.basename(file_path)
            
            parts = file_name.split(f"_{args.model_type}_")
            if len(parts) != 2:
                print(f"Warning: Unexpected filename format: {file_name}")
                continue
            
            model_part = parts[0].replace(args.prediction_prefix, "")
            model_name = model_part
            
            timestamps = re.findall(r'(\d+)', parts[1])
            if len(timestamps) < 2:
                print(f"Warning: Could not find timestamps in {file_name}")
                continue
            
            timestamp = int(timestamps[1])
            
            if model_name not in latest_files or timestamp > latest_files[model_name]['timestamp']:
                latest_files[model_name] = {
                    'file_path': file_path,
                    'timestamp': timestamp
                }
        
        return [info['file_path'] for _, info in latest_files.items()]
    
    @staticmethod
    def load_all_model_predictions(args):
        """Load all model predictions from aggregated files."""
        latest_files = DataLoader.get_latest_aggregated_files(args)
        
        if not latest_files:
            print("No aggregated prediction files found. Cannot proceed.")
            return {}
        
        model_predictions = {}
        
        for file_path in latest_files:
            file_name = os.path.basename(file_path)
            
            parts = file_name.split(f"_{args.model_type}_")
            if len(parts) != 2:
                print(f"Warning: Unexpected filename format: {file_name}")
                continue
                
            model_name = parts[0].replace(args.prediction_prefix, "")
            
            try:
                df = pd.read_csv(file_path)
                
                df['model_name'] = model_name
                
                model_predictions[model_name] = df
                
            except Exception as e:
                print(f"Error loading {file_path}: {e}")
        
        return model_predictions

    @staticmethod
    def load_validation_dataset(args):
        """Load the validation dataset."""
        val_df = pd.read_csv(args.dataset_path)
        
        val_df = DataLoader.process_validation_dataset(val_df)
        
        encounter_question_data = defaultdict(lambda: {
            'images': [],
            'data': None
        })
        
        for _, row in val_df.iterrows():
            encounter_id = row['encounter_id']
            base_qid = row['base_qid']
            key = (encounter_id, base_qid)
            
            if 'image_path' in row and row['image_path']:
                encounter_question_data[key]['images'].append(row['image_path'])
            elif 'image_id' in row and row['image_id']:
                image_path = os.path.join(args.images_dir, row['image_id'])
                encounter_question_data[key]['images'].append(image_path)
            
            if encounter_question_data[key]['data'] is None:
                encounter_question_data[key]['data'] = row.to_dict()
        
        grouped_data = []
        for (encounter_id, base_qid), data in encounter_question_data.items():
            entry = data['data'].copy()
            entry['all_images'] = data['images']
            entry['encounter_id'] = encounter_id
            entry['base_qid'] = base_qid
            grouped_data.append(entry)
        
        return pd.DataFrame(grouped_data)
    
    @staticmethod
    def safe_convert_options(options_str):
        """Safely convert a string representation of a list to an actual list."""
        if not isinstance(options_str, str):
            return options_str
            
        try:
            return ast.literal_eval(options_str)
        except (SyntaxError, ValueError):
            if options_str.startswith('[') and options_str.endswith(']'):
                return [opt.strip().strip("'\"") for opt in options_str[1:-1].split(',')]
            elif ',' in options_str:
                return [opt.strip() for opt in options_str.split(',')]
            else:
                return [options_str]
    
    @staticmethod
    def process_validation_dataset(val_df):
        """Process and clean the validation dataset."""
        if 'options_en' in val_df.columns:
            val_df['options_en'] = val_df['options_en'].apply(DataLoader.safe_convert_options)
            
            def clean_options(options):
                if not isinstance(options, list):
                    return options
                    
                cleaned_options = []
                for opt in options:
                    if isinstance(opt, str):
                        cleaned_opt = opt.strip("'\" ").replace(" (please specify)", "")
                        cleaned_options.append(cleaned_opt)
                    else:
                        cleaned_options.append(str(opt).strip("'\" "))
                return cleaned_options
                
            val_df['options_en_cleaned'] = val_df['options_en'].apply(clean_options)
        
        if 'question_text' in val_df.columns:
            val_df['question_text_cleaned'] = val_df['question_text'].apply(
                lambda q: q.replace(" Please specify which affected area for each selection.", "") 
                          if isinstance(q, str) and "Please specify which affected area for each selection" in q 
                          else q
            )
            
            val_df['question_text_cleaned'] = val_df['question_text_cleaned'].apply(
                lambda q: re.sub(r'^\d+\s+', '', q) if isinstance(q, str) else q
            )
        
        if 'base_qid' not in val_df.columns and 'qid' in val_df.columns:
            val_df['base_qid'] = val_df['qid'].apply(
                lambda q: q.split('-')[0] if isinstance(q, str) and '-' in q else q
            )
        
#         print(val_df)
        return val_df

In [4]:
class DataProcessor:
    @staticmethod
    def create_query_context(row, args=None):
        """Create query context from validation data similar to the inference process."""
        question = row.get('question_text_cleaned', row.get('question_text', 'What do you see in this image?'))
        
        metadata = ""
        if 'question_type_en' in row:
            metadata += f"Type: {row['question_type_en']}"
            
        if 'question_category_en' in row:
            metadata += f", Category: {row['question_category_en']}"
        
        query_title = row.get('query_title_en', '')
        query_content = row.get('query_content_en', '')
        
        clinical_context = ""
        if query_title or query_content:
            clinical_context += "Background Clinical Information (to help with your analysis):\n"
            if query_title:
                clinical_context += f"{query_title}\n"
            if query_content:
                clinical_context += f"{query_content}\n"
        
        options = row.get('options_en_cleaned', row.get('options_en', ['Yes', 'No', 'Not mentioned']))
        if isinstance(options, list):
            options_text = ", ".join(options)
        else:
            options_text = str(options)
        
        query_text = (f"MAIN QUESTION TO ANSWER: {question}\n"
                     f"Question Metadata: {metadata}\n"
                     f"{clinical_context}"
                     f"Available Options (choose from these): {options_text}")
        
        return query_text

In [5]:
class AgenticRAGData:
    def __init__(self, all_models_df, validation_df):
        self.all_models_df = all_models_df
        self.validation_df = validation_df
        
        self.model_predictions = {}
        for (encounter_id, base_qid), group in all_models_df.groupby(['encounter_id', 'base_qid']):
            self.model_predictions[(encounter_id, base_qid)] = group
        
        self.validation_data = {}
        for _, row in validation_df.iterrows():
            self.validation_data[(row['encounter_id'], row['base_qid'])] = row
    
    def get_combined_data(self, encounter_id, base_qid):
        """Retrieve combined data for a specific encounter and question."""
        model_preds = self.model_predictions.get((encounter_id, base_qid), None)
        
        val_data = self.validation_data.get((encounter_id, base_qid), None)
        
        if model_preds is None:
            print(f"No model predictions found for encounter {encounter_id}, question {base_qid}")
            return None
            
        if val_data is None:
            print(f"No validation data found for encounter {encounter_id}, question {base_qid}")
            return None
        
        if 'query_context' not in val_data:
            val_data['query_context'] = DataProcessor.create_query_context(val_data)
        
        model_predictions_dict = {}
        for _, row in model_preds.iterrows():
            model_name = row['model_name']
            
            model_predictions_dict[model_name] = self._process_model_predictions(row)
        
        return {
            'encounter_id': encounter_id,
            'base_qid': base_qid,
            'query_context': val_data['query_context'],
            'images': val_data.get('all_images', []),
            'options': val_data.get('options_en_cleaned', val_data.get('options_en', [])),
            'question_type': val_data.get('question_type_en', ''),
            'question_category': val_data.get('question_category_en', ''),
            'model_predictions': model_predictions_dict
        }
    
    def _process_model_predictions(self, row):
        """Process model predictions from row data."""
        return {
            'model_prediction': row.get('combined_prediction', '')
        }
    
    def get_all_encounter_question_pairs(self):
        """Return a list of all unique encounter_id, base_qid pairs."""
        return list(self.validation_data.keys())
    
    def get_sample_data(self, n=5):
        """Get a sample of combined data for n random encounter-question pairs."""
        import random
        
        all_pairs = self.get_all_encounter_question_pairs()
        sample_pairs = random.sample(all_pairs, min(n, len(all_pairs)))
        
        return [self.get_combined_data(encounter_id, base_qid) for encounter_id, base_qid in sample_pairs]

In [6]:
def parse_json_response(text):
    """Parse JSON from LLM response."""
    cleaned_text = text
    if "```json" in cleaned_text:
        cleaned_text = cleaned_text.split("```json")[1]
    if "```" in cleaned_text:
        cleaned_text = cleaned_text.split("```")[0]
    
    try:
        return json.loads(cleaned_text)
    except json.JSONDecodeError:
        print(f"Warning: Could not parse as JSON")
        return {"parse_error": "Could not parse as JSON", "raw_text": text}

In [7]:
class KnowledgeBaseManager:
    """Manages the dermatology knowledge base for RAG."""

    def __init__(self, args=None):
        """Initialize the knowledge base manager."""
        self.args = args
        self.embedding_model = SentenceTransformer(args.embedding_model if args else "pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb")
        self.cross_encoder = CrossEncoder(args.cross_encoder_model if args else "cross-encoder/ms-marco-MiniLM-L-6-v2")

        # Initialize LanceDB
        self.db_path = args.knowledge_db_path if args else os.path.join(os.getcwd(), "knowledge_db")
        os.makedirs(self.db_path, exist_ok=True)
        self.db = lancedb.connect(self.db_path)

        # Check if the table exists, create it if not
        self.table_name = "dermatology_knowledge"

        if self.table_name not in self.db.table_names():
            print(f"Knowledge base not found. Creating new knowledge base at {self.db_path}")
            self._initialize_knowledge_base()
        else:
            print(f"Using existing knowledge base at {self.db_path}")
            self.table = self.db.open_table(self.table_name)

        # BM25 index for keyword search
        self.tokenized_corpus = []
        self.doc_ids = []
        self._initialize_bm25_index()
    
    def _initialize_knowledge_base(self):
        """Initialize the knowledge base with the skin diseases dataset."""
        print("Loading dermatology dataset...")
        dataset_name = self.args.dataset_name_huggingface if self.args else "brucewayne0459/Skin_diseases_and_care"
        dataset = load_dataset(dataset_name)

        # Prepare data for LanceDB
        data = []

        print("Processing dataset and creating embeddings...")
        for i, item in enumerate(dataset['train']):
            topic = item['Topic']
            information = item['Information']

            # Create document
            combined_text = f"Topic: {topic}\n\nInformation: {information}"

            # Create embedding
            embedding = self.embedding_model.encode(combined_text)

            # Add to data
            data.append({
                "id": i,
                "topic": topic,
                "information": information,
                "combined_text": combined_text,
                "vector": embedding.tolist()
            })

            if (i + 1) % 100 == 0:
                print(f"Processed {i + 1} documents")

        # Convert the data to a pandas DataFrame
        import pandas as pd
        data_df = pd.DataFrame(data)

        # Create LanceDB table
        print("Creating vector database...")
        self.table = self.db.create_table(
            self.table_name,
            data=data_df
        )
        print("Knowledge base initialization complete.")
    
    def _initialize_bm25_index(self):
        """Initialize the BM25 index for keyword search without NLTK dependencies."""
        print("Initializing BM25 index...")

        # Query all documents from LanceDB
        results = self.table.search().limit(10000).to_pandas()

        # Common English stopwords - hardcoded to avoid NLTK dependency
        common_stopwords = {
            "a", "an", "the", "and", "or", "but", "in", "on", "at", "to", "for", "with", 
            "by", "about", "from", "as", "of", "is", "are", "was", "were", "be", "been",
            "being", "have", "has", "had", "do", "does", "did", "can", "could", "will",
            "would", "shall", "should", "may", "might", "must", "this", "that", "these",
            "those", "it", "its", "they", "them", "their", "he", "him", "his", "she", "her"
        }

        for idx, row in results.iterrows():
            doc_text = row['combined_text']
            self.doc_ids.append(row['id'])

            # Simple tokenization without NLTK
            # Split by whitespace and remove punctuation
            tokens = []
            for token in doc_text.lower().split():
                # Remove punctuation
                token = ''.join(c for c in token if c.isalnum())
                if token and token not in common_stopwords:
                    tokens.append(token)

            self.tokenized_corpus.append(tokens)

        # Create BM25 index
        self.bm25 = BM25Okapi(self.tokenized_corpus)
        print("BM25 index initialization complete.")
    
    def semantic_search(self, query, top_k=None):
        """Perform semantic search using embeddings."""
        if top_k is None:
            top_k = self.args.top_k_semantic if self.args else 7
        
        # Create query embedding
        query_embedding = self.embedding_model.encode(query)
        
        # Search LanceDB
        results = self.table.search(query_embedding.tolist()).limit(top_k).to_pandas()
        
        return results
    
    def keyword_search(self, query, top_k=None):
        """Perform keyword search using BM25."""
        if top_k is None:
            top_k = self.args.top_k_keyword if self.args else 7

        # Simple tokenization without NLTK
        # Split by whitespace and remove punctuation
        common_stopwords = {
            "a", "an", "the", "and", "or", "but", "in", "on", "at", "to", "for", "with", 
            "by", "about", "from", "as", "of", "is", "are", "was", "were", "be", "been",
            "being", "have", "has", "had", "do", "does", "did", "can", "could", "will",
            "would", "shall", "should", "may", "might", "must", "this", "that", "these",
            "those", "it", "its", "they", "them", "their", "he", "him", "his", "she", "her"
        }

        query_tokens = []
        for token in query.lower().split():
            # Remove punctuation
            token = ''.join(c for c in token if c.isalnum())
            if token and token not in common_stopwords:
                query_tokens.append(token)

        # Get BM25 scores
        doc_scores = self.bm25.get_scores(query_tokens)
        
        # Get top-k results
        top_indices = np.argsort(doc_scores)[::-1][:top_k]
        
        # Convert to document IDs and scores
        results = []
        for idx in top_indices:
            if doc_scores[idx] > 0:  # Only include if score is positive
                doc_id = self.doc_ids[idx]
                score = doc_scores[idx]
                
                # Get document from LanceDB
                doc = self.table.search().where(f"id = {doc_id}").limit(1).to_pandas()
                
                if not doc.empty:
                    results.append({
                        "id": doc_id,
                        "topic": doc['topic'].iloc[0],
                        "information": doc['information'].iloc[0],
                        "combined_text": doc['combined_text'].iloc[0],
                        "_distance": 1.0 - min(score / 10.0, 1.0)  # Convert to distance metric (0 to 1)
                    })
        
        return pd.DataFrame(results)
    
    def hybrid_search(self, query, top_k=None):
        """Perform hybrid search combining semantic and keyword search."""
        if top_k is None:
            top_k = self.args.top_k_hybrid if self.args else 10
        
        # Perform both search types
        semantic_results = self.semantic_search(query, top_k=top_k)
        keyword_results = self.keyword_search(query, top_k=top_k)
        
        # Merge results and remove duplicates
        combined_results = pd.concat([semantic_results, keyword_results])
        combined_results = combined_results.drop_duplicates(subset=['id'])
        
        # Rerank the top results
        if len(combined_results) > 0:
            return self.rerank_results(combined_results, query, top_k=min(top_k, len(combined_results)))
        else:
            return pd.DataFrame()
    
    def rerank_results(self, results, query, top_k=None):
        """Rerank search results using a cross-encoder."""
        if top_k is None:
            top_k = self.args.top_k_rerank if self.args else 5
        
        if len(results) == 0:
            return pd.DataFrame()
        
        # Prepare input for cross-encoder
        pairs = [(query, doc) for doc in results['combined_text'].tolist()]
        
        # Get scores from cross-encoder
        cross_scores = self.cross_encoder.predict(pairs)
        
        # Add scores to results
        results = results.copy()
        results['cross_score'] = cross_scores
        
        # Sort by cross-encoder score
        results = results.sort_values(by='cross_score', ascending=False).head(top_k)
        
        return results

In [8]:
class DiagnosisExtractor:
    """Extracts potential diagnoses from image analysis and clinical context."""
    
    @staticmethod
    def extract_diagnoses(image_analysis, clinical_context, args=None):
        """
        Extract potential diagnoses from image analysis and clinical context.
        
        Args:
            image_analysis: Structured image analysis containing OVERALL_IMPRESSION
            clinical_context: Structured clinical context analysis
            
        Returns:
            List of dictionaries with diagnoses and confidence scores
        """
        diagnoses = []
        
        # Extract from image analysis
        if image_analysis and "aggregated_analysis" in image_analysis:
            if "OVERALL_IMPRESSION" in image_analysis["aggregated_analysis"]:
                impression = image_analysis["aggregated_analysis"]["OVERALL_IMPRESSION"]
                if isinstance(impression, str):
                    diagnoses.extend(DiagnosisExtractor._extract_from_text(impression, source="image_analysis", confidence=0.7))
        
        # Extract from clinical context
        if clinical_context and "structured_clinical_context" in clinical_context:
            if "DIAGNOSTIC_CONSIDERATIONS" in clinical_context["structured_clinical_context"]:
                diagnostic_info = clinical_context["structured_clinical_context"]["DIAGNOSTIC_CONSIDERATIONS"]
                if isinstance(diagnostic_info, str):
                    diagnoses.extend(DiagnosisExtractor._extract_from_text(diagnostic_info, source="clinical_context", confidence=0.6))
        
        # If no diagnoses found, use extracted features to suggest potential diagnoses
        if not diagnoses:
            diagnoses = DiagnosisExtractor._suggest_from_features(image_analysis, clinical_context)
            
        return diagnoses
    
    @staticmethod
    def _extract_from_text(text, source, confidence):
        """Extract diagnoses from text."""
        import re
        
        # Common diagnostic terms in dermatology
        diagnostic_terms = [
            "eczema", "dermatitis", "psoriasis", "acne", "rosacea", "urticaria", 
            "melanoma", "carcinoma", "pemphigus", "pemphigoid", "lupus", "scleroderma",
            "folliculitis", "cellulitis", "impetigo", "tinea", "herpes", "wart",
            "vitiligo", "alopecia", "lichen", "keratosis", "prurigo", "rash"
        ]
        
        # Find diagnoses in text using regex patterns
        diagnoses = []
        
        # Pattern 1: Diagnostic terms directly mentioned
        for term in diagnostic_terms:
            pattern = fr'\b({term})[s\s]\b'
            matches = re.finditer(pattern, text.lower())
            for match in matches:
                diagnoses.append({
                    "diagnosis": match.group(1).capitalize(),
                    "confidence": confidence,
                    "source": source
                })
                
        # Pattern 2: "Consistent with", "suggestive of", "indicative of" phrases
        patterns = [
            r'consistent with\s+([^,.;]+)',
            r'suggestive of\s+([^,.;]+)',
            r'indicative of\s+([^,.;]+)',
            r'compatible with\s+([^,.;]+)',
            r'diagnostic of\s+([^,.;]+)',
            r'likely\s+([^,.;]+)',
            r'probable\s+([^,.;]+)',
            r'possible\s+([^,.;]+)',
            r'suspected\s+([^,.;]+)',
            r'diagnosis of\s+([^,.;]+)',
            r'impression:\s+([^,.;]+)'
        ]
        
        for pattern in patterns:
            matches = re.finditer(pattern, text.lower())
            for match in matches:
                diagnoses.append({
                    "diagnosis": match.group(1).strip().capitalize(),
                    "confidence": confidence * 0.9,  # Slightly lower confidence
                    "source": source
                })
        
        # Remove duplicates
        unique_diagnoses = []
        seen = set()
        for diag in diagnoses:
            if diag["diagnosis"].lower() not in seen:
                seen.add(diag["diagnosis"].lower())
                unique_diagnoses.append(diag)
        
        return unique_diagnoses
    
    @staticmethod
    def _suggest_from_features(image_analysis, clinical_context):
        """Suggest potential diagnoses based on extracted features."""
        diagnoses = []
        features = {}
        
        # Extract features from image analysis
        if image_analysis and "aggregated_analysis" in image_analysis:
            analysis = image_analysis["aggregated_analysis"]
            
            if "SKIN_DESCRIPTION" in analysis:
                features["skin_description"] = analysis["SKIN_DESCRIPTION"]
                
            if "LESION_COLOR" in analysis:
                features["lesion_color"] = analysis["LESION_COLOR"]
                
            if "SITE_LOCATION" in analysis:
                features["site_location"] = analysis["SITE_LOCATION"]
        
        # Extract features from clinical context
        if clinical_context and "structured_clinical_context" in clinical_context:
            context = clinical_context["structured_clinical_context"]
            
            if "SYMPTOMS" in context:
                features["symptoms"] = context["SYMPTOMS"]
                
            if "HISTORY" in context:
                features["history"] = context["HISTORY"]
        
        # Rule-based diagnosis suggestions based on features
        if features:
            # Example rules (simplified):
            if "hand" in str(features.get("site_location", "")).lower():
                if "scaling" in str(features.get("skin_description", "")).lower():
                    diagnoses.append({
                        "diagnosis": "Hand eczema",
                        "confidence": 0.5,
                        "source": "feature_based"
                    })
                    
            if "red" in str(features.get("lesion_color", "")).lower():
                if "itchy" in str(features.get("symptoms", "")).lower():
                    diagnoses.append({
                        "diagnosis": "Contact dermatitis",
                        "confidence": 0.4,
                        "source": "feature_based"
                    })
        
        # Always include a generic term if no specific diagnoses were found
        if not diagnoses:
            diagnoses.append({
                "diagnosis": "Dermatosis", 
                "confidence": 0.3,
                "source": "fallback"
            })
            
        return diagnoses

In [9]:
class DiagnosisBasedQueryGenerator:
    """Generates search queries based on extracted diagnoses."""
        
    def __init__(self, client, args=None):
        """Initialize the query generator."""
        self.client = client
        self.args = args
    
    def generate_queries(self, question_text, question_type, options, integrated_evidence, diagnoses, num_queries=4):
        """
        Generate search queries based on diagnoses and question type.
        
        Args:
            question_text: The question text
            question_type: Type of question being asked
            options: Available answer options
            integrated_evidence: Integrated evidence from images and clinical context
            diagnoses: List of extracted diagnoses
            num_queries: Number of queries to generate
            
        Returns:
            List of search queries
        """
        # Sort diagnoses by confidence
        sorted_diagnoses = sorted(diagnoses, key=lambda x: x.get('confidence', 0), reverse=True)
        
        # Generate different types of queries based on question type
        question_specific_queries = self._generate_question_specific_queries(
            question_text, 
            question_type, 
            options, 
            sorted_diagnoses
        )
        
        diagnosis_specific_queries = self._generate_diagnosis_specific_queries(
            question_type,
            sorted_diagnoses
        )
        
        # Combine and prioritize queries
        all_queries = question_specific_queries + diagnosis_specific_queries
        
        # Remove duplicates while preserving order
        unique_queries = []
        seen = set()
        for query in all_queries:
            if query.lower() not in seen:
                seen.add(query.lower())
                unique_queries.append(query)
        
        # Return at most num_queries
        return unique_queries[:num_queries]
    
    def _generate_question_specific_queries(self, question_text, question_type, options, diagnoses):
        """Generate queries specific to the question type."""
        queries = []
        
        # For classification/terminology questions, focus on the classification system
        classification_types = ["Site Location", "Lesion Color", "Size", "Extent", "Lesion Count"]
        if question_type in classification_types:
            # Query about the classification system
            classification_terms = ", ".join([opt for opt in options if opt.lower() != "not mentioned"])
            queries.append(f"dermatology {question_type.lower()} classification {classification_terms}")
            
            # Query about how to distinguish between options
            if len(options) > 2:
                queries.append(f"how to distinguish between {classification_terms} in dermatology")
                
            # For extent questions specifically
            if question_type == "Extent":
                queries.append("definition of widespread vs limited area skin condition dermatology")
        
        # For diagnostic questions, use diagnoses
        if question_type in ["Differential", "Specific Diagnosis"]:
            if diagnoses:
                top_diagnosis = diagnoses[0]["diagnosis"]
                queries.append(f"{top_diagnosis} diagnostic criteria dermatology")
                
                # Add query for differential diagnosis
                diagnoses_list = ", ".join([d["diagnosis"] for d in diagnoses[:3]])
                queries.append(f"differential diagnosis {diagnoses_list}")
        
        # For treatment questions
        if question_type == "Treatment":
            if diagnoses:
                top_diagnosis = diagnoses[0]["diagnosis"]
                queries.append(f"{top_diagnosis} treatment options dermatology")
                
                # Add body site if available
                body_site = self._extract_body_site(question_text)
                if body_site:
                    queries.append(f"{top_diagnosis} {body_site} treatment guidelines")
        
        return queries
    
    def _generate_diagnosis_specific_queries(self, question_type, diagnoses):
        """Generate queries that connect diagnoses with the question type."""
        queries = []
        
        if not diagnoses:
            return queries
            
        # Use top diagnoses
        for diagnosis in diagnoses[:2]:  # Use top 2 diagnoses
            diag_name = diagnosis["diagnosis"]
            
            # Connect diagnosis with question type
            if question_type in ["Site Location", "Extent"]:
                queries.append(f"{diag_name} typical distribution pattern dermatology")
                queries.append(f"{diag_name} localized versus widespread presentation")
                
            elif question_type == "Lesion Color":
                queries.append(f"{diag_name} typical color appearance dermatology")
                
            elif question_type == "Texture":
                queries.append(f"{diag_name} texture characteristics dermatology")
                
            elif question_type == "Itch":
                queries.append(f"is {diag_name} itchy dermatology")
                
            elif question_type == "Onset":
                queries.append(f"{diag_name} typical onset and progression")
                
            else:
                # General query connecting diagnosis and question type
                queries.append(f"{diag_name} {question_type.lower()} dermatology")
                
        return queries
    
    def _extract_body_site(self, question_text):
        """Extract body site from question text."""
        import re
        
        body_parts = [
            "hand", "foot", "arm", "leg", "face", "back", "chest", "abdomen",
            "scalp", "neck", "finger", "toe", "elbow", "knee", "shoulder",
            "palm", "sole", "trunk", "extremity", "head"
        ]
        
        for part in body_parts:
            if re.search(r'\b' + part + r'[s]?\b', question_text.lower()):
                return part
                
        return None

In [10]:
class DiagnosisBasedKnowledgeRetriever:
    """Retrieves knowledge from the dermatology knowledge base using diagnosis-based approach."""
    
    def __init__(self, kb_manager, query_generator, diagnosis_extractor, args=None):
        """
        Initialize the knowledge retriever.
        
        Args:
            kb_manager: KnowledgeBaseManager instance
            query_generator: DiagnosisBasedQueryGenerator instance
            diagnosis_extractor: DiagnosisExtractor instance
            args: Configuration arguments
        """
        self.kb_manager = kb_manager
        self.query_generator = query_generator
        self.diagnosis_extractor = diagnosis_extractor
        self.args = args
    
    def retrieve_knowledge(self, question_text, question_type, options, image_analysis, clinical_context, integrated_evidence):
        """
        Retrieve relevant knowledge for a dermatological question using diagnoses.
        
        Args:
            question_text: The question text
            question_type: Type of question being asked
            options: Available answer options
            image_analysis: Structured image analysis
            clinical_context: Structured clinical context
            integrated_evidence: Integrated evidence from images and clinical context
            
        Returns:
            Dictionary with retrieved knowledge
        """
        # Check if we should use RAG for this question type
        if self.args:
            rag_config = self.args.question_type_retrieval_config.get(
                question_type, self.args.default_rag_config
            )
        else:
            # Fallback if args not provided
            default_config = {"use_rag": True, "weight": 0.4}
            rag_config = {
                "Site Location": {"use_rag": False, "weight": 0.2},
                "Lesion Color": {"use_rag": False, "weight": 0.2},
                "Size": {"use_rag": False, "weight": 0.1},
                # Add other configurations as needed
            }.get(question_type, default_config)
        
        if not rag_config["use_rag"]:
            return {
                "retrieved": False,
                "reason": f"RAG not enabled for question type: {question_type}",
                "results": []
            }
        
        # Extract potential diagnoses
        diagnoses = self.diagnosis_extractor.extract_diagnoses(image_analysis, clinical_context)
        
        # Generate search queries based on diagnoses and question type
        queries = self.query_generator.generate_queries(
            question_text, 
            question_type, 
            options, 
            integrated_evidence,
            diagnoses
        )
        
        if not queries:
            return {
                "retrieved": False,
                "reason": "Failed to generate search queries",
                "results": []
            }
        
        # Retrieve results for each query
        all_results = []
        
        for query in queries:
            results = self.kb_manager.hybrid_search(query)
            
            if not results.empty:
                # Convert to list of dictionaries for easier handling
                for _, row in results.iterrows():
                    # Get cross-encoder score or fallback to distance metric
                    relevance_score = float(row.get('cross_score', 1.0 - row.get('_distance', 0.5)))
                    
                    # Only include results with positive relevance scores
                    if relevance_score > 0:
                        all_results.append({
                            "query": query,
                            "topic": row['topic'],
                            "information": row['information'],
                            "relevance_score": relevance_score,
                            "diagnoses": [d["diagnosis"] for d in diagnoses[:3]]
                        })
        
        # Remove duplicates
        unique_results = []
        seen_topics = set()
        
        for result in sorted(all_results, key=lambda x: x['relevance_score'], reverse=True):
            if result['topic'] not in seen_topics:
                unique_results.append(result)
                seen_topics.add(result['topic'])
        
        top_k = self.args.top_k_rerank if self.args else 5
        
        # Return results
        return {
            "retrieved": len(unique_results) > 0,
            "queries": queries,
            "diagnoses": diagnoses,
            "results": unique_results[:top_k]  # Limit to top-k unique results
        }

In [11]:
class ImageAnalysisService:
    """Service for analyzing dermatological images."""
    
    def __init__(self, client, args=None):
        self.client = client
        self.args = args
        
    def analyze_images(self, image_paths, encounter_id):
        """
        Analyze multiple dermatological images for an encounter.
        
        Args:
            image_paths: List of paths to images
            encounter_id: Encounter identifier
            
        Returns:
            Dictionary with individual and aggregated analyses
        """
        image_analyses = []
        
        structured_prompt = self._create_dermatology_prompt()
        
        for idx, img_path in enumerate(image_paths):
            analysis = self._analyze_single_image(
                img_path, 
                structured_prompt, 
                encounter_id, 
                idx, 
                len(image_paths)
            )
            image_analyses.append(analysis)
        
        aggregated_analysis = self._aggregate_analyses(image_analyses, encounter_id)
        
        return {
            "encounter_id": encounter_id,
            "image_count": len(image_paths),
            "individual_analyses": image_analyses,
            "aggregated_analysis": aggregated_analysis
        }
    
    def _create_dermatology_prompt(self):
        """Create the structured dermatology analysis prompt."""
        return """As dermatology specialist analyzing skin images, extract and structure all clinically relevant information from this dermatological image.

Organize your response in a JSON dictionary:

1. SIZE: Approximate dimensions of lesions/affected areas, size comparison (thumbnail, palm, larger), Relative size comparisons for multiple lesions
2. SITE_LOCATION: Visible body parts in the image, body areas showing lesions/abnormalities, Specific anatomical locations affected
3. SKIN_DESCRIPTION: Lesion morphology (flat, raised, depressed), Texture of affected areas, Surface characteristics (scales, crust, fluid), Appearance of lesion boundaries
4. LESION_COLOR: Predominant color(s) of affected areas, Color variations within lesions, Color comparison to normal skin, Color distribution patterns
5. LESION_COUNT: Number of distinct lesions/affected areas, Single vs multiple presentation, Distribution pattern if multiple, Any counting limitations
6. EXTENT: How widespread the condition appears, Localized vs widespread assessment, Approximate percentage of visible skin affected, Limitations in determining full extent
7. TEXTURE: Expected tactile qualities, Smooth vs rough assessment, Notable textural features, Texture consistency across affected areas
8. ONSET_INDICATORS: Visual clues about condition duration, Acute vs chronic presentation features, Healing/progression/chronicity signs, Note: precise timing cannot be determined from images
9. ITCH_INDICATORS: Scratch marks/excoriations/trauma signs, Features associated with itchy conditions, Pruritic vs non-pruritic visual indicators, Note: sensation cannot be directly observed
10. OVERALL_IMPRESSION: Brief description (1-2 sentences), Key diagnostic features, Potential diagnoses (2-3)

Be concise and use medical terminology where appropriate. If information for a section is cannot be determined, state "Cannot determine from image".
"""
    
    def _analyze_single_image(self, img_path, prompt, encounter_id, idx, total_images):
        """Analyze a single dermatological image."""
        try:
            image = Image.open(img_path)
            
            print(f"Analyzing image {idx+1}/{total_images} for encounter {encounter_id}")
            
            response = self.client.models.generate_content(
                model=self.args.gemini_model if self.args else "gemini-2.5-flash-preview-04-17",
                contents=[prompt, image]
            )
            
            analysis_text = response.text
            
            structured_analysis = parse_json_response(analysis_text)
            
            return {
                "image_index": idx + 1,
                "image_path": os.path.basename(img_path),
                "structured_analysis": structured_analysis
            }
            
        except Exception as e:
            print(f"Error analyzing image {img_path}: {str(e)}")
            return {
                "image_index": idx + 1,
                "image_path": os.path.basename(img_path),
                "error": str(e)
            }
    
    def _aggregate_analyses(self, image_analyses, encounter_id):
        """Aggregate structured analyses from multiple images."""
        valid_analyses = [a for a in image_analyses if "error" not in a and "structured_analysis" in a]
        
        if not valid_analyses:
            return {
                "error": "No valid analyses to aggregate",
                "message": "Unable to generate aggregated analysis due to errors in individual analyses."
            }
        
        if len(valid_analyses) == 1:
            return valid_analyses[0]["structured_analysis"]
        
        analysis_jsons = []
        for analysis in valid_analyses:
            analysis_json = json.dumps(analysis["structured_analysis"])
            analysis_jsons.append(f"Image {analysis['image_index']} ({analysis['image_path']}): {analysis_json}")
        
        aggregation_prompt = self._create_aggregation_prompt(analysis_jsons)
        
        try:
            response = self.client.models.generate_content(
                model=self.args.gemini_model if self.args else "gemini-2.5-flash-preview-04-17",
                contents=[aggregation_prompt]
            )
            
            aggregation_text = response.text
            
            aggregated_analysis = parse_json_response(aggregation_text)
            
            return aggregated_analysis
            
        except Exception as e:
            print(f"Error creating aggregated analysis for encounter {encounter_id}: {str(e)}")
            return {
                "error": str(e),
                "aggregation_error": "Failed to generate aggregated analysis"
            }
    
    def _create_aggregation_prompt(self, analysis_jsons):
        """Create a prompt for aggregating multiple image analyses."""
        return f"""As dermatology specialist reviewing multiple skin image analyses for the same patient, combine these analyses and organize your response in a JSON dictionary:

1. SIZE: Approximate dimensions of lesions/affected areas, size comparison (thumbnail, palm, larger), Relative size comparisons for multiple lesions
2. SITE_LOCATION: Visible body parts in the image, body areas showing lesions/abnormalities, Specific anatomical locations affected
3. SKIN_DESCRIPTION: Lesion morphology (flat, raised, depressed), Texture of affected areas, Surface characteristics (scales, crust, fluid), Appearance of lesion boundaries
4. LESION_COLOR: Predominant color(s) of affected areas, Color variations within lesions, Color comparison to normal skin, Color distribution patterns
5. LESION_COUNT: Number of distinct lesions/affected areas, Single vs multiple presentation, Distribution pattern if multiple, Any counting limitations
6. EXTENT: How widespread the condition appears, Localized vs widespread assessment, Approximate percentage of visible skin affected, Limitations in determining full extent
7. TEXTURE: Expected tactile qualities, Smooth vs rough assessment, Notable textural features, Texture consistency across affected areas
8. ONSET_INDICATORS: Visual clues about condition duration, Acute vs chronic presentation features, Healing/progression/chronicity signs, Note: precise timing cannot be determined from images
9. ITCH_INDICATORS: Scratch marks/excoriations/trauma signs, Features associated with itchy conditions, Pruritic vs non-pruritic visual indicators, Note: sensation cannot be directly observed
10. OVERALL_IMPRESSION: Brief description (1-2 sentences), Key diagnostic features, Potential diagnoses (2-3)
    
{' '.join(analysis_jsons)}
"""

In [12]:
class ClinicalContextAnalyzer:
    """Service for analyzing clinical context."""
    
    def __init__(self, client, args=None):
        self.client = client
        self.args = args
        
    def extract_clinical_context(self, query_context, encounter_id):
        """
        Extract structured clinical information from an encounter's query context.
        
        Args:
            query_context: The query context text
            encounter_id: Encounter identifier
            
        Returns:
            Dictionary with structured clinical information
        """
        clinical_text = self._extract_clinical_text(query_context)
        
        if not clinical_text:
            return {
                "encounter_id": encounter_id,
                "clinical_summary": "No clinical information available"
            }
        
        prompt = self._create_clinical_context_prompt(clinical_text)
        
        try:
            response = self.client.models.generate_content(
                model=self.args.gemini_model if self.args else "gemini-2.5-flash-preview-04-17",
                contents=[prompt]
            )
            
            structured_context = parse_json_response(response.text)
            
            return {
                "encounter_id": encounter_id,
                "raw_clinical_text": clinical_text,
                "structured_clinical_context": structured_context
            }
                
        except Exception as e:
            print(f"Error extracting clinical context for encounter {encounter_id}: {str(e)}")
            return {
                "encounter_id": encounter_id,
                "raw_clinical_text": clinical_text,
                "error": str(e)
            }
    
    def _extract_clinical_text(self, query_context):
        """Extract clinical text from query context."""
        clinical_lines = []
        capturing = False
        for line in query_context.split('\n'):
            if "Background Clinical Information" in line:
                capturing = True
                continue
            elif "Available Options" in line:
                capturing = False
            elif capturing:
                clinical_lines.append(line)
        
        return "\n".join(clinical_lines).strip()
    
    def _create_clinical_context_prompt(self, clinical_text):
        """Create prompt for extracting structured clinical information."""
        return f"""You are a dermatology specialist analyzing patient information. 
Extract and structure all clinically relevant information from this patient description:

{clinical_text}

Organize your response in the following JSON structure:

1. DEMOGRAPHICS: Age, sex, and any other demographic data
2. SITE_LOCATION: Body parts affected by the condition as described in the text
3. SKIN_DESCRIPTION: Any mention of lesion morphology (flat, raised, depressed), texture, surface characteristics (scales, crust, fluid), appearance of lesion boundaries
4. LESION_COLOR: Any description of color(s) of affected areas, color variations, comparison to normal skin
5. LESION_COUNT: Any information about number of lesions, single vs multiple presentation, distribution pattern
6. EXTENT: How widespread the condition appears based on the description, localized vs widespread
7. TEXTURE: Any description of tactile qualities, smooth vs rough, notable textural features
8. ONSET_INDICATORS: Information about onset, duration, progression, or evolution of symptoms
9. ITCH_INDICATORS: Mentions of scratching, itchiness, or other sensory symptoms
10. OTHER_SYMPTOMS: Any additional symptoms mentioned (pain, burning, etc.)
11. TRIGGERS: Identified factors that worsen/improve the condition
12. HISTORY: Relevant past medical history or previous treatments
13. DIAGNOSTIC_CONSIDERATIONS: Any mentioned or suggested diagnoses in the text

Be concise and use medical terminology where appropriate. If information for a section is 
not available, indicate "Not mentioned".
"""

In [13]:
class EvidenceIntegrator:
    """Integrates visual, clinical, and knowledge-based evidence."""
    
    def __init__(self, client, args=None):
        self.client = client
        self.args = args
        
    def integrate_evidence(self, image_analysis, clinical_context, question_type, retrieved_knowledge=None):
        """
        Integrate image analysis with clinical context and retrieved knowledge.
        
        Args:
            image_analysis: Structured image analysis
            clinical_context: Structured clinical context
            question_type: Type of question being asked
            retrieved_knowledge: Retrieved knowledge from the knowledge base (optional)
            
        Returns:
            Dictionary with integrated evidence
        """
        # Determine weighting based on question type
        weights = self._get_weights_for_question(question_type)
        
        # Create prompt for integration
        prompt = self._create_integration_prompt(
            image_analysis,
            clinical_context,
            question_type,
            weights,
            retrieved_knowledge
        )
        
        try:
            response = self.client.models.generate_content(
                model=self.args.gemini_model if self.args else "gemini-2.5-flash-preview-04-17",
                contents=[prompt]
            )
            
            integration_text = response.text
            
            integrated_evidence = parse_json_response(integration_text)
            
            return integrated_evidence
            
        except Exception as e:
            print(f"Error integrating evidence: {str(e)}")
            return {
                "error": str(e),
                "message": "Failed to integrate evidence"
            }
    
    def _get_weights_for_question(self, question_type):
        """
        Determine evidence weighting based on question type.
        
        Returns:
            Dictionary with weights for each evidence type
        """
        weights = {
            "Site Location": {"image": 0.8, "clinical": 0.2, "knowledge": 0.0},
            "Lesion Color": {"image": 0.9, "clinical": 0.1, "knowledge": 0.0},
            "Size": {"image": 0.8, "clinical": 0.2, "knowledge": 0.0},
            "Skin Description": {"image": 0.7, "clinical": 0.3, "knowledge": 0.2},
            "Duration of Symptoms": {"image": 0.3, "clinical": 0.7, "knowledge": 0.2},
            "Itch": {"image": 0.4, "clinical": 0.6, "knowledge": 0.3},
            "Extent": {"image": 0.7, "clinical": 0.3, "knowledge": 0.1},
            "Treatment": {"image": 0.1, "clinical": 0.9, "knowledge": 0.7},
            "Lesion Evolution": {"image": 0.3, "clinical": 0.7, "knowledge": 0.4},
            "Texture": {"image": 0.6, "clinical": 0.4, "knowledge": 0.2},
            "Specific Diagnosis": {"image": 0.5, "clinical": 0.5, "knowledge": 0.8},
            "Count": {"image": 0.8, "clinical": 0.2, "knowledge": 0.0},
            "Differential": {"image": 0.5, "clinical": 0.5, "knowledge": 0.8},
        }
        
        # Get weights from args or fallback to default
        if self.args:
            rag_config = self.args.question_type_retrieval_config.get(
                question_type, self.args.default_rag_config
            )
        else:
            # Fallback if args not provided
            default_config = {"use_rag": True, "weight": 0.4}
            rag_config = {
                "Site Location": {"use_rag": False, "weight": 0.2},
                "Lesion Color": {"use_rag": False, "weight": 0.2},
                # Add other known configurations
            }.get(question_type, default_config)
            
        knowledge_weight = rag_config["weight"] if rag_config["use_rag"] else 0.0
        
        # Default weights if question type not found
        default = {"image": 0.5, "clinical": 0.5, "knowledge": knowledge_weight}
        type_weights = weights.get(question_type, default)
        
        # Override knowledge weight if specified in config
        type_weights["knowledge"] = knowledge_weight
        
        return type_weights
    
    def _create_integration_prompt(self, image_analysis, clinical_context, question_type, weights, retrieved_knowledge=None):
        """Create prompt for evidence integration."""
        has_knowledge = retrieved_knowledge is not None and retrieved_knowledge.get('retrieved', False)
        
        knowledge_section = ""
        if has_knowledge:
            results = retrieved_knowledge.get('results', [])
            if results:
                knowledge_texts = []
                for i, result in enumerate(results[:5]):  # Limit to top 5 results
                    knowledge_texts.append(f"RESULT {i+1}:\nTopic: {result['topic']}\nInformation: {result['information']}")
                
                knowledge_section = f"""
RETRIEVED MEDICAL KNOWLEDGE:
{json.dumps(knowledge_texts, indent=2)}

For this {question_type} question, image evidence has {weights['image']*100}% weight, clinical evidence has {weights['clinical']*100}% weight, and medical knowledge has {weights['knowledge']*100}% weight.
"""
        
        # Base prompt
        prompt = f"""As a dermatology specialist, integrate the visual findings from images with the clinical history.

IMAGE ANALYSIS:
{json.dumps(image_analysis.get("aggregated_analysis", {}), indent=2)}

CLINICAL CONTEXT:
{json.dumps(clinical_context.get("structured_clinical_context", {}), indent=2)}

{knowledge_section}

Pay special attention to potential contradictions between visual findings and clinical history. Even minor inconsistencies should be noted as contradictions. Look for cases where clinical context suggests features not visible in images or where visual findings seem to contradict patient-reported symptoms or history.

Organize your response in a JSON structure with the following elements:

1. INTEGRATED_FINDINGS: For each key dermatological feature, combine visual and clinical evidence
   - SIZE
   - SITE_LOCATION
   - SKIN_DESCRIPTION
   - LESION_COLOR
   - LESION_COUNT
   - EXTENT
   - TEXTURE
   - ONSET_DURATION
   - SYMPTOMS

2. CONCORDANCE_ASSESSMENT: For each feature, assess if visual and clinical evidence are:
   - CONCORDANT: Visual and clinical evidence agree
   - DISCORDANT: Visual and clinical evidence conflict (explain the conflict)
   - COMPLEMENTARY: Evidence sources provide different but non-conflicting information
   - MISSING_VISUAL: Clinical description present but not visible in images
   - MISSING_CLINICAL: Visible in images but not mentioned in clinical context

3. CONTRADICTIONS: List any specific contradictions between visual and clinical evidence
   - For each contradiction, explain what the conflict is and assess which source is more reliable

4. WEIGHTED_EVIDENCE_PROFILE: Synthesize the most reliable information for each category
   - Apply the provided weights to determine the most reliable facts for each feature
   - Explain where you've prioritized one source over another

5. CONFIDENCE_SCORES: Score the confidence (0.0-1.0) in the integrated evidence for each feature

Be specific, concise, and use medical terminology where appropriate.
"""

        # Add knowledge summary section if present
        if has_knowledge:
            prompt += """

6. MEDICAL_KNOWLEDGE_INSIGHTS: Summarize key insights from retrieved medical knowledge
   - How the retrieved knowledge confirms or challenges the observed findings
   - Additional relevant diagnostic or management considerations
   - Typical clinical patterns or expected features that align with observations
"""

        return prompt

In [14]:
class ReasoningEngine:
    """Applies reasoning to determine the best answer."""
    
    def __init__(self, client, args=None):
        self.client = client
        self.args = args
        
    def apply_initial_reasoning(self, question_text, question_type, options, integrated_evidence, model_predictions, retrieved_knowledge=None):
        """
        Apply initial reasoning to determine the most likely answer.
        
        Args:
            question_text: The question text
            question_type: The type of question
            options: Available answer options
            integrated_evidence: Integrated evidence from images and clinical context
            model_predictions: Model predictions to consider
            retrieved_knowledge: Retrieved knowledge from the knowledge base (optional)
            
        Returns:
            Dictionary with reasoning and answer
        """
        model_prediction_text = self._format_model_predictions(model_predictions)
        
        multiple_answers_allowed = question_type in ["Site Location", "Size", "Skin Description"]
        
        prompt = self._create_reasoning_prompt(
            question_text,
            question_type,
            options,
            integrated_evidence,
            model_prediction_text,
            multiple_answers_allowed,
            retrieved_knowledge
        )
        
        try:
            response = self.client.models.generate_content(
                model=self.args.gemini_model if self.args else "gemini-2.5-flash-preview-04-17",
                contents=[prompt]
            )
            
            reasoning_text = response.text
            
            reasoning_result = parse_json_response(reasoning_text)
            
            validated_answer = self._validate_answer(reasoning_result.get('answer', ''), options)
            reasoning_result['validated_answer'] = validated_answer
            
            # Ensure confidence doesn't exceed 0.8 unless absolutely confident
            confidence = reasoning_result.get('confidence', 0.0)
            if isinstance(confidence, str):
                try:
                    confidence = float(confidence)
                except:
                    confidence = 0.0
                    
            # Apply randomization to reduce overconfidence and encourage reflection
            # Multiply by a random factor between 0.9 and 1.0
            randomized_confidence = confidence * random.uniform(0.9, 1.0)
            
            # Cap at 0.8 unless perfect confidence (1.0)
            if 0.95 < confidence < 1.0:
                randomized_confidence = 0.95
            
            # Use confidence threshold from args if available
            confidence_threshold = self.args.confidence_threshold if self.args else 0.75
            if randomized_confidence > confidence_threshold:
                randomized_confidence = confidence_threshold
            
            reasoning_result['confidence'] = randomized_confidence
            
            return reasoning_result
            
        except Exception as e:
            print(f"Error applying initial reasoning: {str(e)}")
            return {
                "reasoning": f"Error: {str(e)}",
                "answer": "Not mentioned",
                "validated_answer": "Not mentioned",
                "confidence": 0.0,
                "error": str(e)
            }
    
    def _format_model_predictions(self, model_predictions):
        """Format model predictions for the prompt."""
        model_prediction_text = ""
        for model_name, predictions in model_predictions.items():
            combined_pred = predictions.get('model_prediction', '')
            if isinstance(combined_pred, float) and pd.isna(combined_pred):
                combined_pred = "No prediction"
            model_prediction_text += f"- {model_name}: {combined_pred}\n"
        return model_prediction_text

    def _create_reasoning_prompt(self, question_text, question_type, options, integrated_evidence, model_prediction_text, multiple_answers_allowed, retrieved_knowledge=None):
        """Create a prompt for the reasoning layer."""
        specialized_guidance = ""
        
        if question_type == "Size" and all(option in ", ".join(options) for option in ["size of thumb nail", "size of palm", "larger area"]):
            specialized_guidance = """
SPECIALIZED GUIDANCE FOR SIZE ASSESSMENT:
When answering this size-related question, interpret the options as follows:
- "size of thumb nail": Individual lesions or affected areas approximately 1-2 cm in diameter
- "size of palm": Affected areas larger than the size of a thumb nail and roughly the size of a palm (approximately 1% of body surface area), which may include multiple smaller lesions across a region
- "larger area": Widespread involvement significantly larger than a palm, affecting a substantial portion(s) of the body

IMPORTANT: For cases with multiple small lesions that are visible in the images, but without extensive widespread involvement across large body regions, "size of palm" is likely the most appropriate answer.
"""
        elif question_type == "Lesion Color" and "combination" in ", ".join(options):
            specialized_guidance = """
SPECIALIZED GUIDANCE FOR LESION COLOR:
When answering color-related questions, pay careful attention to whether there are multiple distinct colors present across the affected areas. "Combination" would be appropriate when different lesions display different colors (e.g., some lesions appear red while others appear white), or when individual lesions show mixed or varied coloration patterns.
"""

        has_knowledge = retrieved_knowledge is not None and retrieved_knowledge.get('retrieved', False)
        
        knowledge_section = ""
        if has_knowledge:
            results = retrieved_knowledge.get('results', [])
            if results:
                knowledge_texts = []
                for i, result in enumerate(results[:5]):  # Limit to top 5 results
                    knowledge_texts.append(f"RESULT {i+1}:\nTopic: {result['topic']}\nInformation: {result['information']}")
                
                knowledge_section = f"""
RETRIEVED MEDICAL KNOWLEDGE:
{json.dumps(knowledge_texts, indent=2)}
"""

        if multiple_answers_allowed:
            task_description = """
Based on all the evidence above, determine the most accurate answer(s) to the question. Your task is to:
1. Analyze the integrated evidence
2. Consider the model predictions, noting any consensus or disagreement, but maintain your critical judgment
3. Provide a detailed reasoning for your conclusion
4. Select the final answer(s) from the available options
5. Provide a confidence score from 0.0 to 1.0 for your answer. Be conservative in your confidence assessment. Consider all possible sources of uncertainty, including image quality limitations, interpretation ambiguity, and potential contradictions. Confidence scores should rarely exceed 0.8 unless evidence is absolutely conclusive and unambiguous.

If selecting multiple answers is appropriate, provide them in a comma-separated list. If no answer can be determined, select "Not mentioned".
"""
        else:
            task_description = """
Based on all the evidence above, determine the SINGLE most accurate answer to the question. Your task is to:
1. Analyze the integrated evidence
2. Consider the model predictions, noting any consensus or disagreement, but maintain your critical judgment
3. Provide a detailed reasoning for your conclusion
4. Select ONLY ONE answer option that is most accurate
5. Provide a confidence score from 0.0 to 1.0 for your answer. Be conservative in your confidence assessment. Consider all possible sources of uncertainty, including image quality limitations, interpretation ambiguity, and potential contradictions. Confidence scores should rarely exceed 0.8 unless evidence is absolutely conclusive and unambiguous.

For this question type, you must select ONLY ONE option as your answer. If no answer can be determined, select "Not mentioned".
"""

        response_format = """
Format your response as a JSON object with these fields:
1. "reasoning": Your step-by-step reasoning process
2. "answer": Your final answer(s) as a single string or comma-separated list of options
3. "confidence": A score from 0.0 to 1.0 representing your confidence level in this answer
4. "evidence_used": The key evidence that supports your answer
5. "uncertainty_factors": Any factors that reduce your confidence
6. "counterfactual": What evidence would make you choose a different answer
"""

        if has_knowledge:
            response_format += """
7. "knowledge_contribution": How the retrieved medical knowledge influenced your reasoning and answer
"""

        base_prompt = f"""You are a medical expert analyzing dermatological findings. Use the provided evidence to determine the most accurate answer(s) for the following question:

QUESTION: {question_text}
QUESTION TYPE: {question_type}
OPTIONS: {", ".join(options)}

INTEGRATED EVIDENCE:
{json.dumps(integrated_evidence, indent=2)}

MODEL PREDICTIONS:
{model_prediction_text}

{knowledge_section}

{specialized_guidance}

IMPORTANT: While multiple model predictions are provided, be aware that these predictions can be inaccurate or inconsistent. Do not assume majority agreement equals correctness. Evaluate the evidence critically and independently from these predictions. Your job is to determine the correct answer based primarily on the integrated evidence, treating model predictions as secondary suggestions that may contain errors.

{task_description}

{response_format}

When providing your answer, strictly adhere to the available options and only select from them.
"""

        return base_prompt

    def _validate_answer(self, answer, options):
        """Validate the answer against available options."""
        if not answer:
            return "Not mentioned"
            
        answer = answer.lower()
        valid_answers = []
        
        if ',' in answer:
            answer_parts = [part.strip() for part in answer.split(',')]
            for part in answer_parts:
                for option in options:
                    if part == option.lower():
                        valid_answers.append(option)
        else:
            for option in options:
                if answer == option.lower():
                    valid_answers.append(option)
        
        if not valid_answers:
            if "not mentioned" in answer:
                valid_answers = ["Not mentioned"]
            else:
                valid_answers = ["Not mentioned"]
        
        return ", ".join(valid_answers)

In [15]:
class SelfReflectionEngine:
    """Applies self-reflection to the reasoning process."""
    
    def __init__(self, client, args=None):
        self.client = client
        self.args = args
    
    def apply_reflection(self, question_text, question_type, options, integrated_evidence, reasoning_result, retrieved_knowledge=None):
        """
        Apply self-reflection to the initial reasoning result.
        
        Args:
            question_text: The question text
            question_type: The type of question
            options: Available answer options
            integrated_evidence: Integrated evidence
            reasoning_result: Initial reasoning result
            retrieved_knowledge: Retrieved knowledge from the knowledge base (optional)
            
        Returns:
            Dictionary with reflection results
        """
        prompt = self._create_reflection_prompt(
            question_text,
            question_type,
            options,
            integrated_evidence,
            reasoning_result,
            retrieved_knowledge
        )
        
        try:
            response = self.client.models.generate_content(
                model=self.args.gemini_model if self.args else "gemini-2.5-flash-preview-04-17",
                contents=[prompt]
            )
            
            reflection_text = response.text
            
            reflection_result = parse_json_response(reflection_text)
            
            if 'revised_answer' in reflection_result:
                validated_answer = self._validate_answer(reflection_result.get('revised_answer', ''), options)
                reflection_result['validated_revised_answer'] = validated_answer
            
            return reflection_result
            
        except Exception as e:
            print(f"Error applying reflection: {str(e)}")
            return {
                "reflection": f"Error: {str(e)}",
                "requires_revision": False,
                "confidence": reasoning_result.get('confidence', 0.0),
                "error": str(e)
            }
    
    def _create_reflection_prompt(self, question_text, question_type, options, integrated_evidence, reasoning_result, retrieved_knowledge=None):
        """Create a prompt for the self-reflection layer."""
        has_knowledge = retrieved_knowledge is not None and retrieved_knowledge.get('retrieved', False)
        
        knowledge_section = ""
        if has_knowledge:
            results = retrieved_knowledge.get('results', [])
            if results:
                knowledge_texts = []
                for i, result in enumerate(results[:5]):  # Limit to top 5 results
                    knowledge_texts.append(f"RESULT {i+1}:\nTopic: {result['topic']}\nInformation: {result['information']}")
                
                knowledge_section = f"""
RETRIEVED MEDICAL KNOWLEDGE:
{json.dumps(knowledge_texts, indent=2)}
"""

        base_prompt = f"""You are a medical expert critically reviewing your own reasoning about a dermatological question. 
Carefully examine the initial reasoning and check for errors, biases, and inconsistencies:

QUESTION: {question_text}
QUESTION TYPE: {question_type}
OPTIONS: {", ".join(options)}

INTEGRATED EVIDENCE:
{json.dumps(integrated_evidence, indent=2)}

INITIAL REASONING:
{json.dumps(reasoning_result, indent=2)}

{knowledge_section}

Your task is to:
1. Critically examine the initial reasoning for errors, biases, or incomplete analysis
2. Identify any evidence that was overlooked or misinterpreted
3. Evaluate whether the confidence level was appropriate
4. Determine if a different answer would be more accurate
5. Check if the evidence truly supports the chosen answer

Format your response as a JSON object with these fields:
1. "reflection": Your critical review of the initial reasoning
2. "overlooked_evidence": Any important evidence that was missed or undervalued
3. "misinterpreted_evidence": Any evidence that was incorrectly interpreted
4. "reasoning_gaps": Logical gaps or assumptions in the initial reasoning
5. "confidence_assessment": Was the confidence level appropriate? Why or why not?
6. "requires_revision": Boolean indicating if the answer needs to be revised (true/false)
7. "revised_answer": If revision is needed, the corrected answer
8. "revised_confidence": If revision is needed, the corrected confidence level (0.0-1.0)
9. "revision_explanation": If revision is needed, the explanation for the change
"""

        if has_knowledge:
            base_prompt += """
10. "knowledge_utilization_assessment": Assessment of how well the initial reasoning utilized the available medical knowledge
"""

        base_prompt += """
Be particularly careful to identify:
- Cherry-picking: Did the initial reasoning focus only on evidence supporting its conclusion?
- Overconfidence: Was the confidence level too high given the available evidence?
- Alternative explanations: Are there valid alternative interpretations of the evidence?
- Implicit assumptions: Were there unstated assumptions in the reasoning process?

Be honest and thorough in your self-reflection, even if it means acknowledging errors in the initial reasoning.
"""

        return base_prompt

    def _validate_answer(self, answer, options):
        """Validate the answer against available options."""
        if not answer:
            return "Not mentioned"
            
        answer = answer.lower()
        valid_answers = []
        
        if ',' in answer:
            answer_parts = [part.strip() for part in answer.split(',')]
            for part in answer_parts:
                for option in options:
                    if part == option.lower():
                        valid_answers.append(option)
        else:
            for option in options:
                if answer == option.lower():
                    valid_answers.append(option)
        
        if not valid_answers:
            if "not mentioned" in answer:
                valid_answers = ["Not mentioned"]
            else:
                valid_answers = ["Not mentioned"]
        
        return ", ".join(valid_answers)

In [16]:
class ReAnalysisEngine:
    """Handles re-analysis when initial reasoning is insufficient."""
    
    def __init__(self, client, args=None):
        self.client = client
        self.args = args
    
    def deep_analysis(self, question_text, question_type, options, integrated_evidence, reasoning_result, reflection_result, retrieved_knowledge=None):
        """
        Perform a deeper analysis based on reflection results.
        
        Args:
            question_text: The question text
            question_type: The type of question
            options: Available answer options
            integrated_evidence: Integrated evidence
            reasoning_result: Initial reasoning result
            reflection_result: Self-reflection result
            retrieved_knowledge: Retrieved knowledge from the knowledge base (optional)
            
        Returns:
            Dictionary with deep analysis result
        """
        prompt = self._create_deep_analysis_prompt(
            question_text,
            question_type,
            options,
            integrated_evidence,
            reasoning_result,
            reflection_result,
            retrieved_knowledge
        )
        
        try:
            response = self.client.models.generate_content(
                model=self.args.gemini_model if self.args else "gemini-2.5-flash-preview-04-17",
                contents=[prompt]
            )
            
            analysis_text = response.text
            
            deep_analysis = parse_json_response(analysis_text)
            
            validated_answer = self._validate_answer(deep_analysis.get('final_answer', ''), options)
            deep_analysis['validated_final_answer'] = validated_answer
            
            return deep_analysis
            
        except Exception as e:
            print(f"Error performing deep analysis: {str(e)}")
            return {
                "deep_reasoning": f"Error: {str(e)}",
                "final_answer": reasoning_result.get('validated_answer', 'Not mentioned'),
                "validated_final_answer": reasoning_result.get('validated_answer', 'Not mentioned'),
                "final_confidence": reasoning_result.get('confidence', 0.0),
                "error": str(e)
            }
    
    def _create_deep_analysis_prompt(self, question_text, question_type, options, integrated_evidence, reasoning_result, reflection_result, retrieved_knowledge=None):
        """Create a prompt for deep analysis."""
        has_knowledge = retrieved_knowledge is not None and retrieved_knowledge.get('retrieved', False)
        
        knowledge_section = ""
        if has_knowledge:
            results = retrieved_knowledge.get('results', [])
            if results:
                knowledge_texts = []
                for i, result in enumerate(results[:5]):  # Limit to top 5 results
                    knowledge_texts.append(f"RESULT {i+1}:\nTopic: {result['topic']}\nInformation: {result['information']}")
                
                knowledge_section = f"""
RETRIEVED MEDICAL KNOWLEDGE:
{json.dumps(knowledge_texts, indent=2)}
"""

        base_prompt = f"""You are a medical expert performing a deep analysis for a dermatological question after identifying issues with initial reasoning.
Review all evidence and reasoning paths comprehensively:

QUESTION: {question_text}
QUESTION TYPE: {question_type}
OPTIONS: {", ".join(options)}

INTEGRATED EVIDENCE:
{json.dumps(integrated_evidence, indent=2)}

INITIAL REASONING:
{json.dumps(reasoning_result, indent=2)}

REFLECTION:
{json.dumps(reflection_result, indent=2)}

{knowledge_section}

Your task is to:
1. Re-examine ALL available evidence with fresh eyes
2. Address the specific issues highlighted in the reflection
3. Consider each answer option systematically
4. Weigh evidence for and against each potential answer
5. Determine the most accurate answer based on comprehensive analysis

For issues identified in reflection:
- Overlooked evidence: {reflection_result.get('overlooked_evidence', 'None identified')}
- Misinterpreted evidence: {reflection_result.get('misinterpreted_evidence', 'None identified')}
- Reasoning gaps: {reflection_result.get('reasoning_gaps', 'None identified')}

Format your response as a JSON object with these fields:
1. "deep_reasoning": Your comprehensive analysis considering all evidence and perspectives
2. "systematic_assessment": Assessment of evidence for EACH possible answer option
3. "final_answer": Your conclusion after deep analysis
4. "final_confidence": Your confidence level after deep analysis (0.0-1.0)
5. "key_determinants": The most important factors that determined your final answer
6. "remaining_uncertainties": Any unresolved questions or limitations
"""

        if has_knowledge:
            base_prompt += """
7. "knowledge_integration": How you've incorporated medical knowledge into your final analysis
"""

        base_prompt += """
Be thorough, balanced, and precise in your analysis. Consider the evidence holistically and avoid the pitfalls identified in the reflection phase.
"""

        return base_prompt

    def _validate_answer(self, answer, options):
        """Validate the answer against available options."""
        if not answer:
            return "Not mentioned"
            
        answer = answer.lower()
        valid_answers = []
        
        if ',' in answer:
            answer_parts = [part.strip() for part in answer.split(',')]
            for part in answer_parts:
                for option in options:
                    if part == option.lower():
                        valid_answers.append(option)
        else:
            for option in options:
                if answer == option.lower():
                    valid_answers.append(option)
        
        if not valid_answers:
            if "not mentioned" in answer:
                valid_answers = ["Not mentioned"]
            else:
                valid_answers = ["Not mentioned"]
        
        return ", ".join(valid_answers)

In [17]:
class AgenticDermatologyPipeline:
    """Main pipeline for agentic dermatology analysis with diagnosis-based retrieval."""
    
    def __init__(self, api_key=None, args=None):
        if api_key is None:
            api_key = "AIzaSyCCb63iuGCupIS_EDZ8S0qb2-38DA7mUbM"
        
        self.client = genai.Client(api_key=api_key)
        self.args = args
        
        # Initialize knowledge base and retrieval components
        print("Initializing knowledge base...")
        self.kb_manager = KnowledgeBaseManager()
        
        # Initialize diagnosis extractor
        self.diagnosis_extractor = DiagnosisExtractor()
        
        # Initialize query generator
        self.query_generator = DiagnosisBasedQueryGenerator(self.client)
        
        # Initialize knowledge retriever
        self.knowledge_retriever = DiagnosisBasedKnowledgeRetriever(
            self.kb_manager,
            self.query_generator,
            self.diagnosis_extractor
        )
        
        # Initialize analysis components
        self.image_analyzer = ImageAnalysisService(self.client, args=args)
        self.clinical_analyzer = ClinicalContextAnalyzer(self.client, args=args)
        self.evidence_integrator = EvidenceIntegrator(self.client, args=args)
        self.reasoning_engine = ReasoningEngine(self.client, args=args)
        self.reflection_engine = SelfReflectionEngine(self.client, args=args)
        self.reanalysis_engine = ReAnalysisEngine(self.client, args=args)
    
    def process_single_encounter(self, agentic_data, encounter_id):
        """
        Process a single encounter with all its questions using the agentic pipeline.

        Args:
            agentic_data: AgenticRAGData instance containing all encounter data
            encounter_id: The specific encounter ID to process

        Returns:
            Dictionary with all questions processed with agentic reasoning for this encounter
        """
        all_pairs = agentic_data.get_all_encounter_question_pairs()
        encounter_pairs = [pair for pair in all_pairs if pair[0] == encounter_id]

        if not encounter_pairs:
            print(f"No data found for encounter {encounter_id}")
            return None

        print(f"Processing {len(encounter_pairs)} questions for encounter {encounter_id}")

        encounter_results = {encounter_id: {}}

        # Extract image analysis once per encounter
        print(f"Computing image analysis for {encounter_id}")
        sample_data = agentic_data.get_combined_data(encounter_pairs[0][0], encounter_pairs[0][1])
        image_analysis = self.image_analyzer.analyze_images(sample_data['images'], encounter_id)

        # Extract clinical context once per encounter
        print(f"Extracting clinical context for {encounter_id}")
        clinical_context = self.clinical_analyzer.extract_clinical_context(
            sample_data['query_context'], 
            encounter_id
        )

        for i, (encounter_id, base_qid) in enumerate(encounter_pairs):
            print(f"Processing question {i+1}/{len(encounter_pairs)}: {base_qid}")

            sample_data = agentic_data.get_combined_data(encounter_id, base_qid)
            if not sample_data:
                print(f"Warning: No data found for {encounter_id}, {base_qid}")
                continue

            # Extract question details
            question_text = sample_data['query_context'].split("MAIN QUESTION TO ANSWER:")[1].split("\n")[0].strip()
            question_type = sample_data['question_type']
            options = sample_data['options']
            model_predictions = sample_data['model_predictions']
            
            # First, do initial evidence integration without knowledge retrieval
            print(f"Initial evidence integration for {encounter_id}, {base_qid}")
            initial_integrated_evidence = self.evidence_integrator.integrate_evidence(
                image_analysis,
                clinical_context,
                question_type
            )
            
            # Now use diagnosis-based knowledge retrieval
            print(f"Retrieving knowledge based on diagnoses for {encounter_id}, {base_qid}")
            retrieved_knowledge = self.knowledge_retriever.retrieve_knowledge(
                question_text,
                question_type,
                options,
                image_analysis,
                clinical_context,
                initial_integrated_evidence
            )
            
            # Integrate evidence with retrieved knowledge
            print(f"Integrating all evidence for {encounter_id}, {base_qid}")
            integrated_evidence = self.evidence_integrator.integrate_evidence(
                image_analysis,
                clinical_context,
                question_type,
                retrieved_knowledge
            )

            # Initial reasoning
            print(f"Initial reasoning for {encounter_id}, {base_qid}")
            reasoning_result = self.reasoning_engine.apply_initial_reasoning(
                question_text,
                question_type,
                options,
                integrated_evidence,
                model_predictions,
                retrieved_knowledge
            )

            # Determine if self-reflection is needed based on confidence
            confidence = reasoning_result.get('confidence', 0.0)
            if isinstance(confidence, str):
                try:
                    confidence = float(confidence)
                except:
                    confidence = 0.0

            final_result = reasoning_result
            reflection_path = []

            # Apply self-reflection if confidence is below threshold
            confidence_threshold = self.args.confidence_threshold if self.args else 0.75  # Default to 0.75 if args not available
            if confidence < confidence_threshold:
                print(f"Confidence {confidence} below threshold. Applying self-reflection.")

                reflection_result = self.reflection_engine.apply_reflection(
                    question_text,
                    question_type,
                    options,
                    integrated_evidence,
                    reasoning_result,
                    retrieved_knowledge
                )
                reflection_path.append(reflection_result)

                # Determine if re-analysis is needed based on reflection
                requires_revision = reflection_result.get('requires_revision', False)
                if requires_revision:
                    print(f"Reflection indicates revision needed. Performing deep analysis.")

                    deep_analysis = self.reanalysis_engine.deep_analysis(
                        question_text,
                        question_type,
                        options,
                        integrated_evidence,
                        reasoning_result,
                        reflection_result,
                        retrieved_knowledge
                    )
                    reflection_path.append(deep_analysis)

                    final_result = {
                        "reasoning": deep_analysis.get('deep_reasoning', ''),
                        "answer": deep_analysis.get('final_answer', 'Not mentioned'),
                        "validated_answer": deep_analysis.get('validated_final_answer', 'Not mentioned'),
                        "confidence": deep_analysis.get('final_confidence', 0.0)
                    }
                else:
                    # Use original answer but with updated confidence if available
                    revised_confidence = reflection_result.get('revised_confidence', reasoning_result.get('confidence', 0.0))
                    final_result = {
                        "reasoning": reasoning_result.get('reasoning', ''),
                        "answer": reasoning_result.get('answer', 'Not mentioned'),
                        "validated_answer": reasoning_result.get('validated_answer', 'Not mentioned'),
                        "confidence": revised_confidence
                    }

            encounter_results[encounter_id][base_qid] = {
                "query_context": sample_data['query_context'],
                "options": sample_data['options'],
                "model_predictions": sample_data['model_predictions'],
                "retrieved_knowledge": retrieved_knowledge,
                "integrated_evidence": integrated_evidence,
                "reasoning_result": reasoning_result,
                "reflection_path": reflection_path,
                "final_result": final_result,
                "final_answer": final_result.get('validated_answer', 'Not mentioned')
            }

        output_file = os.path.join(self.args.output_dir if self.args else Config.OUTPUT_DIR, f"diagnosis_based_rag_results_{encounter_id}.json")
        
        with open(output_file, "w") as f:
            json.dump(encounter_results, f, indent=2)

        print(f"Processed all {len(encounter_pairs)} questions for encounter {encounter_id}")
        return encounter_results
    
    def format_results_for_evaluation(self, encounter_results, output_file):
        """Format results for official evaluation."""
        QIDS = [
            "CQID010-001",
            "CQID011-001", "CQID011-002", "CQID011-003", "CQID011-004", "CQID011-005", "CQID011-006",
            "CQID012-001", "CQID012-002", "CQID012-003", "CQID012-004", "CQID012-005", "CQID012-006",
            "CQID015-001",
            "CQID020-001", "CQID020-002", "CQID020-003", "CQID020-004", "CQID020-005", 
            "CQID020-006", "CQID020-007", "CQID020-008", "CQID020-009",
            "CQID025-001",
            "CQID034-001",
            "CQID035-001",
            "CQID036-001",
        ]
        
        qid_variants = {}
        for qid in QIDS:
            base_qid, variant = qid.split('-')
            if base_qid not in qid_variants:
                qid_variants[base_qid] = []
            qid_variants[base_qid].append(qid)
        
        required_base_qids = set(qid.split('-')[0] for qid in QIDS)
        
        formatted_predictions = []
        for encounter_id, questions in encounter_results.items():
            encounter_base_qids = set(questions.keys())
            if not required_base_qids.issubset(encounter_base_qids):
                print(f"Skipping encounter {encounter_id} - missing required questions")
                continue
            
            pred_entry = {'encounter_id': encounter_id}
            
            for base_qid, question_data in questions.items():
                if base_qid not in qid_variants:
                    continue
                
                final_answer = question_data['final_answer']
                options = question_data['options']
                
                not_mentioned_index = self._find_not_mentioned_index(options)
                
                self._process_answers(
                    pred_entry, 
                    base_qid, 
                    final_answer, 
                    options, 
                    qid_variants, 
                    not_mentioned_index
                )
            
            formatted_predictions.append(pred_entry)
        
        with open(output_file, 'w') as f:
            json.dump(formatted_predictions, f, indent=2)
        
        print(f"Formatted predictions saved to {output_file} ({len(formatted_predictions)} complete encounters)")
        return formatted_predictions
    
    def _find_not_mentioned_index(self, options):
        """Find the index of 'Not mentioned' in options."""
        for i, opt in enumerate(options):
            if opt.lower() == "not mentioned":
                return i
        return len(options) - 1
    
    def _process_answers(self, pred_entry, base_qid, final_answer, options, qid_variants, not_mentioned_index):
        """Process answers and add to prediction entry."""
        if ',' in final_answer:
            answer_parts = [part.strip() for part in final_answer.split(',')]
            answer_indices = []
            
            for part in answer_parts:
                found = False
                for i, opt in enumerate(options):
                    if part.lower() == opt.lower():
                        answer_indices.append(i)
                        found = True
                        break
                
                if not found:
                    answer_indices.append(not_mentioned_index)
            
            available_variants = qid_variants[base_qid]
            
            for i, idx in enumerate(answer_indices):
                if i < len(available_variants):
                    pred_entry[available_variants[i]] = idx
            
            for i in range(len(answer_indices), len(available_variants)):
                pred_entry[available_variants[i]] = not_mentioned_index
            
        else:
            answer_index = not_mentioned_index
            
            for i, opt in enumerate(options):
                if final_answer.lower() == opt.lower():
                    answer_index = i
                    break
            
            pred_entry[qid_variants[base_qid][0]] = answer_index
            
            if len(qid_variants[base_qid]) > 1:
                for i in range(1, len(qid_variants[base_qid])):
                    pred_entry[qid_variants[base_qid][i]] = not_mentioned_index

In [18]:
def run_diagnosis_based_pipeline_all_encounters(args=None):
    """Run the diagnosis-based pipeline for all available encounters."""
    # Load model predictions and validation dataset
    if args is None:
        args = Args(use_finetuning=True, use_test_dataset=True)
        
    model_predictions_dict = DataLoader.load_all_model_predictions(args)
    all_models_df = pd.concat(model_predictions_dict.values(), ignore_index=True)
    validation_df = DataLoader.load_validation_dataset(args)
    
    # Create agentic data and pipeline
    agentic_data = AgenticRAGData(all_models_df, validation_df)
    pipeline = AgenticDermatologyPipeline(args=args)
    
    # Get all unique encounter IDs
    all_pairs = agentic_data.get_all_encounter_question_pairs()
    unique_encounter_ids = sorted(list(set(pair[0] for pair in all_pairs)))
    print(f"Found {len(unique_encounter_ids)} unique encounters to process")
    
    # Process each encounter
    all_encounter_results = {}
    for i, encounter_id in enumerate(unique_encounter_ids):
        print(f"Processing encounter {i+1}/{len(unique_encounter_ids)}: {encounter_id}...")
        
        try:
            encounter_results = pipeline.process_single_encounter(agentic_data, encounter_id)
            if encounter_results:
                all_encounter_results.update(encounter_results)
                
            # Save intermediate results periodically
            if (i+1) % 5 == 0 or (i+1) == len(unique_encounter_ids):
                timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
                intermediate_output_file = os.path.join(
                    args.output_dir, 
                    f"intermediate_diagnosis_based_results_{i+1}_of_{len(unique_encounter_ids)}_{timestamp}.json"
                )
                with open(intermediate_output_file, 'w') as f:
                    json.dump(all_encounter_results, f, indent=2)
                print(f"Saved intermediate results after processing {i+1} encounters")
        
        except Exception as e:
            print(f"Error processing encounter {encounter_id}: {str(e)}")
            # Save error information
            error_file = os.path.join(
                args.output_dir, 
                f"error_encounter_{encounter_id}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.txt"
            )
            with open(error_file, 'w') as f:
                f.write(f"Error processing encounter {encounter_id}: {str(e)}\n")
                f.write(f"Traceback:\n{traceback.format_exc()}")
    
    # Format and save final predictions
    timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
    output_file = os.path.join(
        args.output_dir, 
        f"{args.dataset_name}_data_cvqa_sys_diagnosis_based_all_{timestamp}.json"
    )
    
    formatted_predictions = pipeline.format_results_for_evaluation(all_encounter_results, output_file)
    
    print(f"Processed {len(formatted_predictions)} encounters successfully")
    return formatted_predictions

In [19]:
def run_diagnosis_based_pipeline(encounter_id, args=None):
    """Run the diagnosis-based pipeline for a single encounter."""
    # Create args if not provided
    if args is None:
        args = Args(use_finetuning=True, use_test_dataset=True)
        
    # Load model predictions and validation dataset
    model_predictions_dict = DataLoader.load_all_model_predictions(args)
    all_models_df = pd.concat(model_predictions_dict.values(), ignore_index=True)
    validation_df = DataLoader.load_validation_dataset(args)
    
    # Create agentic data and pipeline
    agentic_data = AgenticRAGData(all_models_df, validation_df)
    pipeline = AgenticDermatologyPipeline(args=args)
    
    # Process the encounter
    encounter_results = pipeline.process_single_encounter(agentic_data, encounter_id)
    
    # Format and save predictions
    timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
    output_file = os.path.join(
        args.output_dir, 
        f"{args.dataset_name}_data_cvqa_sys_diagnosis_based_{encounter_id}_{timestamp}.json"
    )
    formatted_predictions = pipeline.format_results_for_evaluation(encounter_results, output_file)
    
    return formatted_predictions

In [20]:
# def run_diagnosis_based_pipeline(encounter_id):
#     """Run the diagnosis-based pipeline for a single encounter."""
#     # Load model predictions and validation dataset
#     model_predictions_dict = DataLoader.load_all_model_predictions(Config.MODEL_PREDICTIONS_DIR)
#     all_models_df = pd.concat(model_predictions_dict.values(), ignore_index=True)
#     validation_df = DataLoader.load_validation_dataset(Config.VAL_DATASET_PATH)
    
#     # Create agentic data and pipeline
#     agentic_data = AgenticRAGData(all_models_df, validation_df)
#     pipeline = AgenticDermatologyPipeline()
    
#     # Process the encounter
#     encounter_results = pipeline.process_single_encounter(agentic_data, encounter_id)
    
#     # Format and save predictions
#     output_file = os.path.join(
#         Config.OUTPUT_DIR, 
#         f"test_data_cvqa_sys_diagnosis_based_{encounter_id}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
#     )
#     formatted_predictions = pipeline.format_results_for_evaluation(encounter_results, output_file)
    
#     return formatted_predictions

In [21]:
if __name__ == "__main__":
    # Create args with desired configuration
    args = Args(use_finetuning=True, use_test_dataset=True)
    
    # Initialize the knowledge base first (if needed)
    kb_manager = KnowledgeBaseManager(args=args)
    
    # Run for a single encounter
    encounter_id = "ENC00908"
    formatted_predictions = run_diagnosis_based_pipeline(encounter_id, args)
    print(f"Processed encounter {encounter_id} with {len(formatted_predictions)} prediction entries")
    
    # Or alternatively run for all encounters
    # formatted_predictions = run_diagnosis_based_pipeline_all_encounters(args)
    # print(f"Total complete encounters processed: {len(formatted_predictions)}")


Configuration initialized:
- Using test dataset
- Looking for finetuned model predictions
- Dataset path: /storage/scratch1/2/kthakrar3/mediqa-magic-v2/outputs/test_dataset.csv
- Images directory: /storage/scratch1/2/kthakrar3/mediqa-magic-v2/2025_dataset/test/images_test
- Prediction file prefix: aggregated_test_predictions_
Using existing knowledge base at /storage/scratch1/2/kthakrar3/mediqa-magic-v2/knowledge_db
Initializing BM25 index...
BM25 index initialization complete.
Initializing knowledge base...
Using existing knowledge base at /storage/scratch1/2/kthakrar3/mediqa-magic-v2/knowledge_db
Initializing BM25 index...
BM25 index initialization complete.
Processing 9 questions for encounter ENC00908
Computing image analysis for ENC00908
Analyzing image 1/2 for encounter ENC00908
Analyzing image 2/2 for encounter ENC00908
Extracting clinical context for ENC00908
Processing question 1/9: CQID010
Initial evidence integration for ENC00908, CQID010
Retrieving knowledge based on diagn

In [22]:
formatted_predictions

[{'encounter_id': 'ENC00908',
  'CQID010-001': 1,
  'CQID011-001': 3,
  'CQID011-002': 2,
  'CQID011-003': 7,
  'CQID011-004': 7,
  'CQID011-005': 7,
  'CQID011-006': 7,
  'CQID012-001': 0,
  'CQID012-002': 1,
  'CQID012-003': 3,
  'CQID012-004': 3,
  'CQID012-005': 3,
  'CQID012-006': 3,
  'CQID015-001': 6,
  'CQID020-001': 0,
  'CQID020-002': 1,
  'CQID020-003': 2,
  'CQID020-004': 3,
  'CQID020-005': 6,
  'CQID020-006': 9,
  'CQID020-007': 9,
  'CQID020-008': 9,
  'CQID020-009': 9,
  'CQID025-001': 2,
  'CQID034-001': 8,
  'CQID035-001': 1,
  'CQID036-001': 1}]