# Colab Training Notebook: Generic PyTorch Pipeline

This notebook sets up a full training workflow (data, model, training, validation, logging, export) suitable for Google Colab GPU runtimes.

In [None]:
# 1. Check Runtime Hardware (GPU/TPU Availability)
import torch, platform, os
print(f"Python: {platform.python_version()}")
print(f"Torch: {torch.__version__}")
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("Device count:", torch.cuda.device_count())
    print("Current device:", torch.cuda.current_device())
    print("Device name:", torch.cuda.get_device_name(0))
else:
    print("If you need a GPU: Runtime > Change runtime type > GPU")

In [None]:
# 2. Install and Upgrade Dependencies
!pip install -q --upgrade pip
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
!pip install -q numpy matplotlib scikit-learn tensorboard tqdm
import torch, torchvision, numpy as np, matplotlib
print("Torch version:", torch.__version__)
print("Torchvision version:", torchvision.__version__)
## 3. Mount Google Drive (Optional Persistence)
USE_DRIVE = False  # set True to persist
if USE_DRIVE:
    from google.colab import drive
    drive.mount('/content/drive')
    BASE_DIR = '/content/drive/MyDrive/colab_training_run'
else:
    BASE_DIR = '/content/training_run'

os.makedirs(BASE_DIR, exist_ok=True)
print("Base directory:", BASE_DIR)

In [None]:
# 4. Set Global Configuration (Paths, Hyperparameters, Seeds)
import random, json, math, numpy as np, torch
cfg = {
    'batch_size': 128,
    'lr': 3e-4,
    'epochs': 15,
    'seed': 42,
    'data_dir': BASE_DIR + '/data',
    'output_dir': BASE_DIR + '/outputs',
    'num_workers': 2
}

def set_seed(seed:int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(cfg['seed'])
os.makedirs(cfg['data_dir'], exist_ok=True)
os.makedirs(cfg['output_dir'], exist_ok=True)
print(cfg)
# 5. Import Libraries
import os, datetime
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report
from torch.utils.tensorboard import SummaryWriter
print('Imports complete.')
# 6. Download or Ingest Dataset (CIFAR10)
train_val = datasets.CIFAR10(root=cfg['data_dir'], train=True, download=True)
TEST = datasets.CIFAR10(root=cfg['data_dir'], train=False, download=True)
classes = train_val.classes
print('Classes:', classes)
# 7. Inspect and Visualize Sample Data
fig, axes = plt.subplots(2,5, figsize=(10,4))
for ax in axes.flatten():
    idx = random.randint(0, len(train_val)-1)
    img, label = train_val[idx]
    ax.imshow(img)
    ax.set_title(classes[label])
    ax.axis('off')
plt.tight_layout()
plt.show()

In [None]:
# 8. Create Train/Validation/Test Splits
val_ratio = 0.1
val_len = int(len(train_val)*val_ratio)
train_len = len(train_val) - val_len
train_ds, val_ds = random_split(train_val, [train_len, val_len], generator=torch.Generator().manual_seed(cfg['seed']))
print(f'Train: {len(train_ds)}  Val: {len(val_ds)}  Test: {len(TEST)}')
# 9. Build Dataset Pipeline (DataLoaders)
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
common_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
train_ds.dataset.transform = train_transform
val_ds.dataset.transform = common_transform
TEST.transform = common_transform

train_loader = DataLoader(train_ds, batch_size=cfg['batch_size'], shuffle=True, num_workers=cfg['num_workers'], pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=cfg['batch_size'], shuffle=False, num_workers=cfg['num_workers'], pin_memory=True)
test_loader = DataLoader(TEST, batch_size=cfg['batch_size'], shuffle=False, num_workers=cfg['num_workers'], pin_memory=True)
print('DataLoaders ready.')
# 10. Define Model Architecture
from torchvision.models import resnet18
model = resnet18(weights=None, num_classes=len(classes))
print(sum(p.numel() for p in model.parameters()), 'total params')

# 11. Loss Function, Metrics, Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=cfg['lr'])

def accuracy(outputs, targets):
    _, preds = torch.max(outputs, 1)
    return (preds == targets).float().mean().item()
# 12. Mixed Precision (Automatic)
use_amp = torch.cuda.is_available()
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler(enabled=use_amp)
print('AMP enabled:', use_amp)
# 13. LR Scheduler
from torch.optim.lr_scheduler import CosineAnnealingLR
scheduler = CosineAnnealingLR(optimizer, T_max=cfg['epochs'])

# 14 & 15. Training + Validation with Early Stopping
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
writer = SummaryWriter(log_dir=os.path.join(cfg['output_dir'], 'tb'))

history = {k:[] for k in ['train_loss','val_loss','train_acc','val_acc','lr']}
best_val_acc = 0.0
patience = 5
wait = 0
for epoch in range(1, cfg['epochs']+1):
    model.train()
    running_loss = 0.0
    running_acc = 0.0
    for batch in train_loader:
        imgs, labels = batch
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad(set_to_none=True)
        with autocast(enabled=use_amp):
            outputs = model(imgs)
            loss = criterion(outputs, labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        running_loss += loss.item()*imgs.size(0)
        running_acc += accuracy(outputs.detach(), labels)*imgs.size(0)
    train_loss = running_loss/len(train_loader.dataset)
    train_acc = running_acc/len(train_loader.dataset)

    model.eval()
    val_loss = 0.0
    val_acc = 0.0
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            with autocast(enabled=use_amp):
                outputs = model(imgs)
                loss = criterion(outputs, labels)
            val_loss += loss.item()*imgs.size(0)
            val_acc += accuracy(outputs, labels)*imgs.size(0)
    val_loss /= len(val_loader.dataset)
    val_acc /= len(val_loader.dataset)

    scheduler.step()
    lr_current = scheduler.get_last_lr()[0]

    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_acc'].append(train_acc)
    history['val_acc'].append(val_acc)
    history['lr'].append(lr_current)

    writer.add_scalar('Loss/train', train_loss, epoch)
    writer.add_scalar('Loss/val', val_loss, epoch)
    writer.add_scalar('Acc/train', train_acc, epoch)
    writer.add_scalar('Acc/val', val_acc, epoch)
    writer.add_scalar('LR', lr_current, epoch)

    print(f"Epoch {epoch}: train_loss={train_loss:.4f} val_loss={val_loss:.4f} train_acc={train_acc:.3f} val_acc={val_acc:.3f} lr={lr_current:.2e}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        wait = 0
        torch.save(model.state_dict(), os.path.join(cfg['output_dir'], 'best_model.pt'))
    else:
        wait += 1
        if wait >= patience:
            print('Early stopping triggered.')
            break

writer.flush()
# 16 & 17. Plot Training History
import matplotlib.pyplot as plt
fig, axs = plt.subplots(1,2, figsize=(12,4))
axs[0].plot(history['train_loss'], label='train'); axs[0].plot(history['val_loss'], label='val'); axs[0].set_title('Loss'); axs[0].legend()
axs[1].plot(history['train_acc'], label='train'); axs[1].plot(history['val_acc'], label='val'); axs[1].set_title('Accuracy'); axs[1].legend()
plt.show()
# 18. Evaluate on Test Set
best_path = os.path.join(cfg['output_dir'], 'best_model.pt')
model.load_state_dict(torch.load(best_path, map_location=device))
model.eval()
all_preds, all_labels = [], []
with torch.no_grad():
    for imgs, labels in test_loader:
        imgs = imgs.to(device)
        outputs = model(imgs)
        _, preds = torch.max(outputs, 1)
        all_preds.extend(preds.cpu().tolist())
        all_labels.extend(labels.tolist())

cm = confusion_matrix(all_labels, all_preds)
print('Confusion matrix shape:', cm.shape)
print(classification_report(all_labels, all_preds, target_names=classes))
# 19. Save and Export Trained Model
final_path = os.path.join(cfg['output_dir'], 'final_model.pt')
torch.save(model.state_dict(), final_path)
with open(os.path.join(cfg['output_dir'], 'config.json'),'w') as f: json.dump(cfg, f, indent=2)
with open(os.path.join(cfg['output_dir'], 'classes.json'),'w') as f: json.dump(classes, f)
print('Saved:', final_path)
# 20. Load Saved Model and Run Inference
loaded = resnet18(weights=None, num_classes=len(classes))
loaded.load_state_dict(torch.load(best_path, map_location=device))
loaded.eval()
idxs = [random.randint(0, len(TEST)-1) for _ in range(5)]
fig, axes = plt.subplots(1,5, figsize=(15,3))
for ax, idx in zip(axes, idxs):
    raw_img, lbl = TEST[idx]
    img = common_transform(raw_img)
    with torch.no_grad():
        out = loaded(img.unsqueeze(0).to(device))
        pred = out.argmax(1).item()
    ax.imshow(raw_img)
    ax.set_title(f"P:{classes[pred]}\nT:{classes[lbl]}")
    ax.axis('off')
plt.show()
# 21. Optional: Convert Model to ONNX
sample_input = torch.randn(1,3,32,32, device=device)
onxx_path = os.path.join(cfg['output_dir'], 'model.onnx')
try:
    torch.onnx.export(model, sample_input, onxx_path, input_names=['input'], output_names=['logits'], dynamic_axes={'input':{0:'batch'}, 'logits':{0:'batch'}}, opset_version=17)
    print('Exported ONNX to', onxx_path)
except Exception as e:
    print('ONNX export failed (optional step):', e)
# 22. Clean Up Session / Free GPU Memory
del sample_input
torch.cuda.empty_cache() if torch.cuda.is_available() else None
print('Cleanup done.')