## Step 1: Setup and Installation

In [None]:
# Check GPU availability
!nvidia-smi

In [None]:
# Clone the repository
!git clone https://github.com/automl/trivialaugment.git
%cd trivialaugment

In [None]:
# Install PyTorch (if needed)
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")

In [None]:
# Install dependencies
!pip install -q git+https://github.com/wbaek/theconf
!pip install -q git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git
!pip install -q git+https://github.com/ildoonet/pystopwatch2.git
!pip install -q colored pretrainedmodels tqdm tensorboardX scikit-learn matplotlib psutil requests tensorboard

## Step 2: Mount Google Drive (Optional - for saving checkpoints)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Create directory for saving results
import os
save_dir = '/content/drive/MyDrive/trivialaugment_results'
os.makedirs(save_dir, exist_ok=True)
print(f"Results will be saved to: {save_dir}")

## Step 3: Verify Setup and Download CIFAR-10

In [None]:
# Download CIFAR-10 dataset
from torchvision import datasets

print("Downloading CIFAR-10...")
datasets.CIFAR10(root='./data', train=True, download=True)
datasets.CIFAR10(root='./data', train=False, download=True)
print("✓ CIFAR-10 downloaded successfully!")

## Step 4: Check Available Config Files

In [None]:
# List WRN-40-2 CIFAR-10 configs (for Table 4a)
!ls -lh confs/wresnet40x2_cifar10*.yaml

## Step 5: Run Experiments

### Table 4a: WRN-40-2 on CIFAR-10 (200 epochs)

We'll run 5 experiments corresponding to Table 4a in the paper.

**Expected Results:**
1. FAA (RA space): 4.72 ± 0.13% error
2. UA (UA space): 4.98 ± 0.23% error
3. UA (RA space): 4.07 ± 0.11% error
4. **TA (RA space): 3.94 ± 0.11% error**
5. **TA (Wide space): 3.82 ± 0.13% error** ⭐ Best result

⚠️ **Note:** Each experiment takes ~4-6 hours on T4 GPU. Choose which ones to run!

### Experiment 1: Fast AutoAugment (FAA)

In [None]:
# Experiment 1: FAA with RA search space
!python -m TrivialAugment.train \
    -c confs/wresnet40x2_cifar10_b128_maxlr.1_faa_nowarmup_200epochs.yaml \
    --dataroot data \
    --tag expFAA \
    --save /content/drive/MyDrive/trivialaugment_results/faa_model.pth

### Experiment 2: UniformAugment with UA search space

In [None]:
# Experiment 2: UA with UA search space
!python -m TrivialAugment.train \
    -c confs/wresnet40x2_cifar10_b128_maxlr.1_ua_uasesp_nowarmup_200epochs.yaml \
    --dataroot data \
    --tag expUAua \
    --save /content/drive/MyDrive/trivialaugment_results/ua_ua_model.pth

### Experiment 3: UniformAugment with RA search space

In [None]:
# Experiment 3: UA with RA search space
!python -m TrivialAugment.train \
    -c confs/wresnet40x2_cifar10_b128_maxlr.1_ua_fixedsesp_nowarmup_200epochs.yaml \
    --dataroot data \
    --tag expUAra \
    --save /content/drive/MyDrive/trivialaugment_results/ua_ra_model.pth

### Experiment 4: TrivialAugment with RA search space ⭐

In [None]:
# Experiment 4: TrivialAugment with RA search space
!python -m TrivialAugment.train \
    -c confs/wresnet40x2_cifar10_b128_maxlr.1_ta_fixedsesp_nowarmup_200epochs.yaml \
    --dataroot data \
    --tag expTAra \
    --save /content/drive/MyDrive/trivialaugment_results/ta_ra_model.pth

### Experiment 5: TrivialAugment with Wide search space ⭐⭐ (Best)

In [None]:
# Experiment 5: TrivialAugment with Wide search space (BEST RESULT)
!python -m TrivialAugment.train \
    -c confs/wresnet40x2_cifar10_b128_maxlr.1_ta_widesesp_nowarmup_200epochs.yaml \
    --dataroot data \
    --tag expTAwide \
    --save /content/drive/MyDrive/trivialaugment_results/ta_wide_model.pth

## Step 6: Monitor Training (in separate cell)

In [None]:
# Load tensorboard
%load_ext tensorboard
%tensorboard --logdir logs

## Step 7: View Results

In [None]:
# Aggregate and display results
!python aggregate_results.py

## Step 8: Extract Final Test Accuracies

In [None]:
import glob
import torch
import pandas as pd

# Find all saved checkpoints
checkpoint_dir = '/content/drive/MyDrive/trivialaugment_results'
local_checkpoints = glob.glob('*.pth')
drive_checkpoints = glob.glob(f'{checkpoint_dir}/*.pth')

results = []

for ckpt_path in local_checkpoints + drive_checkpoints:
    try:
        ckpt = torch.load(ckpt_path, map_location='cpu')
        if 'log' in ckpt and 'test' in ckpt['log']:
            test_acc = ckpt['log']['test']['top1']
            test_error = 100 - test_acc
            epoch = ckpt.get('epoch', 'N/A')
            
            results.append({
                'Checkpoint': ckpt_path.split('/')[-1],
                'Epoch': epoch,
                'Test Accuracy (%)': f'{test_acc:.2f}',
                'Test Error (%)': f'{test_error:.2f}'
            })
    except:
        pass

if results:
    df = pd.DataFrame(results)
    print("\n" + "="*80)
    print("FINAL RESULTS - Table 4a: WRN-40-2 on CIFAR-10")
    print("="*80)
    print(df.to_string(index=False))
    print("\n" + "="*80)
    print("Expected Results from Paper:")
    print("  FAA (RA):      4.72 ± 0.13% error")
    print("  UA (UA):       4.98 ± 0.23% error")
    print("  UA (RA):       4.07 ± 0.11% error")
    print("  TA (RA):       3.94 ± 0.11% error")
    print("  TA (Wide):     3.82 ± 0.13% error (BEST)")
    print("="*80)
else:
    print("No checkpoints found yet. Keep training!")

## Tips for Colab Usage

### Avoiding Session Timeouts:
1. **Save to Google Drive**: Use `--save` flag with Drive path
2. **Resume from checkpoint**: If interrupted, the code will auto-resume
3. **Keep browser tab active**: Prevents idle timeout
4. **Use Colab Pro**: For longer runtimes (up to 24 hours)

### Running Multiple Experiments:
- **Sequential**: Run one after another (recommended for free tier)
- **Parallel**: Open multiple Colab notebooks with different Google accounts
- **Priority**: Start with Experiments 4 & 5 (TrivialAugment - main contribution)

### Monitoring:
- Check tensorboard cell periodically
- Look for test accuracy in logs (printed every 20 epochs)
- Checkpoints saved automatically every 20 epochs

### Memory Issues:
If you get OOM errors:
```python
# Edit config file before running
!sed -i 's/batch: 128/batch: 64/' confs/wresnet40x2*.yaml
```

## Citation

If you use this code or reproduce these results:

```bibtex
@InProceedings{Muller_2021_ICCV,
    author    = {M\"uller, Samuel G. and Hutter, Frank},
    title     = {TrivialAugment: Tuning-Free Yet State-of-the-Art Data Augmentation},
    booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
    month     = {October},
    year      = {2021},
    pages     = {774-782}
}
```