# AlexNet iFood2019 - Train Baseline Model

This notebook trains a single AlexNet model (baseline by default).

## Steps:
1. Setup Environment
2. Configure Training
3. Train Model
4. Evaluate Results
5. Visualize Training Curves

In [None]:
# ============================================================
# Cell 1: Setup Environment
# ============================================================

from google.colab import drive
drive.mount('/content/drive')

import os
import sys

PROJECT_PATH = '/content/drive/MyDrive/AlexNet_iFood2019'
REPO_PATH = '/content/alexnet-ifood2019'

# Clone if needed
if not os.path.exists(REPO_PATH):
    !git clone https://github.com/deftorch/alexnet-ifood2019.git {REPO_PATH}

os.chdir(REPO_PATH)
sys.path.insert(0, REPO_PATH)

# Create symlinks
!rm -rf data checkpoints evaluation_results analysis_results
!ln -s {PROJECT_PATH}/dataset data
!ln -s {PROJECT_PATH}/checkpoints checkpoints
!ln -s {PROJECT_PATH}/evaluation_results evaluation_results
!ln -s {PROJECT_PATH}/analysis_results analysis_results

print("‚úì Environment ready")

In [None]:
# ============================================================
# Cell 2: Install Dependencies & Verify GPU
# ============================================================

!pip install -q torch torchvision pandas numpy pillow scikit-learn matplotlib seaborn tqdm wandb

import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("‚ö†Ô∏è  No GPU! Go to Runtime > Change runtime type > GPU")

In [None]:
# ============================================================
# Cell 3: Training Configuration
# ============================================================

# ‚öôÔ∏è MODIFY THESE SETTINGS AS NEEDED
CONFIG = {
    'model_name': 'alexnet_baseline',  # Options: alexnet_baseline, alexnet_mod1, alexnet_mod2, alexnet_combined
    'num_epochs': 50,
    'batch_size': 128,
    'learning_rate': 0.01,
    'momentum': 0.9,
    'weight_decay': 0.0005,
    'scheduler': 'step',  # Options: step, cosine
    'num_workers': 4,
    'use_wandb': False,  # Set True to log to Weights & Biases
}

print("Training Configuration:")
print("=" * 40)
for key, value in CONFIG.items():
    print(f"  {key}: {value}")
print("=" * 40)

In [None]:
# ============================================================
# Cell 4: Setup WandB (Optional)
# ============================================================

if CONFIG['use_wandb']:
    import wandb
    wandb.login()
    print("‚úì WandB logged in")
else:
    print("‚ÑπÔ∏è  WandB disabled - logging to console only")

In [None]:
# ============================================================
# Cell 5: Start Training
# ============================================================

model_name = CONFIG['model_name']
epochs = CONFIG['num_epochs']
batch_size = CONFIG['batch_size']
lr = CONFIG['learning_rate']
momentum = CONFIG['momentum']
wd = CONFIG['weight_decay']
scheduler = CONFIG['scheduler']
num_workers = CONFIG['num_workers']
use_wandb = "--use_wandb" if CONFIG['use_wandb'] else ""

!python src/train.py \
    --data_dir data \
    --model_name {model_name} \
    --num_epochs {epochs} \
    --batch_size {batch_size} \
    --lr {lr} \
    --momentum {momentum} \
    --weight_decay {wd} \
    --scheduler {scheduler} \
    --num_workers {num_workers} \
    --save_dir checkpoints \
    {use_wandb}

In [None]:
# ============================================================
# Cell 6: Evaluate Model
# ============================================================

model_name = CONFIG['model_name']

!python src/evaluate.py \
    --data_dir data \
    --model_path checkpoints/{model_name}_best.pth \
    --model_name {model_name} \
    --split val \
    --output_dir evaluation_results

In [None]:
# ============================================================
# Cell 7: Visualize Training Curves
# ============================================================

import json
import matplotlib.pyplot as plt

model_name = CONFIG['model_name']
history_file = f'checkpoints/{model_name}_history.json'

if os.path.exists(history_file):
    with open(history_file) as f:
        history = json.load(f)
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    epochs = range(1, len(history['train_loss']) + 1)
    
    # Loss
    axes[0].plot(epochs, history['train_loss'], label='Train', linewidth=2)
    axes[0].plot(epochs, history['val_loss'], label='Validation', linewidth=2)
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Training & Validation Loss')
    axes[0].legend()
    axes[0].grid(alpha=0.3)
    
    # Accuracy
    axes[1].plot(epochs, history['train_acc'], label='Train', linewidth=2)
    axes[1].plot(epochs, history['val_acc'], label='Validation', linewidth=2)
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy')
    axes[1].set_title('Training & Validation Accuracy')
    axes[1].legend()
    axes[1].grid(alpha=0.3)
    
    plt.suptitle(f'{model_name} Training Curves', fontsize=14)
    plt.tight_layout()
    plt.savefig(f'analysis_results/{model_name}_training_curves.png', dpi=150)
    plt.show()
    
    # Summary
    print("\nTraining Summary:")
    print(f"  Final Train Loss: {history['train_loss'][-1]:.4f}")
    print(f"  Final Train Acc: {history['train_acc'][-1]:.4f}")
    print(f"  Best Val Acc: {max(history['val_acc']):.4f}")
    print(f"  Best Val Epoch: {history['val_acc'].index(max(history['val_acc'])) + 1}")
else:
    print(f"‚ö†Ô∏è  History file not found: {history_file}")

In [None]:
# ============================================================
# Cell 8: Keep Session Alive (for long training)
# ============================================================

from IPython.display import Javascript

display(Javascript('''
    function ClickConnect(){
        console.log("Keeping session alive...");
        document.querySelector("colab-connect-button").click()
    }
    setInterval(ClickConnect, 60000)
'''))

print("‚úì Keep-alive script running")
print("üí° This prevents Colab from disconnecting during long training")