# ImgAE-Dx T4 GPU Optimized Training

This notebook implements T4-optimized training for medical image anomaly detection.

## Features:
- Tesla T4 GPU optimization (16GB VRAM)
- Kaggle dataset integration or synthetic data generation
- Mixed precision training (AMP)
- Google Drive checkpointing
- U-Net and Reversed Autoencoder models

## 1. Setup Colab Environment

In [None]:
# Check GPU and mount Google Drive
import torch
import subprocess
import os

# Check GPU
if torch.cuda.is_available():
    gpu_info = !nvidia-smi --query-gpu=name,memory.total --format=csv,noheader
    print(f"GPU: {gpu_info[0]}")
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"PyTorch Version: {torch.__version__}")
else:
    print("⚠️ No GPU detected! Please enable GPU in Runtime > Change runtime type")

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
print("✅ Google Drive mounted")

## 2. Install ImgAE-Dx

In [None]:
# Clone repository if not exists
if not os.path.exists('/content/ImgAE-Dx'):
    !git clone https://github.com/luanbhk/ImgAE-Dx.git /content/ImgAE-Dx
    %cd /content/ImgAE-Dx
else:
    %cd /content/ImgAE-Dx
    !git pull

# Install dependencies
!pip install -e .
!pip install datasets transformers accelerate
!pip install wandb --upgrade

print("✅ ImgAE-Dx installed")

## 3. Configuration

In [None]:
# T4-optimized configuration
CONFIG = {
    # Model settings
    'model_type': 'unet',  # 'unet', 'reversed_ae', or 'both'
    
    # Dataset settings - Updated to use a working dataset
    'dataset_source': 'local',  # 'local' or 'kaggle'
    'kaggle_dataset': 'nih-chest-xrays/data',  # NIH Chest X-ray dataset
    'kaggle_username': '',  # Your Kaggle username
    'kaggle_key': '',  # Your Kaggle API key
    
    # Training settings (T4-optimized)
    'samples': 3000,
    'epochs': 20,
    'batch_size': 48,  # T4-optimized with mixed precision
    'learning_rate': 1e-4,
    'image_size': 128,
    
    # T4 optimizations
    'mixed_precision': True,
    'memory_limit': 14,  # GB (leave 2GB for system)
    'gradient_accumulation_steps': 1,
    'num_workers': 4,
    
    # Checkpointing
    'checkpoint_dir': '/content/drive/MyDrive/imgae_dx_checkpoints',
    'save_frequency': 5,  # Save every N epochs
    'resume_from_checkpoint': None,  # Path to checkpoint to resume from
    
    # Logging
    'use_wandb': True,
    'wandb_project': 'imgae-dx-t4-colab',
    'wandb_run_name': None,  # Auto-generated if None
}

# Create directories
os.makedirs(CONFIG['checkpoint_dir'], exist_ok=True)
os.makedirs('/content/outputs/logs', exist_ok=True)
os.makedirs('/content/outputs/checkpoints', exist_ok=True)
os.makedirs('/content/data', exist_ok=True)

print("Configuration set!")
print(f"Model: {CONFIG['model_type']}")
print(f"Dataset source: {CONFIG['dataset_source']}")
print(f"Batch size: {CONFIG['batch_size']} (T4-optimized)")
print(f"Mixed precision: {CONFIG['mixed_precision']}")

## 4. Setup Weights & Biases (Optional)

In [None]:
if CONFIG['use_wandb']:
    import wandb
    
    # Login to W&B (you'll need to paste your API key)
    wandb.login()
    
    # Initialize run
    run_name = CONFIG['wandb_run_name'] or f"{CONFIG['model_type']}_t4_{CONFIG['samples']}samples"
    wandb.init(
        project=CONFIG['wandb_project'],
        name=run_name,
        config=CONFIG
    )
    print(f"✅ W&B initialized: {run_name}")
else:
    print("W&B logging disabled")

## 5. Load HuggingFace Dataset

In [None]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image
import os
import glob
from pathlib import Path

# Define transforms
transform = transforms.Compose([
    transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Custom dataset for chest X-rays
class ChestXrayDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        return image

# Load dataset based on source
if CONFIG['dataset_source'] == 'kaggle':
    # Load from downloaded Kaggle data
    data_dir = '/content/data'
    
    # Find CSV file with labels
    csv_files = glob.glob(f"{data_dir}/*.csv")
    if csv_files:
        # Load metadata
        df = pd.read_csv(csv_files[0])
        print(f"Loaded metadata from: {csv_files[0]}")
        print(f"Total images in metadata: {len(df)}")
        
        # Filter for normal images (No Finding)
        if 'Finding Labels' in df.columns:
            normal_df = df[df['Finding Labels'] == 'No Finding']
        else:
            # If no labels, use all images
            normal_df = df
            print("Warning: No labels found, using all images")
        
        # Get image paths
        if 'Image Index' in normal_df.columns:
            image_names = normal_df['Image Index'].values
        else:
            image_names = normal_df.iloc[:, 0].values  # First column
        
        # Find actual image files
        image_paths = []
        for img_name in image_names[:CONFIG['samples']]:
            # Search for image in subdirectories
            pattern = f"{data_dir}/**/{img_name}"
            matches = glob.glob(pattern, recursive=True)
            if matches:
                image_paths.append(matches[0])
        
        print(f"Found {len(image_paths)} images")
    else:
        # No CSV, directly load images
        print("No CSV metadata found, loading images directly")
        image_extensions = ['*.png', '*.jpg', '*.jpeg']
        image_paths = []
        for ext in image_extensions:
            image_paths.extend(glob.glob(f"{data_dir}/**/{ext}", recursive=True))
        image_paths = image_paths[:CONFIG['samples']]
        print(f"Found {len(image_paths)} images")
else:
    # For demo: create synthetic data
    print("Creating synthetic chest X-ray data for demo...")
    
    # Create synthetic images directory
    synthetic_dir = '/content/data/synthetic'
    os.makedirs(synthetic_dir, exist_ok=True)
    
    # Generate synthetic chest X-ray-like images
    import numpy as np
    from scipy import ndimage
    
    image_paths = []
    for i in range(min(CONFIG['samples'], 100)):  # Limit to 100 for demo
        # Create synthetic chest X-ray pattern
        size = 256
        
        # Create base pattern
        x = np.linspace(-1, 1, size)
        y = np.linspace(-1, 1, size)
        X, Y = np.meshgrid(x, y)
        
        # Simulate lung fields (two dark regions)
        left_lung = np.exp(-((X + 0.3)**2 + Y**2) / 0.3)
        right_lung = np.exp(-((X - 0.3)**2 + Y**2) / 0.3)
        lungs = 1 - 0.7 * (left_lung + right_lung)
        
        # Add ribs pattern
        ribs = 0.1 * np.sin(10 * Y) * np.exp(-X**2)
        
        # Add spine
        spine = 0.2 * np.exp(-10 * X**2)
        
        # Combine
        chest_xray = np.clip(lungs + ribs + spine, 0, 1)
        
        # Add noise
        noise = np.random.normal(0, 0.02, chest_xray.shape)
        chest_xray = np.clip(chest_xray + noise, 0, 1)
        
        # Smooth
        chest_xray = ndimage.gaussian_filter(chest_xray, sigma=1)
        
        # Convert to uint8
        chest_xray = (chest_xray * 255).astype(np.uint8)
        
        # Save image
        img_path = f"{synthetic_dir}/synthetic_chest_{i:04d}.png"
        Image.fromarray(chest_xray, mode='L').save(img_path)
        image_paths.append(img_path)
    
    print(f"✅ Generated {len(image_paths)} synthetic chest X-ray images")

# Create dataset and dataloader
if image_paths:
    train_dataset = ChestXrayDataset(image_paths, transform=transform)
    train_loader = DataLoader(
        train_dataset,
        batch_size=CONFIG['batch_size'],
        shuffle=True,
        num_workers=CONFIG['num_workers'],
        pin_memory=True
    )
    print(f"✅ DataLoader created with {len(train_dataset)} images")
    print(f"Batch size: {CONFIG['batch_size']}")
else:
    print("❌ No images found! Please check your dataset configuration.")

## 6. Load Dataset

In [None]:
from datasets import load_dataset
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from PIL import Image
import numpy as np

# Load dataset
print(f"Loading dataset: {CONFIG['hf_dataset']}...")

# Authentication if needed
auth_kwargs = {'use_auth_token': CONFIG['hf_token']} if CONFIG['hf_token'] else {}

# Load with streaming for memory efficiency
dataset = load_dataset(
    CONFIG['hf_dataset'],
    split=CONFIG['hf_split'],
    streaming=True,
    **auth_kwargs
)

# Take only the specified number of samples
dataset = dataset.take(CONFIG['samples'])

print(f"✅ Dataset loaded (streaming mode)")

# Define transforms
transform = transforms.Compose([
    transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Custom dataset wrapper
class HFImageDataset(torch.utils.data.IterableDataset):
    def __init__(self, hf_dataset, transform, samples=None):
        self.dataset = hf_dataset
        self.transform = transform
        self.samples = samples
        
    def __iter__(self):
        count = 0
        for item in self.dataset:
            if self.samples and count >= self.samples:
                break
                
            # Handle different dataset formats
            if 'image' in item:
                image = item['image']
            elif 'img' in item:
                image = item['img']
            else:
                # Try to find image field
                image_keys = [k for k in item.keys() if 'image' in k.lower() or 'img' in k.lower()]
                if image_keys:
                    image = item[image_keys[0]]
                else:
                    continue
            
            # Convert to PIL if needed
            if not isinstance(image, Image.Image):
                image = Image.fromarray(np.array(image))
            
            # Apply transforms
            image_tensor = self.transform(image)
            
            yield image_tensor
            count += 1

# Create dataset instance
train_dataset = HFImageDataset(dataset, transform, CONFIG['samples'])

# Create dataloader
train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG['batch_size'],
    num_workers=CONFIG['num_workers'],
    pin_memory=True
)

print(f"✅ DataLoader created with batch size {CONFIG['batch_size']}")

## 7. Initialize Models

In [None]:
from imgae_dx.models import UNet, ReversedAutoencoder
from imgae_dx.training import AnomalyDetectionTrainer
import torch.nn as nn

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Model initialization based on config
models_to_train = []

if CONFIG['model_type'] in ['unet', 'both']:
    unet = UNet(
        in_channels=1,
        out_channels=1,
        features=[64, 128, 256, 512]
    ).to(device)
    models_to_train.append(('unet', unet))
    print("✅ U-Net initialized")

if CONFIG['model_type'] in ['reversed_ae', 'both']:
    reversed_ae = ReversedAutoencoder(
        in_channels=1,
        latent_dim=128,
        image_size=CONFIG['image_size']
    ).to(device)
    models_to_train.append(('reversed_ae', reversed_ae))
    print("✅ Reversed Autoencoder initialized")

# Count parameters
for name, model in models_to_train:
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"{name}: {trainable_params:,} trainable parameters")

## 8. T4-Optimized Training

In [None]:
import time
from torch.cuda.amp import GradScaler, autocast
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

# T4 optimizations
if torch.cuda.is_available():
    # Enable cuDNN benchmarking for T4
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False
    
    # Set memory fraction
    torch.cuda.set_per_process_memory_fraction(CONFIG['memory_limit'] / 16.0)
    
    print(f"✅ T4 optimizations enabled")
    print(f"Memory limit: {CONFIG['memory_limit']}GB")

def train_model(model_name, model, train_loader, config):
    """T4-optimized training function"""
    print(f"\n🚀 Training {model_name}...")
    
    # Setup optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(), lr=config['learning_rate'])
    criterion = nn.MSELoss()
    
    # Mixed precision setup
    scaler = GradScaler() if config['mixed_precision'] else None
    
    # Training history
    history = {'train_loss': []}
    best_loss = float('inf')
    
    # Training loop
    for epoch in range(config['epochs']):
        model.train()
        epoch_loss = 0
        batch_count = 0
        
        # Progress bar
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']}")
        
        for batch_idx, images in enumerate(pbar):
            images = images.to(device)
            
            # Mixed precision training
            if config['mixed_precision']:
                with autocast():
                    reconstructed = model(images)
                    loss = criterion(reconstructed, images)
                
                # Backward pass with scaler
                optimizer.zero_grad()
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                # Standard training
                reconstructed = model(images)
                loss = criterion(reconstructed, images)
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            
            # Update metrics
            epoch_loss += loss.item()
            batch_count += 1
            
            # Update progress bar
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})
            
            # Log to W&B
            if config['use_wandb'] and batch_idx % 10 == 0:
                wandb.log({
                    f'{model_name}_batch_loss': loss.item(),
                    'epoch': epoch,
                    'batch': batch_idx
                })
        
        # Calculate epoch metrics
        avg_loss = epoch_loss / batch_count
        history['train_loss'].append(avg_loss)
        
        print(f"Epoch {epoch+1}: Loss = {avg_loss:.4f}")
        
        # Log epoch metrics
        if config['use_wandb']:
            wandb.log({
                f'{model_name}_epoch_loss': avg_loss,
                'epoch': epoch + 1
            })
        
        # Save checkpoint
        if (epoch + 1) % config['save_frequency'] == 0 or avg_loss < best_loss:
            checkpoint_path = f"{config['checkpoint_dir']}/{model_name}_epoch_{epoch+1}.pth"
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_loss,
                'config': config
            }, checkpoint_path)
            
            if avg_loss < best_loss:
                best_loss = avg_loss
                # Save best model
                best_path = f"{config['checkpoint_dir']}/{model_name}_best.pth"
                torch.save(model.state_dict(), best_path)
                print(f"✅ Best model saved: {best_path}")
    
    return model, history

# Train models
all_histories = {}
trained_models = {}

for model_name, model in models_to_train:
    start_time = time.time()
    
    # Train model
    trained_model, history = train_model(model_name, model, train_loader, CONFIG)
    
    # Store results
    trained_models[model_name] = trained_model
    all_histories[model_name] = history
    
    # Calculate training time
    training_time = (time.time() - start_time) / 60
    print(f"\n✅ {model_name} training completed in {training_time:.1f} minutes")
    
    # Clear cache between models
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

## 9. Visualize Training Results

In [None]:
# Plot training curves
plt.figure(figsize=(10, 6))

for model_name, history in all_histories.items():
    plt.plot(history['train_loss'], label=f'{model_name} train loss')

plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss Comparison')
plt.legend()
plt.grid(True)
plt.show()

# Save plot
plt.savefig(f'{CONFIG["checkpoint_dir"]}/training_curves.png', dpi=300, bbox_inches='tight')
print("✅ Training curves saved")

## 10. Test Reconstruction Quality

In [None]:
# Test reconstruction on a few samples
def visualize_reconstructions(models, data_loader, num_samples=5):
    """Visualize original and reconstructed images"""
    
    # Get sample batch
    sample_batch = next(iter(data_loader))[:num_samples].to(device)
    
    fig, axes = plt.subplots(len(models) + 1, num_samples, figsize=(15, 3 * (len(models) + 1)))
    
    # Original images
    for i in range(num_samples):
        img = sample_batch[i].cpu().squeeze().numpy()
        axes[0, i].imshow(img, cmap='gray')
        axes[0, i].axis('off')
        if i == 0:
            axes[0, i].set_ylabel('Original', fontsize=12)
    
    # Reconstructions
    for idx, (model_name, model) in enumerate(models.items()):
        model.eval()
        with torch.no_grad():
            recon = model(sample_batch)
        
        for i in range(num_samples):
            img = recon[i].cpu().squeeze().numpy()
            axes[idx + 1, i].imshow(img, cmap='gray')
            axes[idx + 1, i].axis('off')
            if i == 0:
                axes[idx + 1, i].set_ylabel(model_name, fontsize=12)
    
    plt.suptitle('Reconstruction Comparison', fontsize=16)
    plt.tight_layout()
    plt.show()
    
    # Save figure
    plt.savefig(f'{CONFIG["checkpoint_dir"]}/reconstruction_comparison.png', dpi=300, bbox_inches='tight')
    print("✅ Reconstruction comparison saved")

# Visualize reconstructions
if trained_models:
    visualize_reconstructions(trained_models, train_loader)

## 11. Save Final Results

In [None]:
import json
from datetime import datetime

# Prepare summary
summary = {
    'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
    'config': CONFIG,
    'results': {}
}

for model_name, history in all_histories.items():
    summary['results'][model_name] = {
        'final_loss': history['train_loss'][-1],
        'best_loss': min(history['train_loss']),
        'epochs_trained': len(history['train_loss'])
    }

# Save summary
summary_path = f"{CONFIG['checkpoint_dir']}/training_summary.json"
with open(summary_path, 'w') as f:
    json.dump(summary, f, indent=2)

print("\n🎯 Training Summary:")
print("=" * 50)
for model_name, results in summary['results'].items():
    print(f"\n{model_name.upper()}:")
    print(f"  Final Loss: {results['final_loss']:.4f}")
    print(f"  Best Loss: {results['best_loss']:.4f}")
    print(f"  Epochs: {results['epochs_trained']}")

print(f"\n✅ Summary saved to: {summary_path}")
print(f"✅ Checkpoints saved to: {CONFIG['checkpoint_dir']}")

# Finish W&B run
if CONFIG['use_wandb']:
    wandb.finish()
    print("✅ W&B run finished")

## 12. Next Steps

Your models are now trained! Here's what you can do next:

1. **Evaluate on test data**: Load a test dataset and compute anomaly scores
2. **Fine-tune**: Adjust hyperparameters and retrain
3. **Deploy**: Use the saved checkpoints for inference
4. **Compare models**: Run evaluation notebook to compare U-Net vs Reversed AE

### Quick Evaluation

In [None]:
# Quick evaluation setup
print("To evaluate your models:")
print("\n1. Load test dataset with abnormal images")
print("2. Compute reconstruction errors")
print("3. Calculate AUC-ROC scores")
print("\nBest model paths:")
for model_name in trained_models.keys():
    print(f"  {model_name}: {CONFIG['checkpoint_dir']}/{model_name}_best.pth")