# CogniRad++: Demo Notebook

This notebook demonstrates how to use CogniRad++ for chest X-ray report generation.

**Features:**
- Load pretrained model
- Generate structured radiology reports
- Predict diseases with confidence scores
- Visualize attention maps
- Interactive clinician interface

## 1. Setup and Imports

In [None]:
import os
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as transforms
from pathlib import Path

# Add project root to path
sys.path.append(str(Path.cwd().parent))

from models.cognirad import CogniRadPlusPlus
from models.classifier import CheXpertLabelEncoder

print("‚úÖ Imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 2. Load Pretrained Model

In [None]:
# Configuration
CHECKPOINT_PATH = '../checkpoints/best_model.pt'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"Using device: {DEVICE}")
print(f"Loading model from: {CHECKPOINT_PATH}")

# Load checkpoint
checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)

# Create model
model = CogniRadPlusPlus(
    visual_backbone='resnet50',
    num_diseases=14,
    pretrained=False
).to(DEVICE)

# Load weights
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print("‚úÖ Model loaded successfully!")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M")

## 3. Image Preprocessing

In [None]:
# Define image transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

def load_and_preprocess_image(image_path: str):
    """Load and preprocess chest X-ray image"""
    img = Image.open(image_path).convert('RGB')
    img_tensor = transform(img).unsqueeze(0)  # Add batch dimension
    return img, img_tensor

print("‚úÖ Image preprocessing ready!")

## 4. Generate Report for Single Image

In [None]:
# Example image path (replace with your image)
IMAGE_PATH = '../data/example_cxr.jpg'
CLINICAL_INDICATION = "55M with fever and cough"

# Load image
if os.path.exists(IMAGE_PATH):
    original_img, img_tensor = load_and_preprocess_image(IMAGE_PATH)
    img_tensor = img_tensor.to(DEVICE)
    
    # Display image
    plt.figure(figsize=(8, 8))
    plt.imshow(original_img, cmap='gray')
    plt.title('Chest X-ray')
    plt.axis('off')
    plt.show()
    
    # Generate report
    print("\nüîÑ Generating report...")
    
    with torch.no_grad():
        report = model.generate_report(
            images=img_tensor,
            clinical_indication=CLINICAL_INDICATION,
            confidence_threshold=0.7,
            include_evidence=True
        )
    
    # Display results
    print("\n" + "="*70)
    print("RADIOLOGY REPORT")
    print("="*70)
    print(f"\nClinical Indication: {report['clinical_indication']}")
    print(f"\nFINDINGS:\n{report['findings']}")
    print(f"\nIMPRESSION:\n{report['impression']}")
    
    # Disease predictions
    if report['predicted_diseases']:
        print("\n" + "="*70)
        print("PREDICTED PATHOLOGIES")
        print("="*70)
        for disease in report['predicted_diseases']:
            print(f"\n‚Ä¢ {disease['label']}")
            print(f"  Probability: {disease['probability']:.2%}")
            print(f"  Confidence:  {disease['confidence']:.2%}")
    
    # Warnings
    if 'warnings' in report:
        print("\n" + "="*70)
        print("‚ö†Ô∏è  WARNINGS")
        print("="*70)
        for warning in report['warnings']:
            print(f"‚Ä¢ {warning}")
    
    print("\n" + "="*70)
    
else:
    print(f"‚ùå Image not found: {IMAGE_PATH}")
    print("Please provide a chest X-ray image")

## 5. Batch Processing

In [None]:
# Process multiple images
IMAGE_DIR = '../data/examples'

if os.path.exists(IMAGE_DIR):
    image_files = list(Path(IMAGE_DIR).glob('*.jpg')) + list(Path(IMAGE_DIR).glob('*.png'))
    
    print(f"Found {len(image_files)} images")
    
    for img_path in image_files[:5]:  # Process first 5
        print(f"\n{'='*70}")
        print(f"Processing: {img_path.name}")
        print(f"{'='*70}")
        
        original_img, img_tensor = load_and_preprocess_image(str(img_path))
        img_tensor = img_tensor.to(DEVICE)
        
        with torch.no_grad():
            report = model.generate_report(
                images=img_tensor,
                clinical_indication="Chest X-ray"
            )
        
        # Display image and report
        fig, ax = plt.subplots(1, 1, figsize=(6, 6))
        ax.imshow(original_img, cmap='gray')
        ax.set_title(img_path.name)
        ax.axis('off')
        plt.show()
        
        print(f"\nFindings: {report['findings'][:200]}...")
        print(f"\nDiseases: {len(report['predicted_diseases'])} detected")
else:
    print(f"Directory not found: {IMAGE_DIR}")

## 6. Visualize Attention Maps

In [None]:
import cv2

def visualize_attention(image, attention_map, alpha=0.5):
    """Overlay attention map on image"""
    # Convert image to numpy
    img_np = np.array(image)
    
    # Resize attention map to image size
    if isinstance(attention_map, torch.Tensor):
        attention_map = attention_map.cpu().numpy()
    
    # Normalize attention
    attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min())
    
    # Resize to image size
    attention_resized = cv2.resize(attention_map, (img_np.shape[1], img_np.shape[0]))
    
    # Apply colormap
    heatmap = cv2.applyColorMap(np.uint8(255 * attention_resized), cv2.COLORMAP_JET)
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
    
    # Overlay
    if len(img_np.shape) == 2:
        img_np = np.stack([img_np] * 3, axis=-1)
    
    overlay = cv2.addWeighted(img_np, 1-alpha, heatmap, alpha, 0)
    
    return overlay

# Generate report with attention
if os.path.exists(IMAGE_PATH):
    original_img, img_tensor = load_and_preprocess_image(IMAGE_PATH)
    img_tensor = img_tensor.to(DEVICE)
    
    with torch.no_grad():
        report = model.generate_report(
            images=img_tensor,
            clinical_indication=CLINICAL_INDICATION,
            include_evidence=True
        )
    
    # Visualize attention for each predicted disease
    if 'attention_maps' in report and report['attention_maps']['concept_attention'] is not None:
        attention = report['attention_maps']['concept_attention'][0, 0, :]  # [num_concepts]
        
        # Create heatmap (simplified - in production, use actual spatial attention)
        attention_2d = attention.view(7, 7).cpu().numpy()  # Reshape to 2D
        
        # Visualize
        fig, axes = plt.subplots(1, 2, figsize=(12, 6))
        
        axes[0].imshow(original_img, cmap='gray')
        axes[0].set_title('Original Image')
        axes[0].axis('off')
        
        overlay = visualize_attention(original_img, attention_2d)
        axes[1].imshow(overlay)
        axes[1].set_title('Attention Map')
        axes[1].axis('off')
        
        plt.tight_layout()
        plt.show()
    else:
        print("Attention maps not available")

## 7. Interactive Clinician Interface

In [None]:
# Simple interactive interface
def interactive_diagnosis(image_path: str, clinical_indication: str):
    """Interactive report generation"""
    # Load and process
    original_img, img_tensor = load_and_preprocess_image(image_path)
    img_tensor = img_tensor.to(DEVICE)
    
    # Generate report
    with torch.no_grad():
        report = model.generate_report(
            images=img_tensor,
            clinical_indication=clinical_indication,
            confidence_threshold=0.7
        )
    
    # Display
    print("\n" + "="*70)
    print("AUTOMATED RADIOLOGY REPORT")
    print("="*70)
    print(f"\nIndication: {clinical_indication}")
    print(f"\nFINDINGS:\n{report['findings']}")
    print(f"\nIMPRESSION:\n{report['impression']}")
    
    # Show predictions
    print("\n" + "="*70)
    print("PREDICTED PATHOLOGIES (>50% confidence)")
    print("="*70)
    
    for disease in report['predicted_diseases']:
        confidence_emoji = "üü¢" if disease['confidence'] > 0.8 else "üü°"
        print(f"{confidence_emoji} {disease['label']}: {disease['probability']:.1%}")
    
    # Uncertain findings
    if report['uncertain_findings']:
        print("\n" + "="*70)
        print("‚ö†Ô∏è  UNCERTAIN FINDINGS (require clinical correlation)")
        print("="*70)
        for disease in report['uncertain_findings']:
            print(f"‚Ä¢ {disease['label']}: {disease['probability']:.1%} (low confidence)")
    
    return report

# Example usage
if os.path.exists(IMAGE_PATH):
    report = interactive_diagnosis(
        IMAGE_PATH,
        "72F presenting with shortness of breath"
    )

## 8. Export Report

In [None]:
import json
from datetime import datetime

def export_report(report: dict, output_path: str):
    """Export report to JSON"""
    report_data = {
        'timestamp': datetime.now().isoformat(),
        'clinical_indication': report['clinical_indication'],
        'findings': report['findings'],
        'impression': report['impression'],
        'predicted_diseases': report['predicted_diseases'],
        'uncertain_findings': report.get('uncertain_findings', []),
        'warnings': report.get('warnings', [])
    }
    
    with open(output_path, 'w') as f:
        json.dump(report_data, f, indent=2)
    
    print(f"‚úÖ Report exported to {output_path}")

# Example
if 'report' in locals():
    export_report(report, '../outputs/example_report.json')

## 9. Model Statistics

In [None]:
# Display model information
print("CogniRad++ Model Statistics")
print("="*70)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total Parameters:     {total_params:,} ({total_params/1e6:.2f}M)")
print(f"Trainable Parameters: {trainable_params:,} ({trainable_params/1e6:.2f}M)")

# Component breakdown
encoder_params = sum(p.numel() for p in model.visual_encoder.parameters())
classifier_params = sum(p.numel() for p in model.disease_classifier.parameters())
decoder_params = sum(p.numel() for p in model.report_generator.parameters())

print(f"\nComponent Breakdown:")
print(f"  Visual Encoder:     {encoder_params:,} ({encoder_params/1e6:.2f}M)")
print(f"  Disease Classifier: {classifier_params:,} ({classifier_params/1e6:.2f}M)")
print(f"  Report Generator:   {decoder_params:,} ({decoder_params/1e6:.2f}M)")

# Supported diseases
print(f"\nSupported Disease Labels ({model.num_diseases}):")
for i, label in enumerate(model.label_encoder.label_names, 1):
    print(f"  {i:2d}. {label}")

## 10. Summary

This notebook demonstrated:

‚úÖ Loading pretrained CogniRad++ model  
‚úÖ Generating structured radiology reports  
‚úÖ Disease prediction with confidence scores  
‚úÖ Attention visualization  
‚úÖ Interactive clinician interface  
‚úÖ Batch processing  
‚úÖ Report export  

**Next Steps:**
- Fine-tune on your own dataset
- Integrate with PACS systems
- Deploy as web service
- Conduct clinical validation studies