# Data Exploration for APD Verification

This notebook explores the three datasets used for experiments:
- CIFAR-10
- Fashion-MNIST
- SVHN

In [None]:
import sys
sys.path.append('..')

import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from datasets import get_cifar10_loaders, get_fashion_mnist_loaders, get_svhn_loaders, get_dataset_info
from utils.visualization import plot_class_distribution

## CIFAR-10 Dataset

In [None]:
# Load CIFAR-10
train_loader, test_loader, class_names = get_cifar10_loaders(batch_size=64)
dataset_info = get_dataset_info('cifar10')

print("CIFAR-10 Dataset Information:")
print(f"Number of classes: {dataset_info['num_classes']}")
print(f"Class names: {dataset_info['classes']}")
print(f"Input shape: {dataset_info['input_shape']}")
print(f"Training samples: {dataset_info['num_train']}")
print(f"Test samples: {dataset_info['num_test']}")

In [None]:
# Display sample images from CIFAR-10
def show_samples(loader, class_names, num_samples=8, title="Sample Images"):
    dataiter = iter(loader)
    images, labels = next(dataiter)
    
    fig, axes = plt.subplots(2, num_samples//2, figsize=(12, 6))
    axes = axes.flatten()
    
    # Denormalize CIFAR-10 images
    mean = torch.tensor([0.4914, 0.4822, 0.4465]).view(3, 1, 1)
    std = torch.tensor([0.2023, 0.1994, 0.2010]).view(3, 1, 1)
    
    for i in range(num_samples):
        img = images[i] * std + mean
        img = torch.clamp(img, 0, 1)
        
        axes[i].imshow(img.permute(1, 2, 0))
        axes[i].set_title(f'{class_names[labels[i]]}')
        axes[i].axis('off')
    
    plt.suptitle(title)
    plt.tight_layout()
    plt.show()

show_samples(train_loader, class_names, num_samples=8, title="CIFAR-10 Sample Images")

In [None]:
# Show class distribution
plot_class_distribution(train_loader, "CIFAR-10", class_names)

## Fashion-MNIST Dataset

In [None]:
# Load Fashion-MNIST
train_loader_fashion, test_loader_fashion, class_names_fashion = get_fashion_mnist_loaders(batch_size=64)
dataset_info_fashion = get_dataset_info('fashion_mnist')

print("Fashion-MNIST Dataset Information:")
print(f"Number of classes: {dataset_info_fashion['num_classes']}")
print(f"Class names: {dataset_info_fashion['classes']}")
print(f"Input shape: {dataset_info_fashion['input_shape']}")
print(f"Training samples: {dataset_info_fashion['num_train']}")
print(f"Test samples: {dataset_info_fashion['num_test']}")

In [None]:
# Display sample images from Fashion-MNIST
def show_fashion_samples(loader, class_names, num_samples=8):
    dataiter = iter(loader)
    images, labels = next(dataiter)
    
    fig, axes = plt.subplots(2, num_samples//2, figsize=(12, 6))
    axes = axes.flatten()
    
    for i in range(num_samples):
        # Denormalize
        img = images[i] * 0.5 + 0.5
        
        axes[i].imshow(img.squeeze(), cmap='gray')
        axes[i].set_title(f'{class_names[labels[i]]}')
        axes[i].axis('off')
    
    plt.suptitle("Fashion-MNIST Sample Images")
    plt.tight_layout()
    plt.show()

show_fashion_samples(train_loader_fashion, class_names_fashion)

## SVHN Dataset

In [None]:
# Load SVHN
train_loader_svhn, test_loader_svhn, class_names_svhn = get_svhn_loaders(batch_size=64)
dataset_info_svhn = get_dataset_info('svhn')

print("SVHN Dataset Information:")
print(f"Number of classes: {dataset_info_svhn['num_classes']}")
print(f"Class names: {dataset_info_svhn['classes']}")
print(f"Input shape: {dataset_info_svhn['input_shape']}")
print(f"Training samples: {dataset_info_svhn['num_train']}")
print(f"Test samples: {dataset_info_svhn['num_test']}")

In [None]:
# Display SVHN samples
def show_svhn_samples(loader, class_names, num_samples=8):
    dataiter = iter(loader)
    images, labels = next(dataiter)
    
    fig, axes = plt.subplots(2, num_samples//2, figsize=(12, 6))
    axes = axes.flatten()
    
    for i in range(num_samples):
        # Denormalize
        img = images[i] * 0.5 + 0.5
        img = torch.clamp(img, 0, 1)
        
        axes[i].imshow(img.permute(1, 2, 0))
        axes[i].set_title(f'Digit: {class_names[labels[i]]}')
        axes[i].axis('off')
    
    plt.suptitle("SVHN Sample Images")
    plt.tight_layout()
    plt.show()

show_svhn_samples(train_loader_svhn, class_names_svhn)

## Data Augmentation Verification

In [None]:
# Show effect of data augmentation on CIFAR-10
def show_augmentation_effects(loader, idx=0):
    """Show the same image with different augmentations."""
    fig, axes = plt.subplots(2, 4, figsize=(12, 6))
    axes = axes.flatten()
    
    # Get multiple augmented versions of similar images
    for i in range(8):
        images, labels = next(iter(loader))
        
        # Denormalize
        mean = torch.tensor([0.4914, 0.4822, 0.4465]).view(3, 1, 1)
        std = torch.tensor([0.2023, 0.1994, 0.2010]).view(3, 1, 1)
        img = images[idx] * std + mean
        img = torch.clamp(img, 0, 1)
        
        axes[i].imshow(img.permute(1, 2, 0))
        axes[i].set_title(f'Aug {i+1}')
        axes[i].axis('off')
    
    plt.suptitle("Data Augmentation Effects (Random Crop + Horizontal Flip)")
    plt.tight_layout()
    plt.show()

print("Note: Each image shows a different random augmentation of similar samples")
show_augmentation_effects(train_loader)

## Batch Statistics

In [None]:
# Verify batch shapes and statistics
def check_batch_stats(loader, dataset_name):
    images, labels = next(iter(loader))
    
    print(f"\n{dataset_name} Batch Statistics:")
    print(f"Batch shape: {images.shape}")
    print(f"Labels shape: {labels.shape}")
    print(f"Image data type: {images.dtype}")
    print(f"Label data type: {labels.dtype}")
    print(f"Image value range: [{images.min():.3f}, {images.max():.3f}]")
    print(f"Unique labels in batch: {torch.unique(labels).tolist()}")
    print(f"Mean: {images.mean():.3f}, Std: {images.std():.3f}")

check_batch_stats(train_loader, "CIFAR-10")
check_batch_stats(train_loader_fashion, "Fashion-MNIST")
check_batch_stats(train_loader_svhn, "SVHN")

## Summary

This notebook has explored the three datasets that will be used for APD verification:

1. **CIFAR-10**: 32x32 RGB images of 10 object categories
2. **Fashion-MNIST**: 28x28 grayscale images of 10 fashion items
3. **SVHN**: 32x32 RGB images of house number digits

Key observations:
- All datasets have 10 classes but different image properties
- CIFAR-10 and SVHN use RGB images while Fashion-MNIST is grayscale
- Data augmentation (random crop and horizontal flip) is applied to CIFAR-10 training
- All images are normalized for better training stability