# SCAN PyTorch Example Training

This notebook demonstrates how to use the SCAN library with PyTorch.

SCAN (Self-Confidence Attention Network) is a library designed to extract visual explanations from deep learning models.

## 1. Install Dependencies

Make sure you have the following packages installed:
- torch
- torchvision
- numpy
- matplotlib
- tqdm

In [1]:
# Install dependencies if needed
# !pip install torch torchvision numpy matplotlib tqdm

## 2. Import Libraries

In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet50, ResNet50_Weights
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import requests
from io import BytesIO

from SCAN import SCAN

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


## 3. Load Target Model

We'll use a pretrained ResNet50 model from torchvision.

In [2]:
# Load pretrained ResNet50
target_model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
target_model.eval()

print("Model loaded successfully!")
print(f"Number of parameters: {sum(p.numel() for p in target_model.parameters()):,}")

Model loaded successfully!
Number of parameters: 25,557,032


## 4. Define Preprocessing Function

The preprocessing function should match the one used for the target model.

In [3]:
# ImageNet normalization
def preprocess_input(x):
    """
    Preprocess input images for ResNet.
    Input: torch.Tensor with values in [0, 255], shape (N, C, H, W)
    Output: Normalized tensor
    """
    # Normalize to [0, 1]
    x = x / 255.0
    
    # ImageNet mean and std
    mean = torch.tensor([0.485, 0.456, 0.406], device=x.device).view(1, 3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225], device=x.device).view(1, 3, 1, 1)
    
    return (x - mean) / std

## 5. Prepare Dataset

We'll use the real ImageNet dataset for training and validation.

In [4]:
import os
import xml.etree.ElementTree as ET
from torchvision.datasets import ImageFolder

# ImageNet dataset paths
IMAGENET_PATH = '/root/jupyter/SCAN2/ImageNet'
TRAIN_PATH = os.path.join(IMAGENET_PATH, 'train')
VAL_PATH = os.path.join(IMAGENET_PATH, 'val')
VAL_ANN_PATH = os.path.join(IMAGENET_PATH, 'Annotations/CLS-LOC/val')

# Build synset to index mapping from train folder (alphabetical order)
synset_list = sorted(os.listdir(TRAIN_PATH))
synset_to_idx = {synset: idx for idx, synset in enumerate(synset_list)}
print(f"Number of classes: {len(synset_to_idx)}")
print(f"First 5 synsets: {synset_list[:5]}")

class ImageNetTrainDataset(Dataset):
    """
    ImageNet training dataset wrapper that returns images in [0, 255] range.
    """
    def __init__(self, root, image_size=(224, 224), max_samples=None):
        self.image_size = image_size
        self.transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(image_size[0]),
            transforms.ToTensor(),  # Converts to [0, 1]
        ])
        self.dataset = ImageFolder(root, transform=self.transform)
        self.max_samples = max_samples
        
    def __len__(self):
        if self.max_samples:
            return min(len(self.dataset), self.max_samples)
        return len(self.dataset)
    
    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        # Convert from [0, 1] to [0, 255]
        image = image * 255.0
        return image, label


class ImageNetValDataset(Dataset):
    """
    ImageNet validation dataset with correct labels from XML annotations.
    Returns images in [0, 255] range.
    """
    def __init__(self, img_dir, ann_dir, synset_to_idx, image_size=(224, 224), max_samples=None):
        self.img_dir = img_dir
        self.synset_to_idx = synset_to_idx
        self.samples = []
        
        # Get all validation images
        img_files = sorted([f for f in os.listdir(img_dir) if f.endswith('.JPEG')])
        
        for img_file in img_files:
            # Get synset from XML annotation
            xml_file = img_file.replace('.JPEG', '.xml')
            xml_path = os.path.join(ann_dir, xml_file)
            
            if os.path.exists(xml_path):
                tree = ET.parse(xml_path)
                root = tree.getroot()
                synset = root.find('.//object/name').text
                
                if synset in synset_to_idx:
                    label = synset_to_idx[synset]
                    self.samples.append((img_file, label))
        
        self.transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(image_size[0]),
            transforms.ToTensor(),
        ])
        
        self.max_samples = max_samples
        print(f"Loaded {len(self.samples)} validation samples")
    
    def __len__(self):
        if self.max_samples:
            return min(len(self.samples), self.max_samples)
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_name, label = self.samples[idx]
        img_path = os.path.join(self.img_dir, img_name)
        img = Image.open(img_path).convert('RGB')
        img_tensor = self.transform(img) * 255.0  # Convert to [0, 255]
        return img_tensor, label


def load_sample_image(image_path=None):
    """Load a sample image from ImageNet for testing."""
    if image_path is None:
        # Get a random image from the first class
        first_class = sorted(os.listdir(TRAIN_PATH))[0]
        class_path = os.path.join(TRAIN_PATH, first_class)
        image_name = sorted(os.listdir(class_path))[0]
        image_path = os.path.join(class_path, image_name)
    
    # Load and transform image
    img = Image.open(image_path).convert('RGB')
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
    ])
    img_tensor = transform(img) * 255.0  # Convert to [0, 255]
    return img_tensor, image_path


# Test loading a sample image
test_img, test_path = load_sample_image()
print(f"Sample image loaded from: {test_path}")
print(f"Image shape: {test_img.shape}, min: {test_img.min():.1f}, max: {test_img.max():.1f}")

Number of classes: 1000
First 5 synsets: ['n01440764', 'n01443537', 'n01484850', 'n01491361', 'n01494475']
Sample image loaded from: /root/jupyter/SCAN2/ImageNet/train/n01440764/n01440764_10026.JPEG
Image shape: torch.Size([3, 224, 224]), min: 0.0, max: 255.0


In [5]:
# Create ImageNet train and validation datasets
# For quick testing, limit the number of samples (set to None for full dataset)
MAX_TRAIN_SAMPLES = None   # Use None for full dataset (~1.2M images)
MAX_VAL_SAMPLES = None     # Use None for full validation (50K images)

# Training dataset from train folder
train_dataset = ImageNetTrainDataset(TRAIN_PATH, image_size=(224, 224), max_samples=MAX_TRAIN_SAMPLES)

# Validation dataset from val folder with correct labels from XML annotations
valid_dataset = ImageNetValDataset(
    img_dir=VAL_PATH,
    ann_dir=VAL_ANN_PATH,
    synset_to_idx=synset_to_idx,
    image_size=(224, 224),
    max_samples=MAX_VAL_SAMPLES
)

# Create dataloaders
BATCH_SIZE = 128  # Increased from 16 for better training efficiency
NUM_WORKERS = 4

train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    num_workers=NUM_WORKERS,
    pin_memory=True
)
valid_loader = DataLoader(
    valid_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False, 
    num_workers=NUM_WORKERS,
    pin_memory=True
)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(valid_dataset)}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(valid_loader)}")

Loaded 50000 validation samples
Training samples: 1281167
Validation samples: 50000
Batch size: 128
Training batches: 10010
Validation batches: 391


## 6. Find Target Layer

Let's explore the model architecture to find the appropriate target layer.

In [6]:
# Print model layer names
print("Available layers in ResNet50:")
print("=" * 50)
for name, module in target_model.named_modules():
    if name:  # Skip empty name (the model itself)
        print(f"{name}: {module.__class__.__name__}")

Available layers in ResNet50:
conv1: Conv2d
bn1: BatchNorm2d
relu: ReLU
maxpool: MaxPool2d
layer1: Sequential
layer1.0: Bottleneck
layer1.0.conv1: Conv2d
layer1.0.bn1: BatchNorm2d
layer1.0.conv2: Conv2d
layer1.0.bn2: BatchNorm2d
layer1.0.conv3: Conv2d
layer1.0.bn3: BatchNorm2d
layer1.0.relu: ReLU
layer1.0.downsample: Sequential
layer1.0.downsample.0: Conv2d
layer1.0.downsample.1: BatchNorm2d
layer1.1: Bottleneck
layer1.1.conv1: Conv2d
layer1.1.bn1: BatchNorm2d
layer1.1.conv2: Conv2d
layer1.1.bn2: BatchNorm2d
layer1.1.conv3: Conv2d
layer1.1.bn3: BatchNorm2d
layer1.1.relu: ReLU
layer1.2: Bottleneck
layer1.2.conv1: Conv2d
layer1.2.bn1: BatchNorm2d
layer1.2.conv2: Conv2d
layer1.2.bn2: BatchNorm2d
layer1.2.conv3: Conv2d
layer1.2.bn3: BatchNorm2d
layer1.2.relu: ReLU
layer2: Sequential
layer2.0: Bottleneck
layer2.0.conv1: Conv2d
layer2.0.bn1: BatchNorm2d
layer2.0.conv2: Conv2d
layer2.0.bn2: BatchNorm2d
layer2.0.conv3: Conv2d
layer2.0.bn3: BatchNorm2d
layer2.0.relu: ReLU
layer2.0.downsample: S

In [7]:
# Common target layers for ResNet50:
# - 'layer4': Last convolutional block (recommended)
# - 'layer4.2.conv3': Last conv layer before avgpool
# - 'layer3': Earlier features (more detailed but noisier)

target_layer_name = 'layer4'  # Recommended for most use cases
print(f"Using target layer: {target_layer_name}")

Using target layer: layer4


## 7. Initialize SCAN

Create the SCAN object with the target model and layer.

In [8]:
# Initialize SCAN
scanner = SCAN(
    target_model=target_model,
    target_layer=target_layer_name,
    image_size=(224, 224),
    use_gradient_mask=True,
    device=device,
    num_classes=1000  # ImageNet classes
)

# Set preprocessing function
scanner.set_preprocess(preprocess_input)

print("SCAN initialized successfully!")

SCAN initialized successfully!


## 8. Setup Training

Configure the decoder model, optimizer, and loss function.

In [9]:
# Set datasets (use_augmentation=(70, 100) is the paper default for training)
scanner.set_dataset(train_loader, use_augmentation=(70, 100))
scanner.set_validation_dataset(valid_loader)

# Generate decoder (convolutional decoder for CNN features)
scanner.generate_decoder(is_Transformer=False)

# Training configuration
EPOCHS = 2
LEARNING_RATE = 1e-3

# Compile with optimizer, loss, and Cosine Annealing LR scheduler
scanner.compile(
    loss_alpha=4.0,
    optimizer_class=torch.optim.Adam,
    learning_rate=LEARNING_RATE,
    scheduler_class=torch.optim.lr_scheduler.CosineAnnealingLR,
    scheduler_kwargs={'T_max': EPOCHS, 'eta_min': 1e-6}
)

print("Training setup complete!")
print(f"Decoder parameters: {sum(p.numel() for p in scanner.decoder.parameters()):,}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Learning rate: {LEARNING_RATE} with Cosine Annealing (eta_min=1e-6)")
print(f"Epochs: {EPOCHS}")

Training setup complete!
Decoder parameters: 8,257,668
Batch size: 128
Learning rate: 0.001 with Cosine Annealing (eta_min=1e-6)
Epochs: 2


## 9. Train SCAN Decoder

Train the decoder to learn visual explanations.

In [None]:
import random
import os
from tqdm import tqdm
from IPython.display import display, clear_output, Image as IPImage, HTML
from io import BytesIO

# Create directory for saving visualization snapshots
VIS_DIR = './training_snapshots'
os.makedirs(VIS_DIR, exist_ok=True)

def create_validation_figure(scanner, valid_loader, num_samples=4, percentile=95, save_path=None):
    """Create and save a figure with SCAN results on random validation samples."""
    scanner.decoder.eval()
    
    # Get random validation samples
    indices = random.sample(range(len(valid_loader.dataset)), num_samples)
    images = []
    labels = []
    for idx in indices:
        img, label = valid_loader.dataset[idx]
        images.append(img)
        labels.append(label)
    
    batch_images = torch.stack(images).to(scanner.device)
    
    # Generate SCAN explanations
    with torch.no_grad():
        confidence_maps, reconstructed_images = scanner(batch_images, percentile=percentile)
    
    # Create visualization
    fig, axes = plt.subplots(num_samples, 4, figsize=(16, 4 * num_samples))
    
    for i in range(num_samples):
        # Original image
        original = batch_images[i].cpu().permute(1, 2, 0).numpy().astype(np.uint8)
        axes[i, 0].imshow(original)
        axes[i, 0].set_title(f'Original (class: {labels[i]})')
        axes[i, 0].axis('off')
        
        # Reconstructed image
        axes[i, 1].imshow(reconstructed_images[i].cpu().numpy())
        axes[i, 1].set_title('Reconstructed')
        axes[i, 1].axis('off')
        
        # Confidence map
        conf_map = confidence_maps[i].cpu().numpy()
        im = axes[i, 2].imshow(conf_map, cmap='jet', vmin=0, vmax=1)
        axes[i, 2].set_title(f'Confidence (min:{conf_map.min():.2f}, max:{conf_map.max():.2f})')
        axes[i, 2].axis('off')
        
        # Overlay
        axes[i, 3].imshow(original)
        axes[i, 3].imshow(conf_map, cmap='jet', alpha=0.5, vmin=0, vmax=1)
        axes[i, 3].set_title('Overlay')
        axes[i, 3].axis('off')
    
    plt.tight_layout()
    
    if save_path:
        fig.savefig(save_path, dpi=100, bbox_inches='tight')
        plt.close(fig)
        return save_path
    else:
        plt.show()
        plt.close(fig)
        return None


def train_with_visualization(scanner, train_loader, valid_loader, epochs=2, 
                             visualize_every=1000, num_vis_samples=4):
    """
    Custom training loop with periodic visualization of validation samples.
    Saves snapshots to files and displays the latest one.
    
    Args:
        scanner: SCAN instance
        train_loader: Training DataLoader
        valid_loader: Validation DataLoader  
        epochs: Number of training epochs
        visualize_every: Visualize every N batches
        num_vis_samples: Number of samples to visualize
    
    Returns:
        history: Training history dict
        snapshot_paths: List of saved snapshot file paths
    """
    history = {'loss': [], 'val_loss': [], 'ConfMAE_Metric': [], 'NoConfMAE_Metric': [], 'lr': []}
    snapshot_paths = []
    
    print(f"Training with visualization every {visualize_every} batches")
    print(f"Snapshots will be saved to: {VIS_DIR}/")
    print("=" * 60)
    
    global_batch = 0
    
    for epoch in range(epochs):
        scanner.decoder.train()
        epoch_loss = 0.0
        num_batches = 0
        
        # Reset metrics
        for metric in scanner.metrics:
            metric.reset()
        
        # Get current learning rate
        current_lr = scanner.optimizer.param_groups[0]['lr']
        
        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs} (LR: {current_lr:.2e})')
        
        for batch_idx, (images, labels) in enumerate(pbar):
            images = images.to(scanner.device).float()
            labels = labels.to(scanner.device)
            
            # Use SCAN's internal _process_batch method with labels for correct gradient computation
            feature_maps, target_images = scanner._process_batch(images, labels=labels, use_augmentation=True)
            
            # Forward pass
            scanner.optimizer.zero_grad()
            outputs = scanner.decoder(feature_maps)
            loss = scanner.criterion(outputs, target_images)
            
            # Backward pass
            loss.backward()
            scanner.optimizer.step()
            
            # Update metrics
            epoch_loss += loss.item()
            num_batches += 1
            global_batch += 1

            with torch.no_grad():
                for metric in scanner.metrics:
                    metric.update(target_images, outputs)

            # Build postfix with loss, metrics, and lr
            postfix = {'loss': f'{loss.item():.4f}'}
            for metric in scanner.metrics:
                metric_name = metric.__class__.__name__
                display_name = 'CMAE' if 'ConfMAE' in metric_name and 'NoConf' not in metric_name else 'NCMAE' if 'NoConf' in metric_name else metric_name
                postfix[display_name] = f'{metric.compute():.4f}'
            postfix['lr'] = f'{current_lr:.2e}'
            pbar.set_postfix(postfix)
            
            # Visualize every N batches
            if global_batch % visualize_every == 0:
                save_path = os.path.join(VIS_DIR, f'snapshot_batch_{global_batch:06d}.png')
                create_validation_figure(scanner, valid_loader, 
                                        num_samples=num_vis_samples,
                                        save_path=save_path)
                snapshot_paths.append(save_path)
                print(f"\n[Batch {global_batch}] Snapshot saved: {save_path}")
                
                # Display the latest snapshot
                clear_output(wait=True)
                display(HTML(f"<h3>Latest Snapshot (Batch {global_batch})</h3>"))
                display(IPImage(filename=save_path))
                
                # Re-create progress bar after clear_output
                pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs} (LR: {current_lr:.2e})', 
                           initial=batch_idx+1, total=len(train_loader))
                scanner.decoder.train()
        
        # Epoch summary
        avg_loss = epoch_loss / max(num_batches, 1)
        history['loss'].append(avg_loss)
        history['lr'].append(current_lr)
        
        for metric in scanner.metrics:
            history[metric.__class__.__name__].append(metric.compute())
        
        print(f"\nEpoch {epoch+1}/{epochs} - Loss: {avg_loss:.4f}", end="")
        for metric in scanner.metrics:
            print(f" - {metric.__class__.__name__}: {history[metric.__class__.__name__][-1]:.4f}", end="")
        print(f" - LR: {current_lr:.2e}")
        
        # Validation at end of epoch
        if valid_loader is not None:
            print("Running validation...")
            val_loss, val_metrics = scanner._validate()
            history['val_loss'].append(val_loss)
            print(f"Validation Loss: {val_loss:.4f}")
        
        # Step the learning rate scheduler after each epoch
        if scanner.scheduler is not None:
            scanner.scheduler.step()
        
        scanner.decoder.train()
    
    print(f"\nTraining complete! {len(snapshot_paths)} snapshots saved to {VIS_DIR}/")
    return history, snapshot_paths


# Train with visualization every 1000 batches
history, snapshot_paths = train_with_visualization(
    scanner, 
    train_loader, 
    valid_loader, 
    epochs=EPOCHS, 
    visualize_every=1000,
    num_vis_samples=4
)

In [None]:
# Plot training history
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Loss plot
axes[0].plot(history['loss'], label='Train Loss')
if history['val_loss']:
    axes[0].plot(history['val_loss'], label='Val Loss')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Loss')
axes[0].legend()

# Metrics plot
if 'ConfMAE_Metric' in history:
    axes[1].plot(history['ConfMAE_Metric'], label='Confident MAE')
if 'NoConfMAE_Metric' in history:
    axes[1].plot(history['NoConfMAE_Metric'], label='Not Confident MAE')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('MAE')
axes[1].set_title('Training Metrics')
axes[1].legend()

# Learning rate plot
if 'lr' in history and history['lr']:
    axes[2].plot(history['lr'], label='Learning Rate', color='green')
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('Learning Rate')
    axes[2].set_title('Learning Rate Schedule (Cosine Annealing)')
    axes[2].set_yscale('log')
    axes[2].legend()

plt.tight_layout()
plt.show()

# Note: You can revisit training visualizations anytime using:
# vis_carousel.show()
print(f"\nTotal visualization snapshots saved: {len(vis_carousel.snapshots)}")
print("Use 'vis_carousel.show()' to browse through training snapshots again.")

## 10. Generate Visual Explanations

Use the trained SCAN to generate visual explanations for images.

In [None]:
# Load sample images from validation set for testing
import random

def load_random_val_images(valid_dataset, num_images=5):
    """Load random images from ImageNet validation set."""
    indices = random.sample(range(len(valid_dataset)), min(num_images, len(valid_dataset)))
    
    images = []
    labels = []
    for idx in indices:
        img, label = valid_dataset[idx]
        images.append(img)
        labels.append(label)
    
    return images, labels

# Load test images from validation set
test_images, test_labels = load_random_val_images(valid_dataset, 5)
test_image = test_images[0]
print(f"Loaded {len(test_images)} test images from validation set")
print(f"Test image shape: {test_image.shape}")
print(f"Test labels: {test_labels}")

In [None]:
# Generate visual explanation
confidence_map, reconstructed_image = scanner(test_image, percentile=95)

print(f"Confidence map shape: {confidence_map.shape}")
print(f"Reconstructed image shape: {reconstructed_image.shape}")

In [None]:
# Visualize results
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Original image
axes[0].imshow(test_image.permute(1, 2, 0).numpy().astype(np.uint8))
axes[0].set_title('Original Image')
axes[0].axis('off')

# Reconstructed image
axes[1].imshow(reconstructed_image.numpy())
axes[1].set_title('Reconstructed Image')
axes[1].axis('off')

# Confidence map
im = axes[2].imshow(confidence_map.numpy(), cmap='jet')
axes[2].set_title('Confidence Map')
axes[2].axis('off')
plt.colorbar(im, ax=axes[2], fraction=0.046, pad=0.04)

plt.tight_layout()
plt.show()

In [None]:
# Overlay confidence map on original image
fig, ax = plt.subplots(1, 1, figsize=(8, 8))

# Show original image
original = test_image.permute(1, 2, 0).numpy().astype(np.uint8)
ax.imshow(original)

# Overlay confidence map
conf_map = confidence_map.numpy()
ax.imshow(conf_map, cmap='jet', alpha=0.5)

ax.set_title('Visual Explanation Overlay')
ax.axis('off')

plt.tight_layout()
plt.show()

## 11. Save and Load Decoder

Save the trained decoder for later use.

In [None]:
# Save decoder
scanner.save_decoder('scan_decoder_resnet50.pt')
print("Decoder saved to 'scan_decoder_resnet50.pt'")

In [None]:
# Load decoder (for demonstration)
# scanner.load_decoder('scan_decoder_resnet50.pt')
# print("Decoder loaded successfully!")

## 12. Batch Processing

Generate explanations for multiple images at once.

In [None]:
# Create batch of images from ImageNet
batch_images = torch.stack(test_images)
print(f"Batch shape: {batch_images.shape}")

# Generate explanations for batch
confidence_maps, reconstructed_images = scanner(batch_images, percentile=95)

print(f"Confidence maps shape: {confidence_maps.shape}")
print(f"Reconstructed images shape: {reconstructed_images.shape}")

In [None]:
# Visualize batch results
num_images = len(test_images)
fig, axes = plt.subplots(num_images, 3, figsize=(12, 4 * num_images))

for i in range(num_images):
    # Original
    axes[i, 0].imshow(batch_images[i].permute(1, 2, 0).numpy().astype(np.uint8))
    axes[i, 0].set_title(f'Original {i+1}')
    axes[i, 0].axis('off')
    
    # Reconstructed
    axes[i, 1].imshow(reconstructed_images[i].numpy())
    axes[i, 1].set_title(f'Reconstructed {i+1}')
    axes[i, 1].axis('off')
    
    # Confidence map
    axes[i, 2].imshow(confidence_maps[i].numpy(), cmap='jet')
    axes[i, 2].set_title(f'Confidence {i+1}')
    axes[i, 2].axis('off')

plt.tight_layout()
plt.show()

## Notes

### Dataset Configuration
This notebook uses the ImageNet dataset:
- **Training**: `/root/jupyter/SCAN2/ImageNet/train` (ImageFolder with synset-based labels)
- **Validation**: `/root/jupyter/SCAN2/ImageNet/val` with XML annotations from `/root/jupyter/SCAN2/ImageNet/Annotations/CLS-LOC/val`

The validation labels are extracted from XML annotation files to ensure correct class mapping (alphabetical synset order).

To adjust the amount of data used:
```python
MAX_TRAIN_SAMPLES = 10000  # Set to None for full dataset (~1.2M images)
MAX_VAL_SAMPLES = 5000     # Set to None for full validation (~50K images)
BATCH_SIZE = 128           # Default: 128 (increase if GPU memory allows)
```

### Training Configuration
Current settings:
- **Batch Size**: 128
- **Learning Rate**: 1e-3 with Cosine Annealing scheduler
- **LR Scheduler**: CosineAnnealingLR (eta_min=1e-6)
- **Epochs**: 2 (increase to 5-10 for better results)

```python
scanner.compile(
    loss_alpha=4.0,
    optimizer_class=torch.optim.Adam,
    learning_rate=1e-3,
    scheduler_class=torch.optim.lr_scheduler.CosineAnnealingLR,
    scheduler_kwargs={'T_max': EPOCHS, 'eta_min': 1e-6}
)
```

### Visualization Carousel (iPyWidgets)
The training loop now includes an interactive carousel widget for browsing visualizations:
- **◀ Previous / Next ▶** buttons to navigate between snapshots
- **Slider** for quick navigation to any step
- **Auto-updates** during training with new snapshots

```python
history, vis_carousel = train_with_visualization(
    scanner, 
    train_loader, 
    valid_loader, 
    epochs=2, 
    visualize_every=1000,  # Capture snapshot every 1000 batches
    num_vis_samples=4      # Number of samples per snapshot
)

# Re-display carousel anytime after training:
vis_carousel.show()
```

Each snapshot shows 4 columns:
1. **Original**: Input image with class label
2. **Reconstructed**: Decoder output (blurred target)
3. **Confidence Map**: Self-confidence map (jet colormap)
4. **Overlay**: Confidence map overlaid on original

### For Better Results:
1. **Full Dataset**: Set `MAX_TRAIN_SAMPLES = None` to use the entire ImageNet training set
2. **More Epochs**: Train for more epochs (5-10 recommended)
3. **Larger Batch Size**: Use larger batches if GPU memory allows (256, 512)
4. **Learning Rate Schedule**: Cosine Annealing is already applied

### Choosing Target Layer:
- **layer4**: Best for high-level semantic features (recommended)
- **layer3**: More detailed but noisier
- **layer2**: Very detailed, may be too noisy

### For Transformer Models:
If using Vision Transformer (ViT) or similar:
```python
scanner.generate_decoder(is_Transformer=True)
```

### Memory Tips:
- If you run out of GPU memory, reduce `BATCH_SIZE`
- Use `NUM_WORKERS=0` if you encounter multiprocessing issues