<a href="https://colab.research.google.com/github/maclandrol/cours-ia-med/blob/master/08_Custom_Dataset_Integration.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Enseignant:** Emmanuel Noutahi, PhD

# Tutorial 8: Custom Dataset Integration

## Medical Context

### For Medical Students
Custom dataset integration is essential for:
- **Medical research**: Analyzing your own clinical data
- **Specialized studies**: Focus on specific pathologies
- **Local validation**: Adapting AI to your patient population
- **Student projects**: Conducting original research

### For Practitioners
- **Surgeons**: Validation on specific surgical cases
- **General practitioners**: Adaptation to local population characteristics
- **Medical educators**: Creating personalized teaching cases

## Learning Objectives

By the end of this tutorial, you will be able to:
1. **Efficiently load** your own radiological images
2. **Organize** a custom dataset for analysis
3. **Apply** TorchXRayVision models to your data
4. **Perform batch analysis** of multiple images
5. **Generate** comparative reports for your dataset
6. **Export** results for clinical or research use

## Prerequisites

This tutorial builds on **Tutorials 1-7**. You should be familiar with:
- Basic TorchXRayVision usage
- Pathology classification and detection
- Results interpretation

## Setup and Installation

In [None]:
# Install required libraries
!pip install torchxrayvision
!pip install torch torchvision
!pip install matplotlib seaborn
!pip install numpy pandas
!pip install scikit-image opencv-python
!pip install tqdm

print("Installation completed successfully")

In [None]:
# Import libraries
import torch
import torch.nn as nn
import torchxrayvision as xrv
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from PIL import Image
import cv2
import os
import glob
import zipfile
import io
import json
from google.colab import files
from tqdm import tqdm
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

# Display configuration
plt.rcParams['figure.figsize'] = (15, 10)
plt.rcParams['font.size'] = 12
sns.set_style("whitegrid")

print(f"Libraries imported successfully")
print(f"PyTorch version: {torch.__version__}")
print(f"TorchXRayVision version: {xrv.__version__}")

# GPU check
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device used: {device}")

# Create working directories
os.makedirs('custom_dataset', exist_ok=True)
os.makedirs('results', exist_ok=True)
print("Working directories created")

## Model Loading

In [None]:
# Load TorchXRayVision models for custom dataset analysis
print("Loading models for custom dataset analysis...")

models = {}
model_info = {}

# 1. Main model for classification
try:
    models['densenet'] = xrv.models.DenseNet(weights="densenet121-res224-all")
    models['densenet'].to(device)
    models['densenet'].eval()
    model_info['densenet'] = {
        'name': 'DenseNet121-All',
        'pathologies': models['densenet'].pathologies,
        'description': 'General model for all pathologies'
    }
    print("DenseNet121-All loaded")
except Exception as e:
    print(f"Error loading DenseNet: {e}")

# 2. CheXpert model
try:
    models['chexpert'] = xrv.models.DenseNet(weights="densenet121-res224-chexpert")
    models['chexpert'].to(device)
    models['chexpert'].eval()
    model_info['chexpert'] = {
        'name': 'CheXpert',
        'pathologies': models['chexpert'].pathologies,
        'description': 'Specialized for CheXpert data'
    }
    print("CheXpert loaded")
except Exception as e:
    print(f"Error loading CheXpert: {e}")

if not models:
    raise Exception("No models could be loaded")

# Select main model
main_model_key = list(models.keys())[0]
main_model = models[main_model_key]

print(f"\nAvailable models: {list(models.keys())}")
print(f"Main model: {model_info[main_model_key]['name']}")
print(f"Detectable pathologies: {len(main_model.pathologies)}")

## Dataset Upload Methods

Choose your preferred method for uploading your dataset:

### Method 1: Batch Upload (Recommended)
Upload multiple images at once or use a ZIP archive for large datasets.

In [None]:
def upload_dataset():
    """
    Unified dataset upload function supporting multiple formats
    """
    print("BATCH DATASET UPLOAD")
    print("-" * 40)
    print("Supported formats: .jpg, .jpeg, .png, .tiff, .zip")
    print("You can select multiple images or a single ZIP archive")
    print("")
    
    # Upload interface
    uploaded_files = files.upload()
    
    images = []
    filenames = []
    
    if not uploaded_files:
        print("No files uploaded")
        return images, filenames
    
    print(f"\n{len(uploaded_files)} file(s) uploaded")
    
    for filename, file_content in uploaded_files.items():
        try:
            # Handle ZIP archives
            if filename.lower().endswith('.zip'):
                print(f"Extracting ZIP archive: {filename}...")
                
                # Save and extract ZIP
                with open('temp_dataset.zip', 'wb') as f:
                    f.write(file_content)
                
                with zipfile.ZipFile('temp_dataset.zip', 'r') as zip_ref:
                    zip_ref.extractall('custom_dataset')
                
                # Find all images in extracted content
                image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.tiff', '*.tif']
                image_files = []
                
                for ext in image_extensions:
                    image_files.extend(glob.glob(os.path.join('custom_dataset', '**', ext), recursive=True))
                    image_files.extend(glob.glob(os.path.join('custom_dataset', '**', ext.upper()), recursive=True))
                
                print(f"Found {len(image_files)} image(s) in archive")
                
                # Load images from archive
                for img_path in tqdm(image_files, desc="Loading images from archive"):
                    try:
                        image = Image.open(img_path)
                        if image.mode != 'L':
                            image = image.convert('L')
                        
                        img_array = np.array(image)
                        images.append(img_array)
                        filenames.append(os.path.basename(img_path))
                    except Exception as e:
                        print(f"Error loading {img_path}: {e}")
                
                # Cleanup
                os.remove('temp_dataset.zip')
                
            # Handle individual image files
            elif filename.lower().endswith(('.jpg', '.jpeg', '.png', '.tiff', '.tif')):
                image = Image.open(io.BytesIO(file_content))
                
                # Convert to grayscale
                if image.mode != 'L':
                    image = image.convert('L')
                
                img_array = np.array(image)
                images.append(img_array)
                filenames.append(filename)
                
                print(f"Loaded: {filename} - {img_array.shape}")
            else:
                print(f"Unsupported format: {filename}")
                
        except Exception as e:
            print(f"Error processing {filename}: {e}")
    
    print(f"\nTotal: {len(images)} image(s) loaded successfully")
    
    # Display preview
    if images:
        n_preview = min(6, len(images))
        fig, axes = plt.subplots(2, 3, figsize=(15, 10))
        axes = axes.flatten()
        
        sample_indices = np.random.choice(len(images), n_preview, replace=False) if len(images) > n_preview else range(len(images))
        
        for i, idx in enumerate(sample_indices):
            if i < len(axes):
                axes[i].imshow(images[idx], cmap='gray')
                axes[i].set_title(f'{filenames[idx][:20]}...\n{images[idx].shape}')
                axes[i].axis('off')
        
        # Hide unused subplots
        for i in range(len(sample_indices), len(axes)):
            axes[i].axis('off')
        
        plt.suptitle(f'Dataset Preview - {len(images)} Images Loaded')
        plt.tight_layout()
        plt.show()
    
    return images, filenames

# Upload dataset
dataset_images, dataset_filenames = upload_dataset()

### Method 2: Sample Dataset (For Testing)
Generate synthetic X-ray images for testing without your own data.

In [None]:
def create_sample_dataset(n_images=6):
    """
    Create sample dataset with simulated pathology patterns
    """
    sample_images = []
    sample_filenames = []
    
    case_types = [
        ("normal", "Normal chest X-ray"),
        ("cardiomegaly", "Simulated cardiomegaly"),
        ("pneumonia", "Simulated pneumonia"),
        ("pneumothorax", "Simulated pneumothorax"),
        ("infiltration", "Simulated infiltration"),
        ("atelectasis", "Simulated atelectasis")
    ]
    
    np.random.seed(42)  # For reproducibility
    
    for i in range(n_images):
        case_type, description = case_types[i % len(case_types)]
        
        # Base chest X-ray structure
        img = np.random.rand(224, 224) * 0.3 + 0.4
        
        # Basic anatomical structures
        img[50:180, 30:100] *= 0.7   # Left lung
        img[50:180, 124:194] *= 0.7  # Right lung
        img[120:180, 90:134] *= 1.2  # Heart
        
        # Add pathology-specific patterns
        if case_type == "cardiomegaly":
            img[110:190, 80:144] *= 1.4  # Enlarged heart
        elif case_type == "pneumonia":
            img[80:140, 140:180] *= 1.6  # Consolidation
        elif case_type == "pneumothorax":
            img[60:120, 150:190] *= 0.3  # Hyperlucent area
        elif case_type == "infiltration":
            img[70:150, 40:90] *= 1.3    # Left infiltrate
            img[90:160, 130:170] *= 1.2  # Right infiltrate
        elif case_type == "atelectasis":
            img[80:130, 35:85] *= 1.7    # Partial collapse
        
        # Normalize and smooth
        img = np.clip(img, 0, 1)
        img = cv2.GaussianBlur(img, (3, 3), 0)
        
        sample_images.append(img)
        sample_filenames.append(f"sample_{i+1:02d}_{case_type}.png")
    
    return sample_images, sample_filenames

# Option to create sample dataset
use_sample = input("Create sample dataset for testing? (y/n): ").lower().strip()

if use_sample in ['y', 'yes']:
    print("\nCreating sample dataset...")
    sample_images, sample_filenames = create_sample_dataset(6)
    
    # Add to main dataset if empty, otherwise offer to replace
    if not dataset_images:
        dataset_images = sample_images
        dataset_filenames = sample_filenames
        print(f"Sample dataset created: {len(sample_images)} images")
    else:
        choice = input(f"Add to existing {len(dataset_images)} images? (y/n): ").lower().strip()
        if choice in ['y', 'yes']:
            dataset_images.extend(sample_images)
            dataset_filenames.extend(sample_filenames)
            print(f"Added {len(sample_images)} sample images to dataset")
    
    # Display sample dataset
    if sample_images:
        fig, axes = plt.subplots(2, 3, figsize=(15, 10))
        axes = axes.flatten()
        
        for i, (img, filename) in enumerate(zip(sample_images, sample_filenames)):
            axes[i].imshow(img, cmap='gray')
            axes[i].set_title(filename.replace('_', ' '))
            axes[i].axis('off')
        
        plt.suptitle('Sample Dataset Generated')
        plt.tight_layout()
        plt.show()
else:
    print("Sample dataset not created")

## Dataset Consolidation and Validation

In [None]:
# Validate and consolidate dataset
print("DATASET CONSOLIDATION")
print("=" * 30)

if not dataset_images:
    print("No images available for analysis!")
    print("Please upload images using one of the methods above")
else:
    print(f"Total dataset: {len(dataset_images)} images")
    
    # Create dataset information DataFrame
    dataset_info = pd.DataFrame({
        'ID': range(len(dataset_images)),
        'Filename': dataset_filenames,
        'Shape': [img.shape for img in dataset_images],
        'Size_MB': [img.nbytes / (1024*1024) for img in dataset_images],
        'Min_Pixel': [img.min() for img in dataset_images],
        'Max_Pixel': [img.max() for img in dataset_images],
        'Mean_Pixel': [img.mean() for img in dataset_images]
    })
    
    print("\nDATASET INFORMATION:")
    print(dataset_info.to_string(index=False))
    
    # Dataset statistics
    print(f"\nDATASET STATISTICS:")
    print(f"   Total images: {len(dataset_images)}")
    print(f"   Unique image shapes: {len(set([img.shape for img in dataset_images]))}")
    print(f"   Total dataset size: {sum(dataset_info['Size_MB']):.2f} MB")
    print(f"   Average image size: {np.mean(dataset_info['Size_MB']):.2f} MB")
    
    # Validation checks
    print(f"\nVALIDATION CHECKS:")
    
    # Check for very small images
    small_images = [i for i, img in enumerate(dataset_images) if min(img.shape) < 64]
    if small_images:
        print(f"   Warning: {len(small_images)} image(s) very small (< 64px)")
    else:
        print(f"   Image sizes: OK")
    
    # Check pixel value ranges
    unusual_ranges = [i for i, img in enumerate(dataset_images) if img.max() > 255 or img.min() < 0]
    if unusual_ranges:
        print(f"   Warning: {len(unusual_ranges)} image(s) with unusual pixel ranges")
    else:
        print(f"   Pixel ranges: OK")
    
    # Summary visualization
    if len(dataset_images) > 1:
        fig, axes = plt.subplots(1, 3, figsize=(18, 5))
        
        # Image size distribution
        image_sizes = [img.shape[0] * img.shape[1] for img in dataset_images]
        axes[0].hist(image_sizes, bins=min(10, len(set(image_sizes))), alpha=0.7, color='skyblue', edgecolor='black')
        axes[0].set_xlabel('Number of Pixels')
        axes[0].set_ylabel('Number of Images')
        axes[0].set_title('Image Size Distribution')
        axes[0].grid(True, alpha=0.3)
        
        # Pixel intensity distribution
        mean_intensities = [img.mean() for img in dataset_images]
        axes[1].hist(mean_intensities, bins=15, alpha=0.7, color='lightcoral', edgecolor='black')
        axes[1].set_xlabel('Mean Pixel Intensity')
        axes[1].set_ylabel('Number of Images')
        axes[1].set_title('Mean Intensity Distribution')
        axes[1].grid(True, alpha=0.3)
        
        # Contrast distribution (std dev)
        contrasts = [img.std() for img in dataset_images]
        axes[2].hist(contrasts, bins=15, alpha=0.7, color='lightgreen', edgecolor='black')
        axes[2].set_xlabel('Pixel Standard Deviation')
        axes[2].set_ylabel('Number of Images')
        axes[2].set_title('Contrast Distribution')
        axes[2].grid(True, alpha=0.3)
        
        plt.suptitle('Dataset Quality Analysis', fontsize=14, fontweight='bold')
        plt.tight_layout()
        plt.show()
    
    print("\nDataset validated and ready for analysis!")

## Batch Preprocessing

In [None]:
def preprocess_dataset_batch(images, target_size=(224, 224)):
    """
    Batch preprocessing of all images in the dataset
    """
    print("BATCH PREPROCESSING")
    print("=" * 30)
    
    processed_images = []
    processed_tensors = []
    preprocessing_stats = []
    
    print(f"Target size: {target_size}")
    print(f"Images to process: {len(images)}")
    print("")
    
    for i, img in enumerate(tqdm(images, desc="Processing images")):
        try:
            # Store original info
            original_shape = img.shape
            original_range = (img.min(), img.max())
            
            # Resize if necessary
            if img.shape != target_size:
                img_resized = cv2.resize(img, target_size)
            else:
                img_resized = img.copy()
            
            # Normalize to [0, 1]
            if img_resized.max() > 1:
                img_normalized = img_resized.astype(np.float32) / 255.0
            else:
                img_normalized = img_resized.astype(np.float32)
            
            # Z-score normalization for TorchXRayVision
            mean_val = np.mean(img_normalized)
            std_val = np.std(img_normalized)
            
            if std_val > 0:
                img_standardized = (img_normalized - mean_val) / std_val
            else:
                img_standardized = img_normalized - mean_val
            
            # Convert to PyTorch tensor
            img_tensor = torch.FloatTensor(img_standardized)
            img_tensor = img_tensor.unsqueeze(0).unsqueeze(0)  # [1, 1, H, W]
            img_tensor = img_tensor.to(device)
            
            # Store results
            processed_images.append(img_normalized)
            processed_tensors.append(img_tensor)
            
            # Store stats
            stats = {
                'original_shape': original_shape,
                'original_range': original_range,
                'processed_shape': img_normalized.shape,
                'mean': mean_val,
                'std': std_val,
                'final_range': (img_standardized.min(), img_standardized.max())
            }
            preprocessing_stats.append(stats)
            
        except Exception as e:
            print(f"Error processing image {i}: {e}")
            continue
    
    print(f"\n{len(processed_images)}/{len(images)} images processed successfully")
    
    # Display preprocessing statistics
    if preprocessing_stats:
        stats_df = pd.DataFrame(preprocessing_stats)
        
        print(f"\nPREPROCESSING STATISTICS:")
        print(f"   Final shape: {target_size}")
        print(f"   Average mean: {stats_df['mean'].mean():.3f}")
        print(f"   Average std: {stats_df['std'].mean():.3f}")
        print(f"   Standardized range: [{stats_df['final_range'].apply(lambda x: x[0]).mean():.3f}, {stats_df['final_range'].apply(lambda x: x[1]).mean():.3f}]")
    
    return processed_images, processed_tensors, preprocessing_stats

# Apply preprocessing if we have images
if dataset_images:
    processed_images, processed_tensors, prep_stats = preprocess_dataset_batch(dataset_images)
    print("\nDataset ready for TorchXRayVision analysis!")
else:
    print("No images to preprocess")

## Batch Analysis

In [None]:
def analyze_dataset_batch(tensors, filenames, models_dict, threshold=0.3):
    """
    Batch analysis of entire dataset with available models
    """
    print("BATCH DATASET ANALYSIS")
    print("=" * 30)
    
    print(f"Images to analyze: {len(tensors)}")
    print(f"Models available: {list(models_dict.keys())}")
    print(f"Detection threshold: {threshold}")
    print("")
    
    all_results = {}
    analysis_summary = []
    
    # Analyze with each model
    for model_name, model in models_dict.items():
        print(f"Analyzing with {model_name} model...")
        
        model_results = []
        pathologies = model.pathologies
        
        # Analyze each image
        for i, (tensor, filename) in enumerate(tqdm(zip(tensors, filenames), 
                                                   desc=f"Analysis {model_name}",
                                                   total=len(tensors))):
            try:
                with torch.no_grad():
                    outputs = model(tensor)
                    probabilities = torch.sigmoid(outputs).cpu().numpy().flatten()
                
                # Create result for this image
                image_result = {
                    'image_id': i,
                    'filename': filename,
                    'model': model_name
                }
                
                # Add probabilities for each pathology
                for pathology, prob in zip(pathologies, probabilities):
                    image_result[pathology] = prob
                    image_result[f'{pathology}_detected'] = prob > threshold
                
                # Summary statistics
                image_result['total_detections'] = sum(prob > threshold for prob in probabilities)
                image_result['max_probability'] = max(probabilities)
                image_result['avg_probability'] = np.mean(probabilities)
                
                model_results.append(image_result)
                
            except Exception as e:
                print(f"Error with {filename} on {model_name}: {e}")
                continue
        
        all_results[model_name] = model_results
        
        # Model summary
        if model_results:
            total_detections = sum(result['total_detections'] for result in model_results)
            images_with_findings = sum(1 for result in model_results if result['total_detections'] > 0)
            
            summary = {
                'model': model_name,
                'images_analyzed': len(model_results),
                'total_detections': total_detections,
                'avg_detections_per_image': total_detections / len(model_results),
                'images_with_findings': images_with_findings,
                'percentage_with_findings': (images_with_findings / len(model_results)) * 100
            }
            analysis_summary.append(summary)
    
    print(f"\nAnalysis completed for {len(tensors)} images")
    
    # Display summary
    if analysis_summary:
        summary_df = pd.DataFrame(analysis_summary)
        
        print("\nANALYSIS SUMMARY:")
        print(summary_df.to_string(index=False))
        
        # Summary visualization
        if len(analysis_summary) > 1:
            fig, axes = plt.subplots(1, 2, figsize=(15, 6))
            
            # Total detections by model
            axes[0].bar(summary_df['model'], summary_df['total_detections'], alpha=0.7)
            axes[0].set_title('Total Detections by Model')
            axes[0].set_ylabel('Number of Detections')
            axes[0].tick_params(axis='x', rotation=45)
            axes[0].grid(True, alpha=0.3)
            
            # Percentage of images with findings
            axes[1].bar(summary_df['model'], summary_df['percentage_with_findings'], 
                       alpha=0.7, color='orange')
            axes[1].set_title('Percentage of Images with Findings')
            axes[1].set_ylabel('Percentage (%)')
            axes[1].tick_params(axis='x', rotation=45)
            axes[1].grid(True, alpha=0.3)
            
            plt.suptitle('Model Comparison Summary')
            plt.tight_layout()
            plt.show()
    
    return all_results, analysis_summary

# Perform analysis if we have processed tensors
if 'processed_tensors' in locals() and processed_tensors:
    batch_results, batch_summary = analyze_dataset_batch(processed_tensors, dataset_filenames, models)
    print("\nBatch analysis completed successfully!")
else:
    print("No processed tensors available for analysis")

## Detailed Results Analysis

In [None]:
def create_detailed_analysis(batch_results, filenames, threshold=0.3):
    """
    Detailed analysis with comprehensive visualizations
    """
    print("DETAILED RESULTS ANALYSIS")
    print("=" * 30)
    
    if not batch_results:
        print("No results available for analysis")
        return
    
    # Use first model for main analysis
    main_model = list(batch_results.keys())[0]
    main_results = batch_results[main_model]
    
    print(f"Main analysis based on: {main_model}")
    print(f"Images analyzed: {len(main_results)}")
    
    # Extract pathologies
    pathologies = [col for col in main_results[0].keys() 
                  if col not in ['image_id', 'filename', 'model', 'total_detections', 
                               'max_probability', 'avg_probability'] 
                  and not col.endswith('_detected')]
    
    # Pathology frequency analysis
    pathology_counts = {}
    pathology_avg_probs = {}
    
    for pathology in pathologies:
        detections = sum(1 for result in main_results if result.get(f'{pathology}_detected', False))
        avg_prob = np.mean([result.get(pathology, 0) for result in main_results])
        
        pathology_counts[pathology] = detections
        pathology_avg_probs[pathology] = avg_prob
    
    # Sort by frequency
    sorted_pathologies = sorted(pathology_counts.items(), key=lambda x: x[1], reverse=True)
    
    print(f"\nTOP 10 DETECTED PATHOLOGIES:")
    for i, (pathology, count) in enumerate(sorted_pathologies[:10]):
        percentage = (count / len(main_results)) * 100
        avg_prob = pathology_avg_probs[pathology]
        print(f"   {i+1:2d}. {pathology:25s}: {count:3d}/{len(main_results)} ({percentage:5.1f}%) - Avg prob: {avg_prob:.3f}")
    
    # Images with most findings
    print(f"\nIMAGES WITH MOST PATHOLOGIES:")
    sorted_images = sorted(main_results, key=lambda x: x['total_detections'], reverse=True)
    
    for i, result in enumerate(sorted_images[:5]):
        detected_paths = [path for path in pathologies 
                         if result.get(f'{path}_detected', False)]
        print(f"   {i+1}. {result['filename']:30s}: {result['total_detections']} detections")
        if detected_paths:
            print(f"      â†’ {', '.join(detected_paths[:3])}{'...' if len(detected_paths) > 3 else ''}")
    
    # Comprehensive visualization
    create_analysis_visualizations(batch_results, pathologies, threshold)
    
    return sorted_pathologies, sorted_images

def create_analysis_visualizations(batch_results, pathologies, threshold):
    """
    Create comprehensive visualizations for the analysis
    """
    main_model = list(batch_results.keys())[0]
    results = batch_results[main_model]
    
    # Main analysis dashboard
    fig, axes = plt.subplots(2, 3, figsize=(20, 12))
    
    # 1. Detection frequency by pathology
    top_pathologies = sorted(pathologies, 
                           key=lambda p: sum(1 for r in results if r.get(f'{p}_detected', False)), 
                           reverse=True)[:10]
    
    counts = [sum(1 for r in results if r.get(f'{path}_detected', False)) for path in top_pathologies]
    
    axes[0, 0].barh(range(len(top_pathologies)), counts, alpha=0.7)
    axes[0, 0].set_yticks(range(len(top_pathologies)))
    axes[0, 0].set_yticklabels([p[:15] for p in top_pathologies])
    axes[0, 0].set_xlabel('Number of Detections')
    axes[0, 0].set_title('Top 10 Most Frequent Pathologies')
    axes[0, 0].grid(True, alpha=0.3)
    
    # 2. Distribution of detections per image
    detections_per_image = [result['total_detections'] for result in results]
    axes[0, 1].hist(detections_per_image, bins=range(max(detections_per_image)+2), 
                   alpha=0.7, color='skyblue', edgecolor='black')
    axes[0, 1].set_xlabel('Number of Detections')
    axes[0, 1].set_ylabel('Number of Images')
    axes[0, 1].set_title('Distribution of Detections per Image')
    axes[0, 1].grid(True, alpha=0.3)
    
    # 3. Average probability distribution
    avg_probs = [result['avg_probability'] for result in results]
    axes[0, 2].hist(avg_probs, bins=20, alpha=0.7, color='lightgreen', edgecolor='black')
    axes[0, 2].axvline(x=threshold, color='red', linestyle='--', linewidth=2, label=f'Threshold ({threshold})')
    axes[0, 2].set_xlabel('Average Probability')
    axes[0, 2].set_ylabel('Number of Images')
    axes[0, 2].set_title('Average Probability Distribution')
    axes[0, 2].legend()
    axes[0, 2].grid(True, alpha=0.3)
    
    # 4. Severity categorization
    normal_count = sum(1 for r in results if r['total_detections'] == 0)
    mild_count = sum(1 for r in results if 1 <= r['total_detections'] <= 2)
    moderate_count = sum(1 for r in results if 3 <= r['total_detections'] <= 5)
    severe_count = sum(1 for r in results if r['total_detections'] > 5)
    
    categories = ['Normal\n(0)', 'Mild\n(1-2)', 'Moderate\n(3-5)', 'Severe\n(6+)']
    counts_sev = [normal_count, mild_count, moderate_count, severe_count]
    colors_sev = ['green', 'yellow', 'orange', 'red']
    
    axes[1, 0].pie(counts_sev, labels=categories, colors=colors_sev,
                  autopct='%1.1f%%', startangle=90)
    axes[1, 0].set_title('Severity Distribution')
    
    # 5. Correlation: detections vs max probability
    max_probs = [result['max_probability'] for result in results]
    axes[1, 1].scatter(detections_per_image, max_probs, alpha=0.6, s=50)
    axes[1, 1].set_xlabel('Number of Detections')
    axes[1, 1].set_ylabel('Maximum Probability')
    axes[1, 1].set_title('Detections vs Maximum Probability')
    axes[1, 1].grid(True, alpha=0.3)
    
    # 6. Pathology probability vs frequency scatter
    path_freq = []
    path_avg_prob = []
    
    for pathology in top_pathologies:
        freq = sum(1 for r in results if r.get(f'{pathology}_detected', False))
        avg_prob = np.mean([r.get(pathology, 0) for r in results])
        path_freq.append(freq)
        path_avg_prob.append(avg_prob)
    
    axes[1, 2].scatter(path_freq, path_avg_prob, s=100, alpha=0.7)
    axes[1, 2].set_xlabel('Detection Frequency')
    axes[1, 2].set_ylabel('Average Probability')
    axes[1, 2].set_title('Pathology: Frequency vs Probability')
    axes[1, 2].grid(True, alpha=0.3)
    
    plt.suptitle(f'COMPREHENSIVE DATASET ANALYSIS - {len(results)} Images', 
                fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()

# Perform detailed analysis if we have results
if 'batch_results' in locals() and batch_results:
    detailed_pathologies, detailed_images = create_detailed_analysis(batch_results, dataset_filenames)
else:
    print("No results available for detailed analysis")

## Export and Download Results

In [None]:
def export_analysis_results(batch_results, filenames, analysis_summary):
    """
    Export analysis results in multiple formats for clinical and research use
    """
    print("EXPORTING RESULTS")
    print("=" * 20)
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    exported_files = []
    
    # 1. Export detailed results as CSV
    print("Exporting detailed results (CSV)...")
    
    for model_name, results in batch_results.items():
        df = pd.DataFrame(results)
        csv_filename = f'results/analysis_{model_name}_{timestamp}.csv'
        df.to_csv(csv_filename, index=False)
        exported_files.append(csv_filename)
        print(f"   {csv_filename}")
    
    # 2. Export summary report
    print("\nExporting summary report (CSV)...")
    
    summary_df = pd.DataFrame(analysis_summary)
    summary_filename = f'results/summary_{timestamp}.csv'
    summary_df.to_csv(summary_filename, index=False)
    exported_files.append(summary_filename)
    print(f"   {summary_filename}")
    
    # 3. Export complete data as JSON
    print("\nExporting complete data (JSON)...")
    
    json_data = {
        'metadata': {
            'timestamp': timestamp,
            'total_images': len(filenames),
            'models_used': list(batch_results.keys()),
            'analysis_date': datetime.now().isoformat()
        },
        'results': batch_results,
        'summary': analysis_summary,
        'filenames': filenames
    }
    
    json_filename = f'results/complete_analysis_{timestamp}.json'
    with open(json_filename, 'w', encoding='utf-8') as f:
        json.dump(json_data, f, indent=2, ensure_ascii=False, default=str)
    exported_files.append(json_filename)
    print(f"   {json_filename}")
    
    # 4. Generate clinical report
    print("\nGenerating clinical report (TXT)...")
    
    report_filename = f'results/clinical_report_{timestamp}.txt'
    generate_clinical_report(batch_results, analysis_summary, report_filename, timestamp)
    exported_files.append(report_filename)
    print(f"   {report_filename}")
    
    # 5. Create download package
    print("\nCreating download package (ZIP)...")
    
    zip_filename = f'analysis_package_{timestamp}.zip'
    with zipfile.ZipFile(zip_filename, 'w') as zipf:
        for filepath in exported_files:
            zipf.write(filepath, os.path.basename(filepath))
    
    print(f"   {zip_filename}")
    
    # 6. Download interface
    print("\nDOWNLOAD OPTIONS:")
    print("1. Complete package (ZIP)")
    print("2. Summary report only (CSV)")
    print("3. Clinical report only (TXT)")
    print("4. Raw data (JSON)")
    
    choice = input("\nSelect download option (1-4) or 'all': ").strip()
    
    if choice == '1' or choice.lower() == 'all':
        try:
            files.download(zip_filename)
            print(f"Downloaded: {zip_filename}")
        except Exception as e:
            print(f"Download error: {e}")
    elif choice == '2':
        try:
            files.download(summary_filename)
            print(f"Downloaded: {summary_filename}")
        except Exception as e:
            print(f"Download error: {e}")
    elif choice == '3':
        try:
            files.download(report_filename)
            print(f"Downloaded: {report_filename}")
        except Exception as e:
            print(f"Download error: {e}")
    elif choice == '4':
        try:
            files.download(json_filename)
            print(f"Downloaded: {json_filename}")
        except Exception as e:
            print(f"Download error: {e}")
    else:
        print("No download selected")
    
    return {
        'package': zip_filename,
        'summary': summary_filename,
        'report': report_filename,
        'data': json_filename
    }

def generate_clinical_report(batch_results, analysis_summary, filename, timestamp):
    """
    Generate a clinical-style report
    """
    with open(filename, 'w', encoding='utf-8') as f:
        f.write("=" * 80 + "\n")
        f.write("CLINICAL AI ANALYSIS REPORT\n")
        f.write("Custom Dataset Analysis with TorchXRayVision\n")
        f.write("=" * 80 + "\n")
        f.write(f"Analysis Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write(f"Report ID: {timestamp}\n")
        f.write("\n")
        
        # Executive summary
        f.write("EXECUTIVE SUMMARY\n")
        f.write("-" * 40 + "\n")
        
        main_model = list(batch_results.keys())[0]
        main_results = batch_results[main_model]
        total_images = len(main_results)
        
        images_with_findings = sum(1 for result in main_results if result['total_detections'] > 0)
        normal_images = total_images - images_with_findings
        
        f.write(f"Total Images Analyzed: {total_images}\n")
        f.write(f"Normal Images: {normal_images} ({normal_images/total_images*100:.1f}%)\n")
        f.write(f"Images with Findings: {images_with_findings} ({images_with_findings/total_images*100:.1f}%)\n")
        f.write(f"AI Models Used: {', '.join(batch_results.keys())}\n")
        
        # Detailed findings
        f.write("\n\nDETAILED FINDINGS\n")
        f.write("-" * 40 + "\n")
        
        for summary in analysis_summary:
            f.write(f"\nModel: {summary['model']}\n")
            f.write(f"  - Images analyzed: {summary['images_analyzed']}\n")
            f.write(f"  - Total detections: {summary['total_detections']}\n")
            f.write(f"  - Average detections per image: {summary['avg_detections_per_image']:.2f}\n")
            f.write(f"  - Images with findings: {summary['images_with_findings']} ({summary['percentage_with_findings']:.1f}%)\n")
        
        # Priority cases
        high_severity = [r for r in main_results if r['total_detections'] > 5]
        if high_severity:
            f.write(f"\n\nPRIORITY CASES (6+ findings)\n")
            f.write("-" * 40 + "\n")
            for case in high_severity:
                f.write(f"  - {case['filename']}: {case['total_detections']} findings\n")
        
        # Recommendations
        f.write(f"\n\nRECOMMENDATIONS\n")
        f.write("-" * 40 + "\n")
        
        if high_severity:
            f.write(f"1. Priority review recommended for {len(high_severity)} case(s) with multiple findings\n")
        f.write(f"2. Clinical correlation recommended for all positive findings\n")
        f.write(f"3. AI analysis serves as screening tool - final diagnosis requires radiologist review\n")
        f.write(f"4. Consider follow-up imaging for unclear or borderline cases\n")
        
        f.write("\n" + "=" * 80 + "\n")
        f.write("End of Report\n")
        f.write("Note: This AI analysis is for screening purposes only. \n")
        f.write("Clinical correlation and radiologist review are essential.\n")
        f.write("=" * 80 + "\n")

# Export results if available
if 'batch_results' in locals() and batch_results:
    export_info = export_analysis_results(batch_results, dataset_filenames, batch_summary)
    print("\nExport completed successfully!")
    print("\nGenerated files:")
    for key, filename in export_info.items():
        print(f"   {key}: {filename}")
else:
    print("No results available for export")

## Tutorial Summary

### What You Have Accomplished:

**Data Management:**
- Loaded custom medical images efficiently
- Processed datasets in batch for optimal performance
- Validated data quality and consistency

**AI Analysis:**
- Applied state-of-the-art TorchXRayVision models
- Performed multi-model comparison analysis
- Generated comprehensive pathology assessments

**Clinical Integration:**
- Created professional clinical reports
- Exported results in multiple formats
- Established workflows for research and practice

### Skills Acquired:

- **Dataset Integration**: Efficient loading and organization of custom medical datasets
- **Batch Processing**: Automated analysis of multiple images simultaneously
- **Model Application**: Practical use of multiple AI models for comprehensive analysis
- **Results Interpretation**: Understanding and visualizing AI predictions
- **Clinical Reporting**: Generating professional reports for medical use
- **Data Export**: Preparing results for integration with clinical workflows

### Practical Applications:

**Medical Research:**
- Analyze research cohorts and clinical studies
- Validate AI models on local populations
- Generate data for publication and presentation

**Clinical Practice:**
- Screen large volumes of radiological images
- Assist in diagnostic workflow optimization
- Support quality assurance programs

**Educational Use:**
- Create teaching datasets for medical education
- Demonstrate AI capabilities to students
- Support research projects and theses

### Important Reminders:

- **Data Privacy**: Always ensure patient confidentiality and data security
- **Clinical Validation**: AI results require expert medical review
- **Limitations**: Understand model constraints and appropriate use cases
- **Documentation**: Maintain detailed records of methods and findings

### Next Steps:

1. **Practice**: Test the workflow with different types of medical images
2. **Integrate**: Incorporate into your research or clinical practice
3. **Collaborate**: Work with radiologists and clinicians for validation
4. **Advance**: Explore additional AI models and techniques

You now have the practical skills to effectively integrate custom datasets with state-of-the-art medical AI models for both research and clinical applications.