# Imbalanced Dataset Experiments - CNN vs Vision Transformer

This notebook runs experiments comparing CNN (ResNet) and Vision Transformer (ViT) on imbalanced CIFAR10 datasets.

## Step 1: Setup and Install Dependencies

In [None]:
# Install required packages
!pip install torch torchvision torchmetrics pyyaml matplotlib seaborn -q

## Step 2: Upload Project Files

You have two options:

### Option A: Upload files directly (if you have the files locally)
1. Upload `CNN.py`, `VisionTransormer.py`, `config.yaml` to the root directory
2. Upload all files from `imbalanced-exp/` folder
3. Upload `cka/metrics.py` and `cka/hook_manager.py` (if needed)

### Option B: Clone from GitHub (if your project is on GitHub)
Uncomment and modify the cell below:

In [None]:
# Option B: Clone from GitHub (uncomment and modify if needed)
# !git clone https://github.com/yourusername/your-repo.git
# %cd your-repo

## Step 3: Verify File Structure

Check that all necessary files are present:

In [None]:
import os

# Check for required files
required_files = [
    'CNN.py',
    'VisionTransormer.py',
    'config.yaml',
    'imbalanced-exp/create_imbalanced_dataset.py',
    'imbalanced-exp/evaluate.py',
    'imbalanced-exp/run_imbalanced_experiment.py',
    'cka/metrics.py'
]

print("Checking required files...")
for file in required_files:
    if os.path.exists(file):
        print(f"✓ {file}")
    else:
        print(f"✗ {file} - MISSING!")

print("\nFile check complete!")

## Step 4: Run Imbalanced Dataset Experiment

Run the experiment with default settings (long-tail imbalance, ratio=0.1, 10 epochs):

In [None]:
import sys
import os

# Change to imbalanced-exp directory
os.chdir('imbalanced-exp')

# Run experiment
!python run_imbalanced_experiment.py \
    --imbalance_type long_tail \
    --imbalance_ratio 0.1 \
    --num_epochs 10 \
    --config ../config.yaml \
    --output_dir ./results

## Step 5: Run with Custom Parameters

You can customize the experiment parameters:

In [None]:
# Example: Step imbalance with different ratio
# !python run_imbalanced_experiment.py \
#     --imbalance_type step \
#     --imbalance_ratio 0.05 \
#     --num_epochs 15 \
#     --config ../config.yaml \
#     --output_dir ./results

## Step 6: View Results

Display the generated plots and results:

In [None]:
import matplotlib.pyplot as plt
from IPython.display import Image, display
import glob
import json

# Find the most recent experiment directory
result_dirs = sorted(glob.glob('results/imbalanced_*'), key=os.path.getmtime, reverse=True)
if result_dirs:
    latest_dir = result_dirs[0]
    print(f"Latest experiment: {latest_dir}")
    
    # Display results JSON
    if os.path.exists(f"{latest_dir}/results.json"):
        with open(f"{latest_dir}/results.json", 'r') as f:
            results = json.load(f)
        print("\nExperiment Results:")
        print(f"CNN Accuracy: {results['cnn_metrics']['accuracy']*100:.2f}%")
        print(f"ViT Accuracy: {results['vit_metrics']['accuracy']*100:.2f}%")
        print(f"CNN F1-Macro: {results['cnn_metrics']['f1_macro']:.4f}")
        print(f"ViT F1-Macro: {results['vit_metrics']['f1_macro']:.4f}")
    
    # Display plots
    plot_files = [
        'confusion_matrix_cnn.png',
        'confusion_matrix_vit.png',
        'per_class_metrics_comparison.png',
        'training_curves.png'
    ]
    
    for plot_file in plot_files:
        plot_path = f"{latest_dir}/{plot_file}"
        if os.path.exists(plot_path):
            print(f"\nDisplaying {plot_file}:")
            display(Image(plot_path))
else:
    print("No experiment results found. Run the experiment first.")

## Step 7: Download Results

Download the results folder to your local machine:

In [None]:
from google.colab import files
import zipfile
import shutil

# Find latest results directory
result_dirs = sorted(glob.glob('results/imbalanced_*'), key=os.path.getmtime, reverse=True)
if result_dirs:
    latest_dir = result_dirs[0]
    
    # Create zip file
    zip_filename = f"{os.path.basename(latest_dir)}.zip"
    shutil.make_archive(latest_dir, 'zip', latest_dir)
    
    # Download
    files.download(zip_filename)
    print(f"Downloaded {zip_filename}")
else:
    print("No results to download.")