In [1]:
import os
import glob
import pandas as pd
import ast
import re
from collections import defaultdict
import base64
from PIL import Image
from dotenv import load_dotenv
from google import genai
import json
import io
import traceback
import datetime

In [2]:
class Config:
    MODEL_PREDICTIONS_DIR = os.path.join(os.getcwd(), "outputs")
    OUTPUT_DIR = os.path.join(os.getcwd(), "outputs")
    VAL_DATASET_PATH = os.path.join(OUTPUT_DIR, "val_dataset.csv")
    IMAGES_DIR = os.path.join(os.getcwd(), "2025_dataset", "valid", "images_valid")
    GEMINI_MODEL = "gemini-2.5-flash-preview-04-17"

In [3]:
class DataLoader:
    @staticmethod
    def get_latest_aggregated_files(model_predictions_dir):
        """Get the latest aggregated prediction files for each model."""
        pattern = os.path.join(model_predictions_dir, "aggregated_predictions_*.csv")
#         print(f"Searching for files with pattern: {pattern}")
        
        agg_files = glob.glob(pattern)
#         print(f"Found {len(agg_files)} aggregated prediction files")
        
        if len(agg_files) == 0:
            return []
        
        latest_files = {}
        
        for file_path in agg_files:
            file_name = os.path.basename(file_path)
            
            parts = file_name.split("_base_")
            if len(parts) != 2:
                print(f"Warning: Unexpected filename format: {file_name}")
                continue
            
            model_part = parts[0].replace("aggregated_predictions_", "")
            model_name = model_part
            
            timestamps = re.findall(r'(\d+)', parts[1])
            if len(timestamps) < 2:
                print(f"Warning: Could not find timestamps in {file_name}")
                continue
            
            timestamp = int(timestamps[1])
            
#             print(f"Found model: {model_name}, timestamp: {timestamp}")
            
            if model_name not in latest_files or timestamp > latest_files[model_name]['timestamp']:
                latest_files[model_name] = {
                    'file_path': file_path,
                    'timestamp': timestamp
                }
        
#         print("\nSelected latest file for each model:")
#         for model, info in latest_files.items():
#             print(f"  {model}: {os.path.basename(info['file_path'])}")
        
        return [info['file_path'] for _, info in latest_files.items()]
    
    @staticmethod
    def load_all_model_predictions(model_predictions_dir):
        """Load all model predictions from aggregated files."""
        latest_files = DataLoader.get_latest_aggregated_files(model_predictions_dir)
        
        if not latest_files:
            print("No aggregated prediction files found. Cannot proceed.")
            return {}
        
        model_predictions = {}
        
        for file_path in latest_files:
            file_name = os.path.basename(file_path)
            
            parts = file_name.split("_base_")
            if len(parts) != 2:
                print(f"Warning: Unexpected filename format: {file_name}")
                continue
                
            model_name = parts[0].replace("aggregated_predictions_", "")
            
            try:
                df = pd.read_csv(file_path)
                
                df['model_name'] = model_name
                
                model_predictions[model_name] = df
                
#                 print(f"Successfully loaded {model_name} predictions with {len(df)} rows")
                
            except Exception as e:
                print(f"Error loading {file_path}: {e}")
        
#         print(f"Loaded {len(model_predictions)} model prediction sets")
        return model_predictions

    @staticmethod
    def load_validation_dataset(val_dataset_path):
        """Load the validation dataset."""
#         print(f"Loading validation dataset from {val_dataset_path}")
        val_df = pd.read_csv(val_dataset_path)
#         print(f"Loaded validation dataset with {len(val_df)} rows")
        
        val_df = DataLoader.process_validation_dataset(val_df)
        
        encounter_question_data = defaultdict(lambda: {
            'images': [],
            'data': None
        })
        
        for _, row in val_df.iterrows():
            encounter_id = row['encounter_id']
            base_qid = row['base_qid']
            key = (encounter_id, base_qid)
            
            if 'image_path' in row and row['image_path']:
                encounter_question_data[key]['images'].append(row['image_path'])
            elif 'image_id' in row and row['image_id']:
                image_path = os.path.join(Config.IMAGES_DIR, row['image_id'])
                encounter_question_data[key]['images'].append(image_path)
            
            if encounter_question_data[key]['data'] is None:
                encounter_question_data[key]['data'] = row.to_dict()
        
#         print(f"Created grouped validation dataset with {len(encounter_question_data)} unique encounter-question pairs")
        
        grouped_data = []
        for (encounter_id, base_qid), data in encounter_question_data.items():
            entry = data['data'].copy()
            entry['all_images'] = data['images']
            entry['encounter_id'] = encounter_id
            entry['base_qid'] = base_qid
            grouped_data.append(entry)
        
        return pd.DataFrame(grouped_data)
    
    @staticmethod
    def safe_convert_options(options_str):
        """Safely convert a string representation of a list to an actual list."""
        if not isinstance(options_str, str):
            return options_str
            
        try:
            return ast.literal_eval(options_str)
        except (SyntaxError, ValueError):
            if options_str.startswith('[') and options_str.endswith(']'):
                return [opt.strip().strip("'\"") for opt in options_str[1:-1].split(',')]
            elif ',' in options_str:
                return [opt.strip() for opt in options_str.split(',')]
            else:
                return [options_str]
    
    @staticmethod
    def process_validation_dataset(val_df):
        """Process and clean the validation dataset."""
        if 'options_en' in val_df.columns:
            val_df['options_en'] = val_df['options_en'].apply(DataLoader.safe_convert_options)
            
            def clean_options(options):
                if not isinstance(options, list):
                    return options
                    
                cleaned_options = []
                for opt in options:
                    if isinstance(opt, str):
                        cleaned_opt = opt.strip("'\" ").replace(" (please specify)", "")
                        cleaned_options.append(cleaned_opt)
                    else:
                        cleaned_options.append(str(opt).strip("'\" "))
                return cleaned_options
                
            val_df['options_en_cleaned'] = val_df['options_en'].apply(clean_options)
        
        if 'question_text' in val_df.columns:
            val_df['question_text_cleaned'] = val_df['question_text'].apply(
                lambda q: q.replace(" Please specify which affected area for each selection.", "") 
                          if isinstance(q, str) and "Please specify which affected area for each selection" in q 
                          else q
            )
            
            val_df['question_text_cleaned'] = val_df['question_text_cleaned'].apply(
                lambda q: re.sub(r'^\d+\s+', '', q) if isinstance(q, str) else q
            )
        
        if 'base_qid' not in val_df.columns and 'qid' in val_df.columns:
            val_df['base_qid'] = val_df['qid'].apply(
                lambda q: q.split('-')[0] if isinstance(q, str) and '-' in q else q
            )
        
#         print(val_df)
        return val_df

In [4]:
class DataProcessor:
    @staticmethod
    def create_query_context(row):
        """Create query context from validation data similar to the inference process."""
        question = row.get('question_text_cleaned', row.get('question_text', 'What do you see in this image?'))
        
        metadata = ""
        if 'question_type_en' in row:
            metadata += f"Type: {row['question_type_en']}"
            
        if 'question_category_en' in row:
            metadata += f", Category: {row['question_category_en']}"
        
        query_title = row.get('query_title_en', '')
        query_content = row.get('query_content_en', '')
        
        clinical_context = ""
        if query_title or query_content:
            clinical_context += "Background Clinical Information (to help with your analysis):\n"
            if query_title:
                clinical_context += f"{query_title}\n"
            if query_content:
                clinical_context += f"{query_content}\n"
        
        options = row.get('options_en_cleaned', row.get('options_en', ['Yes', 'No', 'Not mentioned']))
        if isinstance(options, list):
            options_text = ", ".join(options)
        else:
            options_text = str(options)
        
        query_text = (f"MAIN QUESTION TO ANSWER: {question}\n"
                     f"Question Metadata: {metadata}\n"
                     f"{clinical_context}"
                     f"Available Options (choose from these): {options_text}")
        
#         print(query_text)
        return query_text

In [5]:
class AgenticRAGData:
    def __init__(self, all_models_df, validation_df):
        self.all_models_df = all_models_df
        self.validation_df = validation_df
        
        self.model_predictions = {}
        for (encounter_id, base_qid), group in all_models_df.groupby(['encounter_id', 'base_qid']):
            self.model_predictions[(encounter_id, base_qid)] = group
        
        self.validation_data = {}
        for _, row in validation_df.iterrows():
            self.validation_data[(row['encounter_id'], row['base_qid'])] = row
    
    def get_combined_data(self, encounter_id, base_qid):
        """Retrieve combined data for a specific encounter and question."""
        model_preds = self.model_predictions.get((encounter_id, base_qid), None)
        
        val_data = self.validation_data.get((encounter_id, base_qid), None)
        
        if model_preds is None:
            print(f"No model predictions found for encounter {encounter_id}, question {base_qid}")
            return None
            
        if val_data is None:
            print(f"No validation data found for encounter {encounter_id}, question {base_qid}")
            return None
        
        if 'query_context' not in val_data:
            val_data['query_context'] = DataProcessor.create_query_context(val_data)
        
        model_predictions_dict = {}
        for _, row in model_preds.iterrows():
            model_name = row['model_name']
            
            model_predictions_dict[model_name] = self._process_model_predictions(row)
        
        return {
            'encounter_id': encounter_id,
            'base_qid': base_qid,
            'query_context': val_data['query_context'],
            'images': val_data.get('all_images', []),
            'options': val_data.get('options_en_cleaned', val_data.get('options_en', [])),
            'question_type': val_data.get('question_type_en', ''),
            'question_category': val_data.get('question_category_en', ''),
            'model_predictions': model_predictions_dict
        }
    
    def _process_model_predictions(self, row):
        """Process model predictions from row data."""
#         unique_preds = row.get('unique_predictions', [])
#         if isinstance(unique_preds, str):
#             try:
#                 unique_preds = ast.literal_eval(unique_preds)
#             except:
#                 unique_preds = [unique_preds]
                
#         raw_preds = row.get('all_raw_predictions', [])
#         if isinstance(raw_preds, str):
#             try:
#                 raw_preds = ast.literal_eval(raw_preds)
#             except:
#                 raw_preds = [raw_preds]
                
#         sorted_preds = row.get('all_sorted_predictions', [])
#         if isinstance(sorted_preds, str):
#             try:
#                 sorted_preds = ast.literal_eval(sorted_preds)
#             except:
#                 sorted_preds = [(str(raw_preds[0]), 1)] if raw_preds else []
        
        return {
            'model_prediction': row.get('combined_prediction', ''),
#             'unique_predictions': unique_preds,
#             'all_raw_predictions': raw_preds,
#             'all_sorted_predictions': sorted_preds
        }
    
    def get_all_encounter_question_pairs(self):
        """Return a list of all unique encounter_id, base_qid pairs."""
        return list(self.validation_data.keys())
    
    def get_sample_data(self, n=5):
        """Get a sample of combined data for n random encounter-question pairs."""
        import random
        
        all_pairs = self.get_all_encounter_question_pairs()
        sample_pairs = random.sample(all_pairs, min(n, len(all_pairs)))
        
        return [self.get_combined_data(encounter_id, base_qid) for encounter_id, base_qid in sample_pairs]

In [6]:
class AnalysisService:
    def __init__(self, api_key=None):
        if api_key is None:
            load_dotenv()
            api_key = os.getenv("API_KEY")
        
        self.client = genai.Client(api_key=api_key)
    
    def extract_dermatological_analysis(self, sample_data):
        """
        Extract structured analysis of images for an encounter.
        
        Args:
            sample_data: Dictionary containing encounter data with images
            
        Returns:
            Dictionary with structured dermatological analysis
        """
        encounter_id = sample_data['encounter_id']
        image_paths = sample_data['images']
        
        image_analyses = []
        
        structured_prompt = self._create_dermatology_prompt()
        
        for idx, img_path in enumerate(image_paths):
            analysis = self._analyze_single_image(
                img_path, 
                structured_prompt, 
                encounter_id, 
                idx, 
                len(image_paths)
            )
            image_analyses.append(analysis)
        
        aggregated_analysis = self._aggregate_analyses(image_analyses, encounter_id)
        
        return {
            "encounter_id": encounter_id,
            "image_count": len(image_paths),
            "individual_analyses": image_analyses,
            "aggregated_analysis": aggregated_analysis
        }
    
    def _create_dermatology_prompt(self):
        """Create the structured dermatology analysis prompt."""
        return """As dermatology specialist analyzing skin images, extract and structure all clinically relevant information from this dermatological image.

Organize your response in a JSON dictionary:

1. SIZE: Approximate dimensions of lesions/affected areas, size comparison (thumbnail, palm, larger), Relative size comparisons for multiple lesions
2. SITE_LOCATION: Visible body parts in the image, body areas showing lesions/abnormalities, Specific anatomical locations affected
3. SKIN_DESCRIPTION: Lesion morphology (flat, raised, depressed), Texture of affected areas, Surface characteristics (scales, crust, fluid), Appearance of lesion boundaries
4. LESION_COLOR: Predominant color(s) of affected areas, Color variations within lesions, Color comparison to normal skin, Color distribution patterns
5. LESION_COUNT: Number of distinct lesions/affected areas, Single vs multiple presentation, Distribution pattern if multiple, Any counting limitations
6. EXTENT: How widespread the condition appears, Localized vs widespread assessment, Approximate percentage of visible skin affected, Limitations in determining full extent
7. TEXTURE: Expected tactile qualities, Smooth vs rough assessment, Notable textural features, Texture consistency across affected areas
8. ONSET_INDICATORS: Visual clues about condition duration, Acute vs chronic presentation features, Healing/progression/chronicity signs, Note: precise timing cannot be determined from images
9. ITCH_INDICATORS: Scratch marks/excoriations/trauma signs, Features associated with itchy conditions, Pruritic vs non-pruritic visual indicators, Note: sensation cannot be directly observed
10. OVERALL_IMPRESSION: Brief description (1-2 sentences), Key diagnostic features, Potential diagnoses (2-3)

Be concise and use medical terminology where appropriate. If information for a section is cannot be determined, state "Cannot determine from image".
"""
    
    def _analyze_single_image(self, img_path, prompt, encounter_id, idx, total_images):
        """Analyze a single dermatological image."""
        try:
            image = Image.open(img_path)
            
#             print(f"Analyzing image {idx+1}/{total_images} for encounter {encounter_id}")
            
            response = self.client.models.generate_content(
                model=Config.GEMINI_MODEL,
                contents=[prompt, image]
            )
            
            analysis_text = response.text
#             print(f"Analysis text received (length: {len(analysis_text)})")
            
            structured_analysis = self._parse_json_response(analysis_text)
#             print(f"Successfully parsed structured analysis for image {idx+1}")
            
            return {
                "image_index": idx + 1,
                "image_path": os.path.basename(img_path),
                "structured_analysis": structured_analysis
            }
            
        except Exception as e:
            print(f"Error analyzing image {img_path}: {str(e)}")
            return {
                "image_index": idx + 1,
                "image_path": os.path.basename(img_path),
                "error": str(e)
            }
    
    def _parse_json_response(self, text):
        """Parse JSON from LLM response."""
        cleaned_text = text
        if "```json" in cleaned_text:
            cleaned_text = cleaned_text.split("```json")[1]
        if "```" in cleaned_text:
            cleaned_text = cleaned_text.split("```")[0]
        
        try:
            return json.loads(cleaned_text)
        except json.JSONDecodeError:
            print(f"Warning: Could not parse as JSON")
            return {"parse_error": "Could not parse as JSON", "raw_text": text}
    
    def _aggregate_analyses(self, image_analyses, encounter_id):
        """Aggregate structured analyses from multiple images."""
        valid_analyses = [a for a in image_analyses if "error" not in a and "structured_analysis" in a]
#         print(f"Aggregating {len(valid_analyses)} valid structured analyses for encounter {encounter_id}")
        
        if not valid_analyses:
            return {
                "error": "No valid analyses to aggregate",
                "message": "Unable to generate aggregated analysis due to errors in individual analyses."
            }
        
        if len(valid_analyses) == 1:
            return valid_analyses[0]["structured_analysis"]
        
        analysis_jsons = []
        for analysis in valid_analyses:
            analysis_json = json.dumps(analysis["structured_analysis"])
            analysis_jsons.append(f"Image {analysis['image_index']} ({analysis['image_path']}): {analysis_json}")
        
        aggregation_prompt = self._create_aggregation_prompt(analysis_jsons)
        
        try:
            response = self.client.models.generate_content(
                model=Config.GEMINI_MODEL,
                contents=[aggregation_prompt]
            )
            
            aggregation_text = response.text
#             print(f"Aggregated analysis received (length: {len(aggregation_text)})")
            
            aggregated_analysis = self._parse_json_response(aggregation_text)
#             print("Successfully parsed aggregated analysis")
            
#             print(aggregated_analysis)
            
            return aggregated_analysis
            
        except Exception as e:
            print(f"Error creating aggregated analysis for encounter {encounter_id}: {str(e)}")
            return {
                "error": str(e),
                "aggregation_error": "Failed to generate aggregated analysis"
            }
    
    def _create_aggregation_prompt(self, analysis_jsons):
        """Create a prompt for aggregating multiple image analyses."""
        return f"""As dermatology specialist reviewing multiple skin image analyses for the same patient, combine these analyses and organize your response in a JSON dictionary:

1. SIZE: Approximate dimensions of lesions/affected areas, size comparison (thumbnail, palm, larger), Relative size comparisons for multiple lesions
2. SITE_LOCATION: Visible body parts in the image, body areas showing lesions/abnormalities, Specific anatomical locations affected
3. SKIN_DESCRIPTION: Lesion morphology (flat, raised, depressed), Texture of affected areas, Surface characteristics (scales, crust, fluid), Appearance of lesion boundaries
4. LESION_COLOR: Predominant color(s) of affected areas, Color variations within lesions, Color comparison to normal skin, Color distribution patterns
5. LESION_COUNT: Number of distinct lesions/affected areas, Single vs multiple presentation, Distribution pattern if multiple, Any counting limitations
6. EXTENT: How widespread the condition appears, Localized vs widespread assessment, Approximate percentage of visible skin affected, Limitations in determining full extent
7. TEXTURE: Expected tactile qualities, Smooth vs rough assessment, Notable textural features, Texture consistency across affected areas
8. ONSET_INDICATORS: Visual clues about condition duration, Acute vs chronic presentation features, Healing/progression/chronicity signs, Note: precise timing cannot be determined from images
9. ITCH_INDICATORS: Scratch marks/excoriations/trauma signs, Features associated with itchy conditions, Pruritic vs non-pruritic visual indicators, Note: sensation cannot be directly observed
10. OVERALL_IMPRESSION: Brief description (1-2 sentences), Key diagnostic features, Potential diagnoses (2-3)
    
{' '.join(analysis_jsons)}
"""
    
    def extract_clinical_context(self, sample_data):
        """
        Extract structured clinical information from an encounter's query context.
        
        Args:
            sample_data: Dictionary containing encounter data with query_context
            
        Returns:
            Dictionary with structured clinical information
        """
        encounter_id = sample_data['encounter_id']
        
        query_context = sample_data['query_context']
        
        clinical_text = self._extract_clinical_text(query_context)
        
        if not clinical_text:
            return {
                "encounter_id": encounter_id,
                "clinical_summary": "No clinical information available"
            }
        
        prompt = self._create_clinical_context_prompt(clinical_text)
        
        try:
            response = self.client.models.generate_content(
                model=Config.GEMINI_MODEL,
                contents=[prompt]
            )
            
            return {
                "encounter_id": encounter_id,
                "raw_clinical_text": clinical_text,
                "structured_clinical_context": response.text
            }
                
        except Exception as e:
            print(f"Error extracting clinical context for encounter {encounter_id}: {str(e)}")
            return {
                "encounter_id": encounter_id,
                "raw_clinical_text": clinical_text,
                "error": str(e)
            }
    
    def _extract_clinical_text(self, query_context):
        """Extract clinical text from query context."""
        clinical_lines = []
        capturing = False
        for line in query_context.split('\n'):
            if "Background Clinical Information" in line:
                capturing = True
                continue
            elif "Available Options" in line:
                capturing = False
            elif capturing:
                clinical_lines.append(line)
        
        return "\n".join(clinical_lines).strip()
    
    def _create_clinical_context_prompt(self, clinical_text):
        """Create prompt for extracting structured clinical information."""
        return f"""You are a dermatology specialist analyzing patient information. 
Extract and structure all clinically relevant information from this patient description:

{clinical_text}

Organize your response in the following JSON structure:

1. DEMOGRAPHICS: Age, sex, and any other demographic data
2. SITE_LOCATION: Body parts affected by the condition as described in the text
3. SKIN_DESCRIPTION: Any mention of lesion morphology (flat, raised, depressed), texture, surface characteristics (scales, crust, fluid), appearance of lesion boundaries
4. LESION_COLOR: Any description of color(s) of affected areas, color variations, comparison to normal skin
5. LESION_COUNT: Any information about number of lesions, single vs multiple presentation, distribution pattern
6. EXTENT: How widespread the condition appears based on the description, localized vs widespread
7. TEXTURE: Any description of tactile qualities, smooth vs rough, notable textural features
8. ONSET_INDICATORS: Information about onset, duration, progression, or evolution of symptoms
9. ITCH_INDICATORS: Mentions of scratching, itchiness, or other sensory symptoms
10. OTHER_SYMPTOMS: Any additional symptoms mentioned (pain, burning, etc.)
11. TRIGGERS: Identified factors that worsen/improve the condition
12. HISTORY: Relevant past medical history or previous treatments
13. DIAGNOSTIC_CONSIDERATIONS: Any mentioned or suggested diagnoses in the text

Be concise and use medical terminology where appropriate. If information for a section is 
not available, indicate "Not mentioned".
"""
    
    def apply_reasoning_layer(self, encounter_id, base_qid, image_analysis, clinical_context, sample_data):
        """
        Apply a reasoning layer to determine the best answer(s) for a specific encounter-question pair.
        
        Args:
            encounter_id: The encounter ID
            base_qid: The question ID
            image_analysis: Structured image analysis for this encounter
            clinical_context: Structured clinical context for this encounter
            sample_data: Combined data for this encounter-question pair
        
        Returns:
            Dictionary with reasoning and final answer(s)
        """
        question_text = sample_data['query_context'].split("MAIN QUESTION TO ANSWER:")[1].split("\n")[0].strip()
        options = sample_data['options']
        question_type = sample_data['question_type']
        model_predictions = sample_data['model_predictions']
        
        model_prediction_text = self._format_model_predictions(model_predictions)
        
        prompt = self._create_reasoning_prompt(
            question_text, 
            question_type, 
            options, 
            image_analysis, 
            clinical_context, 
            model_prediction_text
        )
        
#         print("\n==== REASONING PROMPT ====")
#         print(prompt)
#         print("==========================\n")
        
        try:
            response = self.client.models.generate_content(
                model=Config.GEMINI_MODEL,
                contents=[prompt]
            )
            
            reasoning_text = response.text
            
#             print("\n==== RAW LLM RESPONSE ====")
#             print(reasoning_text)
#             print("===========================\n")
            
            reasoning_result = self._parse_json_response(reasoning_text)
            
            validated_answer = self._validate_answer(reasoning_result.get('answer', ''), options)
            reasoning_result['validated_answer'] = validated_answer
            
#             print("\n==== PROCESSED REASONING RESULT ====")
#             print(json.dumps(reasoning_result, indent=2))
#             print("====================================\n")
            
            return reasoning_result
            
        except Exception as e:
            print(f"Error applying reasoning layer for {encounter_id}, {base_qid}: {str(e)}")
            return {
                "reasoning": f"Error: {str(e)}",
                "answer": "Not mentioned",
                "validated_answer": "Not mentioned",
                "error": str(e)
            }
    
    def _format_model_predictions(self, model_predictions):
        """Format model predictions for the prompt."""
        model_prediction_text = ""
        for model_name, predictions in model_predictions.items():
            combined_pred = predictions.get('model_prediction', '')
            if isinstance(combined_pred, float) and pd.isna(combined_pred):
                combined_pred = "No prediction"
            model_prediction_text += f"- {model_name}: {combined_pred}\n"
#         print(model_prediction_text)
        return model_prediction_text

    def _create_reasoning_prompt(self, question_text, question_type, options, image_analysis, clinical_context, model_prediction_text):
        """Create a prompt for the reasoning layer."""
#         print("\n--- Question:", question_text)
#         print("--- Question Type:", question_type)
#         print("--- Options:", ", ".join(options))
#         print("--- Image Analysis:", json.dumps(image_analysis.get('aggregated_analysis', {}), indent=2)[:300] + "..." if len(json.dumps(image_analysis.get('aggregated_analysis', {}))) > 300 else json.dumps(image_analysis.get('aggregated_analysis', {})))
#         print("--- Clinical Context:", clinical_context.get('structured_clinical_context', '')[:300] + "..." if len(clinical_context.get('structured_clinical_context', '')) > 300 else clinical_context.get('structured_clinical_context', ''))
#         print("--- Model Predictions:", model_prediction_text)

        specialized_guidance = ""
        include_clinical_context = True

        multiple_answers_allowed = question_type in ["Site Location", "Size", "Skin Description"]

        if multiple_answers_allowed:
            task_description = """Based on all the evidence above, determine the most accurate answer(s) to the question. Your task is to:
    1. Analyze the evidence from the image analysis{0}
    2. Consider the model predictions, noting any consensus or disagreement, but maintain your critical judgment
    3. Provide a brief reasoning for your conclusion
    4. Select the final answer(s) from the available options

    If selecting multiple answers is appropriate, provide them in a comma-separated list. If no answer can be determined, select "Not mentioned".""".format(' and clinical context' if include_clinical_context else '')
        else:
            task_description = """Based on all the evidence above, determine the SINGLE most accurate answer to the question. Your task is to:
    1. Analyze the evidence from the image analysis{0}
    2. Consider the model predictions, noting any consensus or disagreement, but maintain your critical judgment
    3. Provide a brief reasoning for your conclusion
    4. Select ONLY ONE answer option that is most accurate

    For this question type, you must select ONLY ONE option as your answer. If no answer can be determined, select "Not mentioned".""".format(' and clinical context' if include_clinical_context else '')

        if question_type == "Size" and all(option in ", ".join(options) for option in ["size of thumb nail", "size of palm", "larger area"]):
            specialized_guidance = """
    SPECIALIZED GUIDANCE FOR SIZE ASSESSMENT:
    When answering this size-related question, interpret the options as follows:
    - "size of thumb nail": Individual lesions or affected areas approximately 1-2 cm in diameter
    - "size of palm": Affected areas larger than the size of a thumb nail and roughly the size of a palm (approximately 1% of body surface area), which may include multiple smaller lesions across a region
    - "larger area": Widespread involvement significantly larger than a palm, affecting a substantial portion(s) of the body

    IMPORTANT: For cases with multiple small lesions that are visible in the images, but without extensive widespread involvement across large body regions, "size of palm" is likely the most appropriate answer.

    Base your assessment PRIMARILY on the current state shown in the IMAGES and their analysis, not on descriptions of progression or potential future spread mentioned in the clinical context. Prioritize what you can directly observe in the image analysis over clinical descriptions.
    """
            include_clinical_context = False

        elif question_type == "Lesion Color" and "combination" in ", ".join(options):
            specialized_guidance = """
    SPECIALIZED GUIDANCE FOR LESION COLOR:
    When answering color-related questions, pay careful attention to whether there are multiple distinct colors present across the affected areas. "Combination" would be appropriate when different lesions display different colors (e.g., some lesions appear red while others appear white), or when individual lesions show mixed or varied coloration patterns.
    """

        base_prompt = f"""You are a medical expert analyzing dermatological images. Use the provided evidence to determine the most accurate answer(s) for the following question:

    QUESTION: {question_text}
    QUESTION TYPE: {question_type}
    OPTIONS: {", ".join(options)}

    IMAGE ANALYSIS:
    {json.dumps(image_analysis['aggregated_analysis'], indent=2)}
    """

        if include_clinical_context:
            base_prompt += f"""
    CLINICAL CONTEXT:
    {clinical_context['structured_clinical_context']}
    """
        else:
            base_prompt += """
    NOTE: For this question type, the analysis is based primarily on image evidence rather than clinical descriptions.
    """

        return base_prompt + f"""
    MODEL PREDICTIONS:
    {model_prediction_text}

    {specialized_guidance}

    IMPORTANT: While multiple model predictions are provided, be aware that these predictions can be inaccurate or inconsistent. Do not assume majority agreement equals correctness. Evaluate the evidence critically and independently from these predictions. Your job is to determine the correct answer based primarily on the image analysis, treating model predictions as secondary suggestions that may contain errors.

    {task_description}

    Format your response as a JSON object with these fields:
    1. "reasoning": Your step-by-step reasoning process
    2. "answer": Your final answer(s) as a single string or comma-separated list of options

    When providing your answer, strictly adhere to the available options and only select from them.
    """
    
    def _validate_answer(self, answer, options):
        """Validate the answer against available options."""
        answer = answer.lower()
        valid_answers = []
        
        if ',' in answer:
            answer_parts = [part.strip() for part in answer.split(',')]
            for part in answer_parts:
                for option in options:
                    if part == option.lower():
                        valid_answers.append(option)
        else:
            for option in options:
                if answer == option.lower():
                    valid_answers.append(option)
        
        if not valid_answers:
            if "not mentioned" in answer:
                valid_answers = ["Not mentioned"]
            else:
                valid_answers = ["Not mentioned"]
        
#         print(valid_answers)
        return ", ".join(valid_answers)

In [7]:
class DermatologyPipeline:
    def __init__(self, analysis_service):
        self.analysis_service = analysis_service
    
    def process_single_encounter(self, agentic_data, encounter_id):
        """
        Process a single encounter with all its questions using the reasoning layer.
        
        Args:
            agentic_data: AgenticRAGData instance containing all encounter data
            encounter_id: The specific encounter ID to process
            
        Returns:
            Dictionary with all questions processed with reasoning for this encounter
        """
#         print("Printing everything in Dermatology Pipeline now")
        
        all_pairs = agentic_data.get_all_encounter_question_pairs()
        encounter_pairs = [pair for pair in all_pairs if pair[0] == encounter_id]
        
        if not encounter_pairs:
            print(f"No data found for encounter {encounter_id}")
            return None
        
#         print(f"Processing {len(encounter_pairs)} questions for encounter {encounter_id}")
        
        encounter_results = {encounter_id: {}}
        
#         print(f"Computing image analysis for {encounter_id}")
        sample_data = agentic_data.get_combined_data(encounter_pairs[0][0], encounter_pairs[0][1])
#         print(sample_data)
        image_analysis = self.analysis_service.extract_dermatological_analysis(sample_data)
#         print(image_analysis)
        
#         print(f"Computing clinical context for {encounter_id}")
        clinical_context = self.analysis_service.extract_clinical_context(sample_data)
#         print(clinical_context)
        
        for i, (encounter_id, base_qid) in enumerate(encounter_pairs):
#             print(f"Processing question {i+1}/{len(encounter_pairs)}: {base_qid}")
            
            sample_data = agentic_data.get_combined_data(encounter_id, base_qid)
#             print(sample_data)
            if not sample_data:
                print(f"Warning: No data found for {encounter_id}, {base_qid}")
                continue
            
#             print(f"Applying reasoning layer for {encounter_id}, {base_qid}")
            reasoning_result = self.analysis_service.apply_reasoning_layer(
                encounter_id,
                base_qid,
                image_analysis,
                clinical_context,
                sample_data
            )
#             print(reasoning_result)
            
            encounter_results[encounter_id][base_qid] = {
                "query_context": sample_data['query_context'],
                "options": sample_data['options'],
                "model_predictions": sample_data['model_predictions'],
                "reasoning_result": reasoning_result,
                "final_answer": reasoning_result.get('validated_answer', 'Not mentioned')
            }
#             print(encounter_results)
        
        output_file = os.path.join(Config.OUTPUT_DIR, f"reasoning_results_{encounter_id}.json")
        with open(output_file, "w") as f:
            json.dump(encounter_results, f, indent=2)
        
#         print(f"Processed all {len(encounter_pairs)} questions for encounter {encounter_id}")
        return encounter_results
    
    def format_results_for_evaluation(self, encounter_results, output_file):
        """Format results for official evaluation."""
        QIDS = [
            "CQID010-001",
            "CQID011-001", "CQID011-002", "CQID011-003", "CQID011-004", "CQID011-005", "CQID011-006",
            "CQID012-001", "CQID012-002", "CQID012-003", "CQID012-004", "CQID012-005", "CQID012-006",
            "CQID015-001",
            "CQID020-001", "CQID020-002", "CQID020-003", "CQID020-004", "CQID020-005", 
            "CQID020-006", "CQID020-007", "CQID020-008", "CQID020-009",
            "CQID025-001",
            "CQID034-001",
            "CQID035-001",
            "CQID036-001",
        ]
        
        qid_variants = {}
        for qid in QIDS:
            base_qid, variant = qid.split('-')
            if base_qid not in qid_variants:
                qid_variants[base_qid] = []
            qid_variants[base_qid].append(qid)
        
        required_base_qids = set(qid.split('-')[0] for qid in QIDS)
        
        formatted_predictions = []
        for encounter_id, questions in encounter_results.items():
            encounter_base_qids = set(questions.keys())
            if not required_base_qids.issubset(encounter_base_qids):
                print(f"Skipping encounter {encounter_id} - missing required questions")
                continue
            
            pred_entry = {'encounter_id': encounter_id}
            
            for base_qid, question_data in questions.items():
                if base_qid not in qid_variants:
                    continue
                
                final_answer = question_data['final_answer']
                options = question_data['options']
                
                not_mentioned_index = self._find_not_mentioned_index(options)
                
                self._process_answers(
                    pred_entry, 
                    base_qid, 
                    final_answer, 
                    options, 
                    qid_variants, 
                    not_mentioned_index
                )
            
            formatted_predictions.append(pred_entry)
#             print(formatted_predictions)
        
        with open(output_file, 'w') as f:
            json.dump(formatted_predictions, f, indent=2)
        
#         print(f"Formatted predictions saved to {output_file} ({len(formatted_predictions)} complete encounters)")
#         print(formatted_predictions)
        return formatted_predictions
    
    def _find_not_mentioned_index(self, options):
        """Find the index of 'Not mentioned' in options."""
        for i, opt in enumerate(options):
            if opt.lower() == "not mentioned":
                return i
        return len(options) - 1
    
    def _process_answers(self, pred_entry, base_qid, final_answer, options, qid_variants, not_mentioned_index):
        """Process answers and add to prediction entry."""
        if ',' in final_answer:
            answer_parts = [part.strip() for part in final_answer.split(',')]
            answer_indices = []
            
            for part in answer_parts:
                found = False
                for i, opt in enumerate(options):
                    if part.lower() == opt.lower():
                        answer_indices.append(i)
                        found = True
                        break
                
                if not found:
                    answer_indices.append(not_mentioned_index)
            
            available_variants = qid_variants[base_qid]
            
            for i, idx in enumerate(answer_indices):
                if i < len(available_variants):
                    pred_entry[available_variants[i]] = idx
            
            for i in range(len(answer_indices), len(available_variants)):
                pred_entry[available_variants[i]] = not_mentioned_index
            
        else:
            answer_index = not_mentioned_index
            
            for i, opt in enumerate(options):
                if final_answer.lower() == opt.lower():
                    answer_index = i
                    break
            
            pred_entry[qid_variants[base_qid][0]] = answer_index
            
            if len(qid_variants[base_qid]) > 1:
                for i in range(1, len(qid_variants[base_qid])):
                    pred_entry[qid_variants[base_qid][i]] = not_mentioned_index

In [8]:
def run_all_encounters_pipeline():
    """Run the pipeline for all available encounters and combine the results."""
    model_predictions_dict = DataLoader.load_all_model_predictions(Config.MODEL_PREDICTIONS_DIR)
    all_models_df = pd.concat(model_predictions_dict.values(), ignore_index=True)
    validation_df = DataLoader.load_validation_dataset(Config.VAL_DATASET_PATH)
    agentic_data = AgenticRAGData(all_models_df, validation_df)
    
    all_pairs = agentic_data.get_all_encounter_question_pairs()
    unique_encounter_ids = sorted(list(set(pair[0] for pair in all_pairs)))
    print(f"Found {len(unique_encounter_ids)} unique encounters to process")
    
    analysis_service = AnalysisService()
    
    pipeline = DermatologyPipeline(analysis_service)
    
    all_encounter_results = {}
    for i, encounter_id in enumerate(unique_encounter_ids):
        print(f"Processing encounter {i+1}/{len(unique_encounter_ids)}: {encounter_id}...")
        encounter_results = pipeline.process_single_encounter(agentic_data, encounter_id)
        if encounter_results:
            all_encounter_results.update(encounter_results)
        
        if (i+1) % 5 == 0 or (i+1) == len(unique_encounter_ids):
            timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
            intermediate_output_file = os.path.join(
                Config.OUTPUT_DIR, 
                f"intermediate_results_{i+1}_of_{len(unique_encounter_ids)}_{timestamp}.json"
            )
            with open(intermediate_output_file, 'w') as f:
                json.dump(all_encounter_results, f, indent=2)
            print(f"Saved intermediate results after processing {i+1} encounters")
    
    timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
    output_file = os.path.join(
        Config.OUTPUT_DIR, 
        f"data_cvqa_sys_reasoned_all_{timestamp}.json"
    )
    
    formatted_predictions = pipeline.format_results_for_evaluation(all_encounter_results, output_file)
    
    print(f"Processed {len(formatted_predictions)} encounters successfully")
    return formatted_predictions

In [9]:
# if __name__ == "__main__":
#     formatted_predictions = run_all_encounters_pipeline()
#     print(f"Total complete encounters processed: {len(formatted_predictions)}")

In [12]:
def run_single_encounter_pipeline(encounter_id):
    """Run the pipeline for a single encounter."""
    model_predictions_dict = DataLoader.load_all_model_predictions(Config.MODEL_PREDICTIONS_DIR)
    print(model_predictions_dict)
    
    all_models_df = pd.concat(model_predictions_dict.values(), ignore_index=True)
    
    validation_df = DataLoader.load_validation_dataset(Config.VAL_DATASET_PATH)
#     print(validation_df.head())
    agentic_data = AgenticRAGData(all_models_df, validation_df)
    
    analysis_service = AnalysisService()
    
    pipeline = DermatologyPipeline(analysis_service)
    encounter_results = pipeline.process_single_encounter(agentic_data, encounter_id)
    
    output_file = os.path.join(
        Config.OUTPUT_DIR, 
        f"data_cvqa_sys_reasoned_{encounter_id}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
    )
    formatted_predictions = pipeline.format_results_for_evaluation(encounter_results, output_file)
    
    return formatted_predictions

In [13]:
if __name__ == "__main__":
    encounter_id = "ENC00853"
    formatted_predictions = run_single_encounter_pipeline(encounter_id)
    print(f"Processed encounter {encounter_id} with {len(formatted_predictions)} prediction entries")

No aggregated prediction files found. Cannot proceed.
{}


ValueError: No objects to concatenate

In [None]:
formatted_predictions