## üì¶ Setup and Imports

In [None]:
import sys
import os
from pathlib import Path
import torch
import torch.nn.functional as F
from PIL import Image
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import yaml
import warnings
from torchvision import transforms
from tqdm.notebook import tqdm

# Add project root to path
project_root = Path().absolute()
sys.path.insert(0, str(project_root))

# Import your models
from src.models.text_model import create_text_model
from src.models.multimodal_model import create_multimodal_model

# Setup plotting
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)
warnings.filterwarnings('ignore')

print("  Setup complete!")
print(f"  Project root: {project_root}")
print(f"  Device available: {'CUDA' if torch.cuda.is_available() else 'CPU'}")

## ü§ñ Load Your Trained Models

In [None]:
class VQAPredictor:
    """Easy-to-use VQA prediction interface"""
    
    def __init__(self):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Load config
        with open('config.yaml', 'r') as f:
            self.config = yaml.safe_load(f)
        
        # Load answer vocabulary
        with open('answers.txt', 'r', encoding='utf-8') as f:
            answers = [line.strip() for line in f.readlines()]
        
        self.answers = sorted(list(set(answers)))
        self.answer_to_idx = {ans: idx for idx, ans in enumerate(self.answers)}
        self.idx_to_answer = {idx: ans for ans, idx in self.answer_to_idx.items()}
        self.num_classes = len(self.answers)
        
        # Simple vocab for inference (you can expand this)
        self.vocab = {
            '<PAD>': 0, '<UNK>': 1, '<SOS>': 2, '<EOS>': 3,
            'what': 4, 'is': 5, 'the': 6, 'in': 7, 'of': 8, 'a': 9, 'and': 10,
            'to': 11, 'this': 12, 'image': 13, 'shown': 14, 'visible': 15,
            'organ': 16, 'tissue': 17, 'cell': 18, 'structure': 19, 'abnormal': 20,
            'normal': 21, 'pathology': 22, 'medical': 23, 'diagnosis': 24, 'disease': 25
        }
        self.vocab_size = len(self.vocab)
        
        # Load models
        self._load_models()
        
        # Image preprocessing
        self.image_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225])
        ])
        
        print(f"  VQA Predictor initialized!")
        print(f"  Device: {self.device}")
        print(f"  Answer classes: {self.num_classes}")
        print(f"  Vocabulary size: {self.vocab_size}")
    
    def _load_models(self):
        """Load both trained models"""
        # Text-only model
        self.text_model = create_text_model(
            vocab_size=self.vocab_size,
            embedding_dim=self.config['text']['embedding_dim'],
            hidden_dim=self.config['model']['baseline']['hidden_dim'],
            num_classes=self.num_classes,
            dropout=self.config['model']['baseline']['dropout']
        ).to(self.device)
        
        # Multimodal model  
        self.multimodal_model = create_multimodal_model(
            model_type='concat',
            vocab_size=self.vocab_size,
            num_classes=self.num_classes,
            embedding_dim=self.config['text']['embedding_dim'],
            text_hidden_dim=self.config['model']['baseline']['hidden_dim'],
            fusion_hidden_dim=self.config['model']['baseline']['hidden_dim'],
            dropout=self.config['model']['baseline']['dropout']
        ).to(self.device)
        
        # Load checkpoints
        try:
            text_checkpoint = torch.load(
                'checkpoints/text_baseline_lstm_notebook/best_model.pth', 
                map_location=self.device
            )
            multimodal_checkpoint = torch.load(
                'checkpoints/multimodal_concat/best_model.pth',
                map_location=self.device
            )
            
            self.text_model.load_state_dict(text_checkpoint)
            self.multimodal_model.load_state_dict(multimodal_checkpoint)
            
            print("    Text-only model loaded")
            print("    Multimodal model loaded")
            
        except FileNotFoundError as e:
            print(f"    Error loading models: {e}")
            print("  Make sure you have trained both models first!")
            return
        
        self.text_model.eval()
        self.multimodal_model.eval()
    
    def encode_question(self, question: str, max_length: int = 32):
        """Encode question to tensor"""
        words = question.lower().split()
        indices = [self.vocab.get(word, self.vocab.get('<UNK>', 1)) for word in words]
        
        # Pad or truncate
        if len(indices) < max_length:
            indices = indices + [self.vocab.get('<PAD>', 0)] * (max_length - len(indices))
        else:
            indices = indices[:max_length]
        
        return torch.tensor(indices, dtype=torch.long).unsqueeze(0).to(self.device)
    
    def predict(self, image_path=None, question="", top_k=5):
        """Make prediction with both models"""
        with torch.no_grad():
            # Encode question
            question_tensor = self.encode_question(question)
            
            # Text-only prediction
            text_logits = self.text_model(question_tensor)
            text_probs = F.softmax(text_logits, dim=1)
            text_top_k = torch.topk(text_probs, top_k, dim=1)
            
            text_results = []
            for prob, idx in zip(text_top_k.values[0], text_top_k.indices[0]):
                text_results.append({
                    'answer': self.idx_to_answer[idx.item()],
                    'confidence': prob.item()
                })
            
            results = {
                'question': question,
                'text_only': text_results
            }
            
            # Multimodal prediction (if image provided)
            if image_path and Path(image_path).exists():
                try:
                    # Load and preprocess image
                    image = Image.open(image_path).convert('RGB')
                    image_tensor = self.image_transform(image).unsqueeze(0).to(self.device)
                    
                    # Make prediction
                    mm_logits = self.multimodal_model(question_tensor, image_tensor)
                    mm_probs = F.softmax(mm_logits, dim=1)
                    mm_top_k = torch.topk(mm_probs, top_k, dim=1)
                    
                    mm_results = []
                    for prob, idx in zip(mm_top_k.values[0], mm_top_k.indices[0]):
                        mm_results.append({
                            'answer': self.idx_to_answer[idx.item()],
                            'confidence': prob.item()
                        })
                    
                    results['multimodal'] = mm_results
                    results['image'] = image
                    results['image_path'] = image_path
                    
                except Exception as e:
                    print(f"Error loading image {image_path}: {e}")
            
            return results

# Initialize predictor
predictor = VQAPredictor()

## üéØ Quick Test - Single Prediction

In [None]:
def quick_test(image_path, question, top_k=3):
    """Quick test function with visualization"""
    result = predictor.predict(image_path, question, top_k=top_k)
    
    # Create visualization
    fig = plt.figure(figsize=(16, 6))
    
    # Display image if available
    if 'image' in result:
        # Image subplot
        ax1 = plt.subplot(1, 3, 1)
        plt.imshow(result['image'])
        plt.axis('off')
        plt.title(f"Image: {Path(image_path).name}", fontsize=12, fontweight='bold')
        
        # Text predictions subplot
        ax2 = plt.subplot(1, 3, 2)
        plt.axis('off')
        
        text_content = f"Question:\n{question}\n\n"
        text_content += "  Text-only predictions:\n"
        for i, pred in enumerate(result['text_only'][:top_k]):
            text_content += f"{i+1}. {pred['answer']} ({pred['confidence']:.3f})\n"
        
        plt.text(0.05, 0.95, text_content, fontsize=11, verticalalignment='top',
                transform=ax2.transAxes, bbox=dict(boxstyle="round,pad=0.3", facecolor="lightblue", alpha=0.5))
        
        # Multimodal predictions subplot
        ax3 = plt.subplot(1, 3, 3)
        plt.axis('off')
        
        if 'multimodal' in result:
            mm_content = "   Multimodal predictions:\n"
            for i, pred in enumerate(result['multimodal'][:top_k]):
                mm_content += f"{i+1}. {pred['answer']} ({pred['confidence']:.3f})\n"
            
            # Highlight if different
            text_answer = result['text_only'][0]['answer']
            mm_answer = result['multimodal'][0]['answer']
            
            if text_answer != mm_answer:
                mm_content += "\n  Models disagree!\n"
                mm_content += f"Text: {text_answer}\n"
                mm_content += f"Visual: {mm_answer}"
                box_color = "salmon"
            else:
                mm_content += f"\n  Both agree: {text_answer}"
                box_color = "lightgreen"
        else:
            mm_content = "  No image provided\nOnly text prediction shown"
            box_color = "lightyellow"
        
        plt.text(0.05, 0.95, mm_content, fontsize=11, verticalalignment='top',
                transform=ax3.transAxes, bbox=dict(boxstyle="round,pad=0.3", facecolor=box_color, alpha=0.5))
    
    else:
        # Text-only mode
        plt.axis('off')
        content = f"Question: {question}\n\n"
        content += "  Text-only predictions:\n"
        for i, pred in enumerate(result['text_only'][:top_k]):
            content += f"  {i+1}. {pred['answer']} ({pred['confidence']:.3f})\n"
        content += "\n  No image provided"
        
        plt.text(0.1, 0.5, content, fontsize=12, verticalalignment='center')
    
    plt.tight_layout()
    plt.show()
    
    return result

# Example usage - modify paths as needed
print("  Testing your models!")
print("Modify the paths below to test with your images:")

# Test with sample question
sample_result = quick_test(
    image_path=None,  # Change to your image path, e.g., "data/train/image1.png"
    question="What organ is shown in the image?"
)

## üñºÔ∏è Test with Real Image

**Replace the path below with one of your PathVQA images!**

In [None]:
# List some available images
image_dir = Path("data/train")
if image_dir.exists():
    images = list(image_dir.glob("*.png"))[:10]
    print(f"Available images in {image_dir}:")
    for i, img in enumerate(images):
        print(f"  {i+1}. {img.name}")
    
    if images:
        # Test with first available image
        sample_image = str(images[0])
        print(f"\nTesting with: {sample_image}")
        
        result = quick_test(
            image_path=sample_image,
            question="What is the primary structure visible in this pathology image?"
        )
else:
    print("  Image directory not found. Make sure 'data/train' exists with .png files.")
    print("You can still test with text-only predictions!")

## üî¨ Interactive Testing

**Try different questions and images!**

In [None]:
# Modify these variables and re-run to test different combinations

# Example questions you can try:
questions = [
    "What organ is shown in the image?",
    "What type of tissue is visible?",
    "Is there any abnormality present?",
    "What is the primary pathological finding?",
    "What structures are visible in this sample?"
]

print("  Try different questions:")
for i, q in enumerate(questions):
    print(f"{i+1}. {q}")

# Customize your test here:
your_question = "What organ is shown in the image?"  # ‚Üê Change this
your_image = None  # ‚Üê Change to your image path, e.g., "data/train/your_image.png"

print(f"\n  Testing with: '{your_question}'")
custom_result = quick_test(your_image, your_question)

## üìä Batch Evaluation

**Evaluate your models on the test set**

In [None]:
def evaluate_batch(test_csv="testrenamed.csv", image_dir="data/train", num_samples=100):
    """Evaluate both models on test set"""
    
    if not Path(test_csv).exists():
        print(f"  Test file {test_csv} not found")
        return None
    
    df = pd.read_csv(test_csv).head(num_samples)
    results = []
    
    print(f"  Evaluating on {len(df)} test samples...")
    
    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Processing"):
        # Construct image path
        image_name = row['image']
        if not str(image_name).endswith('.png'):
            image_name = f"{image_name}.png"
        
        image_path = Path(image_dir) / image_name
        question = row['question']
        true_answer = row['answer']
        
        # Make prediction
        pred = predictor.predict(
            str(image_path) if image_path.exists() else None, 
            question, 
            top_k=1
        )
        
        text_pred = pred['text_only'][0]['answer']
        text_conf = pred['text_only'][0]['confidence']
        
        result = {
            'question': question,
            'true_answer': true_answer,
            'text_prediction': text_pred,
            'text_confidence': text_conf,
            'text_correct': text_pred.lower() == true_answer.lower(),
            'image_path': str(image_path) if image_path.exists() else None
        }
        
        if 'multimodal' in pred:
            mm_pred = pred['multimodal'][0]['answer']
            mm_conf = pred['multimodal'][0]['confidence']
            result.update({
                'multimodal_prediction': mm_pred,
                'multimodal_confidence': mm_conf,
                'multimodal_correct': mm_pred.lower() == true_answer.lower(),
                'models_agree': text_pred.lower() == mm_pred.lower()
            })
        
        results.append(result)
    
    results_df = pd.DataFrame(results)
    
    # Calculate metrics
    text_acc = results_df['text_correct'].mean()
    print(f"\n  Results:")
    print(f"  Text-only accuracy: {text_acc:.4f} ({text_acc*100:.2f}%)")
    
    if 'multimodal_correct' in results_df.columns:
        mm_acc = results_df['multimodal_correct'].mean()
        improvement = mm_acc - text_acc
        agreement = results_df['models_agree'].mean()
        
        print(f"  Multimodal accuracy: {mm_acc:.4f} ({mm_acc*100:.2f}%)")
        print(f"  Improvement: {improvement:.4f} ({improvement*100:.2f} pp)")
        print(f"  Model agreement: {agreement:.4f} ({agreement*100:.1f}%)")
        
        if improvement > 0:
            print(f"    Multimodal model is better!")
        else:
            print(f"    Text-only model performed better on this sample")
    
    return results_df

# Run evaluation
print("  Starting batch evaluation...")
eval_results = evaluate_batch(num_samples=50)  # Start with 50 samples

## üìà Visualize Results

In [None]:
if eval_results is not None and len(eval_results) > 0:
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    
    # 1. Accuracy comparison
    text_acc = eval_results['text_correct'].mean()
    
    if 'multimodal_correct' in eval_results.columns:
        mm_acc = eval_results['multimodal_correct'].mean()
        accuracies = [text_acc, mm_acc]
        labels = ['Text-only', 'Multimodal']
        colors = ['skyblue', 'lightcoral']
    else:
        accuracies = [text_acc]
        labels = ['Text-only']
        colors = ['skyblue']
    
    bars = axes[0,0].bar(labels, accuracies, color=colors)
    axes[0,0].set_ylabel('Accuracy')
    axes[0,0].set_title('Model Accuracy Comparison')
    axes[0,0].set_ylim(0, 1)
    
    # Add value labels on bars
    for bar, acc in zip(bars, accuracies):
        axes[0,0].text(bar.get_x() + bar.get_width()/2, acc + 0.02, 
                      f'{acc:.3f}', ha='center', fontweight='bold')
    
    # 2. Confidence distribution
    axes[0,1].hist(eval_results['text_confidence'], alpha=0.6, label='Text-only', 
                   bins=20, color='skyblue', edgecolor='black')
    if 'multimodal_confidence' in eval_results.columns:
        axes[0,1].hist(eval_results['multimodal_confidence'], alpha=0.6, label='Multimodal', 
                       bins=20, color='lightcoral', edgecolor='black')
    axes[0,1].set_xlabel('Confidence Score')
    axes[0,1].set_ylabel('Frequency')
    axes[0,1].set_title('Prediction Confidence Distribution')
    axes[0,1].legend()
    
    # 3. Agreement analysis (if multimodal available)
    if 'models_agree' in eval_results.columns:
        agreement = eval_results['models_agree'].mean()
        sizes = [agreement, 1-agreement]
        labels_pie = [f'Agree\n({agreement:.1%})', f'Disagree\n({1-agreement:.1%})']
        colors_pie = ['lightgreen', 'salmon']
        
        axes[1,0].pie(sizes, labels=labels_pie, colors=colors_pie, autopct='', startangle=90)
        axes[1,0].set_title('Model Agreement Analysis')
    else:
        axes[1,0].text(0.5, 0.5, 'Multimodal results\nnot available', 
                       ha='center', va='center', fontsize=12)
        axes[1,0].set_title('Model Agreement')
    
    # 4. Performance by confidence
    # Bin predictions by confidence and show accuracy
    eval_results['conf_bin'] = pd.cut(eval_results['text_confidence'], bins=5)
    conf_acc = eval_results.groupby('conf_bin')['text_correct'].mean()
    
    x_pos = range(len(conf_acc))
    axes[1,1].bar(x_pos, conf_acc.values, color='steelblue', alpha=0.7)
    axes[1,1].set_xlabel('Confidence Bins')
    axes[1,1].set_ylabel('Accuracy')
    axes[1,1].set_title('Accuracy vs Confidence (Text-only)')
    axes[1,1].set_xticks(x_pos)
    axes[1,1].set_xticklabels([f'{interval.left:.2f}-{interval.right:.2f}' 
                               for interval in conf_acc.index], rotation=45)
    
    plt.tight_layout()
    plt.show()
    
    print("  Analysis complete! See visualizations above.")
else:
    print("  No evaluation results to visualize")

## üîç Analyze Specific Examples

**Look at interesting cases where models agree/disagree**

In [None]:
def show_interesting_examples(results_df, example_type='disagreement', n_examples=3):
    """Show specific types of examples"""
    
    if results_df is None or len(results_df) == 0:
        print(" No results available")
        return
    
    print(f"  Showing {example_type} examples:\n")
    
    if example_type == 'disagreement' and 'multimodal_prediction' in results_df.columns:
        # Cases where models disagree
        disagreements = results_df[results_df['models_agree'] == False]
        
        if len(disagreements) == 0:
            print(" No disagreements found - models always agree!")
            return
        
        sample = disagreements.sample(min(n_examples, len(disagreements)))
        
        for idx, (_, row) in enumerate(sample.iterrows()):
            print(f"Example {idx+1}:")
            print(f"  Question: {row['question']}")
            print(f"  True answer: {row['true_answer']}")
            print(f"  Text-only: {row['text_prediction']} {' ' if row['text_correct'] else ' '}")
            print(f"  Multimodal: {row['multimodal_prediction']} {' ' if row['multimodal_correct'] else ' '}")
            
            # Determine which is correct
            if row['text_correct'] and not row['multimodal_correct']:
                print(f"  ‚Üí Text-only was right!  ")
            elif row['multimodal_correct'] and not row['text_correct']:
                print(f"  ‚Üí Multimodal was right!   ")
            elif row['text_correct'] and row['multimodal_correct']:
                print(f"  ‚Üí Both wrong but disagreed on how  ")
            else:
                print(f"  ‚Üí Both were wrong  ")
            
            print()
    
    elif example_type == 'correct':
        # Cases where model(s) are correct
        if 'multimodal_correct' in results_df.columns:
            correct_cases = results_df[
                (results_df['text_correct'] == True) | (results_df['multimodal_correct'] == True)
            ]
        else:
            correct_cases = results_df[results_df['text_correct'] == True]
        
        sample = correct_cases.sample(min(n_examples, len(correct_cases)))
        
        for idx, (_, row) in enumerate(sample.iterrows()):
            print(f"Example {idx+1}:")
            print(f"  Question: {row['question']}")
            print(f"  Correct answer: {row['true_answer']}  ")
            print(f"  Text prediction: {row['text_prediction']}")
            if 'multimodal_prediction' in row:
                print(f"  Multimodal prediction: {row['multimodal_prediction']}")
            print()
    
    elif example_type == 'errors':
        # Cases where models are wrong
        if 'multimodal_correct' in results_df.columns:
            error_cases = results_df[
                (results_df['text_correct'] == False) & (results_df['multimodal_correct'] == False)
            ]
        else:
            error_cases = results_df[results_df['text_correct'] == False]
        
        sample = error_cases.sample(min(n_examples, len(error_cases)))
        
        for idx, (_, row) in enumerate(sample.iterrows()):
            print(f"Example {idx+1}:")
            print(f"  Question: {row['question']}")
            print(f"  Correct answer: {row['true_answer']}")
            print(f"  Text prediction: {row['text_prediction']}  ")
            if 'multimodal_prediction' in row:
                print(f"  Multimodal prediction: {row['multimodal_prediction']}  ")
            print()

# Show different types of examples
if eval_results is not None:
    show_interesting_examples(eval_results, 'disagreement', 3)
    print("\n" + "="*50 + "\n")
    show_interesting_examples(eval_results, 'correct', 2)
    print("\n" + "="*50 + "\n")
    show_interesting_examples(eval_results, 'errors', 2)
else:
    print("Run the batch evaluation first to see examples!")

##  Save Your Results

In [None]:
if eval_results is not None:
    # Save results to CSV
    output_file = f"vqa_evaluation_results_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.csv"
    eval_results.to_csv(output_file, index=False)
    
    print(f"  Results saved to: {output_file}")
    print(f"   Samples evaluated: {len(eval_results)}")
    
    # Summary statistics
    summary = {
        'total_samples': len(eval_results),
        'text_accuracy': eval_results['text_correct'].mean(),
        'text_avg_confidence': eval_results['text_confidence'].mean()
    }
    
    if 'multimodal_correct' in eval_results.columns:
        summary.update({
            'multimodal_accuracy': eval_results['multimodal_correct'].mean(),
            'multimodal_avg_confidence': eval_results['multimodal_confidence'].mean(),
            'improvement': eval_results['multimodal_correct'].mean() - eval_results['text_correct'].mean(),
            'agreement_rate': eval_results['models_agree'].mean()
        })
    
    print("\n  Final Summary:")
    for key, value in summary.items():
        if 'accuracy' in key or 'improvement' in key or 'agreement' in key:
            print(f"  {key.replace('_', ' ').title()}: {value:.4f} ({value*100:.2f}%)")
        elif 'confidence' in key:
            print(f"  {key.replace('_', ' ').title()}: {value:.4f}")
        else:
            print(f"  {key.replace('_', ' ').title()}: {value}")
else:
    print("No results to save. Run the evaluation first!")