# Large Dataset Experiments - CNN vs Vision Transformer

This notebook runs experiments comparing CNN (ResNet) and Vision Transformer (ViT) on large datasets (CIFAR-100).

## Step 1: Setup and Install Dependencies

In [None]:
# Install required packages
!pip install torch torchvision torchmetrics pyyaml matplotlib seaborn -q

## Step 2: Upload Project Files

You have two options:

### Option A: Upload files directly (if you have the files locally)
1. Upload `CNN.py`, `VisionTransormer.py`, `config.yaml` to the root directory
2. Upload all files from `large-dataset/` folder
3. Upload `cka/metrics.py` and `cka/hook_manager.py` (if needed)

### Option B: Clone from GitHub (if your project is on GitHub)
Uncomment and modify the cell below:

In [None]:
# Option B: Clone from GitHub (uncomment and modify if needed)
# !git clone https://github.com/yourusername/your-repo.git
# %cd your-repo

## Step 3: Verify File Structure

Check that all necessary files are present:

In [None]:
import os

# Check for required files (notebook is inside `large-dataset`)
# So model/config files are one level up, and local experiment files are in this folder.
required_files = [
    '../CNN.py',
    '../VisionTransormer.py',
    '../config.yaml',
    'evaluate.py',
    'load_coyo_dataset.py',
    'run_large_dataset_experiment.py',
    '../cka/metrics.py'
]

print("Checking required files...")
for file in required_files:
    if os.path.exists(file):
        print(f"✓ {file}")
    else:
        print(f"✗ {file} - MISSING!")

print("\nFile check complete!")

## Step 4: Run Large Dataset Experiment

Run the experiment with default settings (CIFAR-100, 10 epochs):

In [None]:
import sys
import os

# Change to large-dataset directory
os.chdir('large-dataset')

# # Run experiment with CIFAR-100
# !python run_large_dataset_experiment.py \
#     --dataset cifar100 \
#     --num_epochs 10 \
#     --config ../config.yaml \
#     --output_dir ./results

# Or run with Coyo-labeled-300m dataset
!python run_large_dataset_experiment.py \
    --dataset coyo \
    --num_epochs 10 \
    --max_samples 50000 \
    --num_classes 500 \
    --config ../config.yaml \
    --output_dir ./results

## Step 5: Run with Custom Parameters

You can customize the experiment parameters:

In [None]:
# Example: More epochs for better convergence
# !python run_large_dataset_experiment.py \
#     --dataset cifar100 \
#     --num_epochs 20 \
#     --config ../config.yaml \
#     --output_dir ./results

# Example: Coyo-labeled-300m with custom parameters
# !python run_large_dataset_experiment.py \
#     --dataset coyo \
#     --num_epochs 20 \
#     --max_samples 100000 \
#     --num_classes 1000 \
#     --config ../config.yaml \
#     --output_dir ./results

## Step 6: View Results

Display the generated plots and results:

In [None]:
import matplotlib.pyplot as plt
from IPython.display import Image, display
import glob
import json

# Find the most recent experiment directory
result_dirs = sorted(glob.glob('results/large_dataset_*'), key=os.path.getmtime, reverse=True)
if result_dirs:
    latest_dir = result_dirs[0]
    print(f"Latest experiment: {latest_dir}")
    
    # Display results JSON
    if os.path.exists(f"{latest_dir}/results.json"):
        with open(f"{latest_dir}/results.json", 'r') as f:
            results = json.load(f)
        print("\nExperiment Results:")
        print(f"CNN Accuracy: {results['cnn_metrics']['accuracy']*100:.2f}%")
        print(f"ViT Accuracy: {results['vit_metrics']['accuracy']*100:.2f}%")
        print(f"CNN F1-Macro: {results['cnn_metrics']['f1_macro']:.4f}")
        print(f"ViT F1-Macro: {results['vit_metrics']['f1_macro']:.4f}")
    
    # Display plots from charts folder
    charts_dir = os.path.join(latest_dir, 'charts')
    plot_files = [
        'confusion_matrix_cnn.png',
        'confusion_matrix_vit.png',
        'per_class_metrics_comparison.png',
        'training_curves.png'
    ]
    
    for plot_file in plot_files:
        plot_path = os.path.join(charts_dir, plot_file)
        if os.path.exists(plot_path):
            print(f"\nDisplaying {plot_file}:")
            display(Image(plot_path))
        else:
            print(f"Chart not found: {plot_path}")
else:
    print("No experiment results found. Run the experiment first.")