# Imbalanced Datasets Experiments - CNN vs Vision Transformer

This notebook supports:
- **Existing imbalanced datasets**: CIFAR-10-LT, CIFAR-100-LT (from Hugging Face)
- **Custom imbalanced datasets**: Create your own imbalanced CIFAR-10/100

Compare CNN (ResNet) and Vision Transformer (ViT) performance on various imbalanced datasets.

## Step 1: Setup and Install Dependencies

In [None]:
# Install required packages
!pip install torch torchvision torchmetrics pyyaml matplotlib seaborn datasets huggingface_hub -q

## Step 2: Upload Project Files

Upload the following files:
- `CNN.py`, `VisionTransormer.py`, `config.yaml` (to root)
- Files from `imbalanced-exp/` folder
- `cka/metrics.py` (to `cka/` folder)

In [None]:
import os

# Create necessary directories
!mkdir -p imbalanced-exp cka data

# Check for required files
required_files = [
    'CNN.py',
    'VisionTransormer.py',
    'config.yaml',
    'imbalanced-exp/create_imbalanced_dataset.py',
    'imbalanced-exp/evaluate.py',
    'cka/metrics.py'
]

print("Checking required files...")
missing_files = []
for file in required_files:
    if os.path.exists(file):
        print(f"✓ {file}")
    else:
        print(f"✗ {file} - MISSING!")
        missing_files.append(file)

if missing_files:
    print("\n⚠ Please upload the missing files using the file browser on the left.")
else:
    print("\n✓ All required files are present!")

## Step 3: Load Existing Imbalanced Datasets

This section loads pre-existing long-tailed imbalanced datasets from Hugging Face.

In [None]:
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
import numpy as np
from PIL import Image

class HuggingFaceCIFAR(Dataset):
    """Wrapper for Hugging Face CIFAR datasets"""
    def __init__(self, hf_dataset, transform=None):
        self.dataset = hf_dataset
        self.transform = transform
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = item['img'] if 'img' in item else item['image']
        if isinstance(image, dict):
            image = Image.fromarray(np.array(image))
        elif not isinstance(image, Image.Image):
            image = Image.fromarray(image)
        
        if self.transform:
            image = self.transform(image)
        
        label = item['label']
        return image, label

def load_cifar10_lt(imbalance_factor=100, split='train', transform=None):
    """
    Load CIFAR-10-LT (Long-Tailed) dataset from Hugging Face.
    
    Args:
        imbalance_factor: Imbalance factor (10, 50, 100, 200)
        split: 'train' or 'test'
        transform: Torchvision transforms
    """
    try:
        dataset_name = f"tomas-gajarsky/cifar10-lt-{imbalance_factor}"
        print(f"Loading {dataset_name} ({split})...")
        hf_dataset = load_dataset(dataset_name, split=split)
        return HuggingFaceCIFAR(hf_dataset, transform=transform)
    except Exception as e:
        print(f"Error loading CIFAR-10-LT: {e}")
        print("Falling back to creating custom imbalanced dataset...")
        return None

def load_cifar100_lt(imbalance_factor=100, split='train', transform=None):
    """
    Load CIFAR-100-LT (Long-Tailed) dataset from Hugging Face.
    
    Args:
        imbalance_factor: Imbalance factor (10, 50, 100, 200)
        split: 'train' or 'test'
        transform: Torchvision transforms
    """
    try:
        dataset_name = f"tomas-gajarsky/cifar100-lt-{imbalance_factor}"
        print(f"Loading {dataset_name} ({split})...")
        hf_dataset = load_dataset(dataset_name, split=split)
        return HuggingFaceCIFAR(hf_dataset, transform=transform)
    except Exception as e:
        print(f"Error loading CIFAR-100-LT: {e}")
        print("Falling back to creating custom imbalanced dataset...")
        return None

print("Dataset loading functions ready!")

## Step 4: Experiment Runner

Run experiments with existing imbalanced datasets or custom ones.

In [None]:
import sys
import yaml
import json
from datetime import datetime
import matplotlib.pyplot as plt
from torchvision import datasets

# Add paths
sys.path.append('.')
sys.path.append('imbalanced-exp')

from CNN import CNN
from VisionTransormer import VisionTransformer
from create_imbalanced_dataset import ImbalancedCIFAR10, create_long_tail_imbalance, create_step_imbalance
from evaluate import evaluate_model, print_metrics, plot_confusion_matrix, plot_per_class_metrics

def train_model(model, train_loader, num_epochs, device, print_freq=10):
    """Train model and return training history"""
    history = {'loss': [], 'epoch': []}
    
    for epoch in range(num_epochs):
        epoch_losses = []
        for batch_idx, batch in enumerate(train_loader):
            loss = model.training_step(batch)
            epoch_losses.append(loss.item())
            
            if batch_idx % print_freq == 0:
                print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}")
        
        avg_loss = sum(epoch_losses) / len(epoch_losses)
        history['loss'].append(avg_loss)
        history['epoch'].append(epoch)
        print(f"Epoch {epoch} completed. Average Loss: {avg_loss:.4f}")
    
    return history

def plot_training_curves(history_cnn, history_vit, save_path):
    """Plot training loss curves"""
    plt.figure(figsize=(10, 6))
    plt.plot(history_cnn['epoch'], history_cnn['loss'], label='CNN', marker='o')
    plt.plot(history_vit['epoch'], history_vit['loss'], label='ViT', marker='s')
    plt.xlabel('Epoch')
    plt.ylabel('Training Loss')
    plt.title('Training Loss Curves - Imbalanced Dataset')
    plt.legend()
    plt.grid(alpha=0.3)
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Training curves saved to {save_path}")

print("Experiment functions ready!")

## Step 5: Run Experiment with Existing Imbalanced Dataset

### Option A: CIFAR-10-LT (Long-Tailed CIFAR-10)

In [None]:
# Configuration
DATASET_TYPE = 'cifar10_lt'  # Options: 'cifar10_lt', 'cifar100_lt', 'custom'
IMBALANCE_FACTOR = 100  # For LT datasets: 10, 50, 100, 200
NUM_EPOCHS = 10
NUM_CLASSES = 10  # 10 for CIFAR-10, 100 for CIFAR-100
BATCH_SIZE = 64

# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load config
with open('config.yaml', 'r') as f:
    config = yaml.safe_load(f)

# Setup transforms
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

# Load existing imbalanced dataset
if DATASET_TYPE == 'cifar10_lt':
    train_dataset = load_cifar10_lt(imbalance_factor=IMBALANCE_FACTOR, split='train', transform=transform)
    test_dataset = load_cifar10_lt(imbalance_factor=IMBALANCE_FACTOR, split='test', transform=transform)
    class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
elif DATASET_TYPE == 'cifar100_lt':
    train_dataset = load_cifar100_lt(imbalance_factor=IMBALANCE_FACTOR, split='train', transform=transform)
    test_dataset = load_cifar100_lt(imbalance_factor=IMBALANCE_FACTOR, split='test', transform=transform)
    NUM_CLASSES = 100
    class_names = [f'class_{i}' for i in range(100)]
else:
    raise ValueError(f"Unknown dataset type: {DATASET_TYPE}")

if train_dataset is None or test_dataset is None:
    print("Failed to load dataset. Please check the dataset name or use custom dataset option.")
else:
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    
    # Analyze class distribution
    from collections import Counter
    train_labels = [train_dataset[i][1] for i in range(len(train_dataset))]
    class_distribution = Counter(train_labels)
    print("\nClass distribution in training set:")
    for cls in sorted(class_distribution.keys()):
        print(f"  Class {cls}: {class_distribution[cls]} samples")
    
    print(f"\nTotal training samples: {len(train_dataset)}")
    print(f"Total test samples: {len(test_dataset)}")

### Option B: CIFAR-100-LT (Long-Tailed CIFAR-100)

In [None]:
# Uncomment and modify to use CIFAR-100-LT
# DATASET_TYPE = 'cifar100_lt'
# IMBALANCE_FACTOR = 100
# NUM_CLASSES = 100
# 
# train_dataset = load_cifar100_lt(imbalance_factor=IMBALANCE_FACTOR, split='train', transform=transform)
# test_dataset = load_cifar100_lt(imbalance_factor=IMBALANCE_FACTOR, split='test', transform=transform)
# 
# # Update CNN and ViT to output 100 classes
# # (You'll need to modify the final layer in CNN.py and VisionTransormer.py)

print("To use CIFAR-100-LT, uncomment the code above and ensure models support 100 classes.")

### Option C: Custom Imbalanced Dataset (Create from CIFAR-10)

In [None]:
# Uncomment to create custom imbalanced dataset
# from create_imbalanced_dataset import create_long_tail_imbalance, create_step_imbalance
# 
# # Load full CIFAR-10
# full_train = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
# full_test = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)
# 
# # Create imbalanced version
# train_indices, class_dist = create_long_tail_imbalance(full_train, imbalance_ratio=0.1)
# train_dataset = ImbalancedCIFAR10(full_train, train_indices)
# test_dataset = full_test
# 
# train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
# test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

print("To use custom imbalanced dataset, uncomment the code above.")

## Step 6: Train and Evaluate Models

In [None]:
# Create output directory
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = f"imbalanced-exp/results/{DATASET_TYPE}_if{IMBALANCE_FACTOR}_{timestamp}"
os.makedirs(output_dir, exist_ok=True)

# Initialize models
print("\n" + "="*60)
print("Initializing CNN (ResNet)")
print("="*60)
cnn_config = config['cnn'].copy()
cnn_model = CNN(config=cnn_config, device=device)

print("\n" + "="*60)
print("Initializing Vision Transformer")
print("="*60)
vit_config = config['vision_transformer'].copy()
vit_model = VisionTransformer(config=vit_config, device=device)

# Train CNN
print("\nTraining CNN...")
cnn_history = train_model(cnn_model, train_loader, NUM_EPOCHS, device, 
                         print_freq=config.get('print_batch_frequency', 10))

# Train ViT
print("\nTraining ViT...")
vit_history = train_model(vit_model, train_loader, NUM_EPOCHS, device,
                         print_freq=config.get('print_batch_frequency', 10))

# Evaluate models
print("\n" + "="*60)
print("Evaluating CNN")
print("="*60)
cnn_metrics = evaluate_model(cnn_model, test_loader, device, num_classes=NUM_CLASSES, class_names=class_names)
print_metrics(cnn_metrics, "CNN (ResNet)", class_names)

print("\n" + "="*60)
print("Evaluating Vision Transformer")
print("="*60)
vit_metrics = evaluate_model(vit_model, test_loader, device, num_classes=NUM_CLASSES, class_names=class_names)
print_metrics(vit_metrics, "Vision Transformer", class_names)

## Step 7: Save Results and Generate Visualizations

In [None]:
# Save results
results = {
    'experiment_config': {
        'dataset_type': DATASET_TYPE,
        'imbalance_factor': IMBALANCE_FACTOR,
        'num_epochs': NUM_EPOCHS,
        'num_classes': NUM_CLASSES,
        'class_distribution': dict(class_distribution) if 'class_distribution' in locals() else {}
    },
    'cnn_metrics': {
        'test_loss': float(cnn_metrics['test_loss']),
        'accuracy': float(cnn_metrics['accuracy']),
        'f1_macro': float(cnn_metrics['f1_macro']),
        'f1_weighted': float(cnn_metrics['f1_weighted']),
        'f1_per_class': [float(x) for x in cnn_metrics['f1_per_class']],
        'per_class_recall': [float(x) for x in cnn_metrics['per_class_recall']],
        'per_class_precision': [float(x) for x in cnn_metrics['per_class_precision']]
    },
    'vit_metrics': {
        'test_loss': float(vit_metrics['test_loss']),
        'accuracy': float(vit_metrics['accuracy']),
        'f1_macro': float(vit_metrics['f1_macro']),
        'f1_weighted': float(vit_metrics['f1_weighted']),
        'f1_per_class': [float(x) for x in vit_metrics['f1_per_class']],
        'per_class_recall': [float(x) for x in vit_metrics['per_class_recall']],
        'per_class_precision': [float(x) for x in vit_metrics['per_class_precision']]
    },
    'training_history': {
        'cnn_loss': [float(x) for x in cnn_history['loss']],
        'vit_loss': [float(x) for x in vit_history['loss']]
    }
}

with open(f'{output_dir}/results.json', 'w') as f:
    json.dump(results, f, indent=2)

# Generate visualizations
print("\nGenerating visualizations...")

# Confusion matrices
plot_confusion_matrix(
    cnn_metrics['confusion_matrix'], 
    class_names,
    f'{output_dir}/confusion_matrix_cnn.png',
    'CNN (ResNet)'
)

plot_confusion_matrix(
    vit_metrics['confusion_matrix'],
    class_names,
    f'{output_dir}/confusion_matrix_vit.png',
    'Vision Transformer'
)

# Per-class metrics comparison
plot_per_class_metrics(
    cnn_metrics, vit_metrics, class_names,
    f'{output_dir}/per_class_metrics_comparison.png'
)

# Training curves
plot_training_curves(
    cnn_history, vit_history,
    f'{output_dir}/training_curves.png'
)

# Summary
print("\n" + "="*60)
print("Summary Comparison")
print("="*60)
print(f"{'Metric':<25} {'CNN':<15} {'ViT':<15}")
print("-" * 60)
print(f"{'Accuracy':<25} {cnn_metrics['accuracy']*100:>6.2f}%      {vit_metrics['accuracy']*100:>6.2f}%")
print(f"{'F1-Score (Macro)':<25} {cnn_metrics['f1_macro']:>6.4f}      {vit_metrics['f1_macro']:>6.4f}")
print(f"{'F1-Score (Weighted)':<25} {cnn_metrics['f1_weighted']:>6.4f}      {vit_metrics['f1_weighted']:>6.4f}")
print(f"{'Test Loss':<25} {cnn_metrics['test_loss']:>6.4f}      {vit_metrics['test_loss']:>6.4f}")

print(f"\nAll results saved to: {output_dir}")

## Step 8: View Results

In [None]:
from IPython.display import Image, display
import json

# Display results
with open(f'{output_dir}/results.json', 'r') as f:
    results = json.load(f)

print("Experiment Results:")
print(f"Dataset: {results['experiment_config']['dataset_type']}")
print(f"Imbalance Factor: {results['experiment_config']['imbalance_factor']}")
print(f"\nCNN Accuracy: {results['cnn_metrics']['accuracy']*100:.2f}%")
print(f"ViT Accuracy: {results['vit_metrics']['accuracy']*100:.2f}%")
print(f"CNN F1-Macro: {results['cnn_metrics']['f1_macro']:.4f}")
print(f"ViT F1-Macro: {results['vit_metrics']['f1_macro']:.4f}")

# Display plots
plot_files = [
    'confusion_matrix_cnn.png',
    'confusion_matrix_vit.png',
    'per_class_metrics_comparison.png',
    'training_curves.png'
]

for plot_file in plot_files:
    plot_path = f"{output_dir}/{plot_file}"
    if os.path.exists(plot_path):
        print(f"\nDisplaying {plot_file}:")
        display(Image(plot_path))

## Step 9: Download Results

In [None]:
from google.colab import files
import zipfile
import shutil

# Create zip file
zip_filename = f"{os.path.basename(output_dir)}.zip"
shutil.make_archive(output_dir, 'zip', output_dir)

# Download
files.download(f"{output_dir}.zip")
print(f"Downloaded {zip_filename}")

## Step 10: Save to Google Drive (Optional)

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Copy results to Drive
drive_path = '/content/drive/MyDrive/imbalanced_experiments'
!mkdir -p {drive_path}
!cp -r {output_dir} {drive_path}/
print(f"Results saved to {drive_path}")