# WaferScan AI: Phase 2 - Model Training

This notebook covers loading the pre-split dataset, defining the ViT-Small model, and executing the training loop on Google Colab (Tesla T4).

In [None]:
# CELL 1: Setup
import os
import sys
import yaml
import json
import pickle
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from google.colab import drive
import timm
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.utils.class_weight import compute_class_weight

# Mount Drive
drive.mount('/content/drive')

# Deterministic behavior
SEED = 42
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_seed(SEED)

# Check GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))

In [None]:
# CELL 2: Load Config
PROJECT_ROOT = "/content/drive/MyDrive/wafer-hackathon"
CONFIG_PATH = os.path.join(PROJECT_ROOT, "configs/colab_training_config.yaml")

print(f"Loading config from {CONFIG_PATH}...")
with open(CONFIG_PATH, 'r') as f:
    config = yaml.safe_load(f)

# Extract critical paths
PROCESSED_DATA_PATH = os.path.join(PROJECT_ROOT, config['data']['processed_data'])
CHECKPOINT_DIR = os.path.join(PROJECT_ROOT, config['checkpoint']['checkpoint_path'])
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

print("Configuration loaded.")

# Sanity checks
print("Processed data path:", PROCESSED_DATA_PATH)
print("Checkpoint directory:", CHECKPOINT_DIR)
print("Batch size:", config['training']['batch_size'])
print("Epochs:", config['training']['epochs'])
print("Backbone:", config['model']['backbone'])


In [None]:
# CELL 3: Load Data (No Re-splitting)
print(f"Loading dataset from {PROCESSED_DATA_PATH}...")

with open(PROCESSED_DATA_PATH, 'rb') as f:
    dataset_pkl = pickle.load(f)

all_images = dataset_pkl['images']
all_labels = np.array(dataset_pkl['labels'])

# CRITICAL: Use stored split indices
train_idx = dataset_pkl['train_indices']
val_idx = dataset_pkl['val_indices']
test_idx = dataset_pkl['test_indices']

print("Data loaded successfully.")
print(f"Train size: {len(train_idx)}")
print(f"Val size:   {len(val_idx)}")
print(f"Test size:  {len(test_idx)}")

# Verify Train Class Distribution
from collections import Counter
train_dist = Counter(all_labels[train_idx])
print("Train Distribution:", dict(train_dist))

In [None]:
# CELL 4: Dataset Class
class WaferDataset(Dataset):
    def __init__(self, images, labels, indices, transform=None):
        self.images = images
        self.labels = labels
        self.indices = indices
        self.transform = transform

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        real_idx = self.indices[idx]
        img_array = self.images[real_idx]
        label = self.labels[real_idx]

        # Convert numpy array to PIL Image for torchvision transforms
        # The raw images are (H, W).
        # We simulate RGB by stacking 3 channels as required by ImageNet-pretrained ViT
        
        # Simple Manual Grayscale -> RGB (H, W) -> (H, W, 3)
        img_rpc = np.stack([img_array]*3, axis=-1).astype(np.float32)
        
        # To PIL
        img_pil = transforms.ToPILImage()(img_rpc)
        
        if self.transform:
            img_tensor = self.transform(img_pil)
        else:
            img_tensor = transforms.ToTensor()(img_pil)
            
        return img_tensor, torch.tensor(label, dtype=torch.long)

# Define Transforms
img_size = 224
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

val_transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

In [None]:
# CELL 4.5: Dataset Sanity Check
sample_img, sample_label = WaferDataset(
    all_images,
    all_labels,
    train_idx,
    train_transform
)[0]

print("Sample Shape:", sample_img.shape)
print("Sample Range:", sample_img.min().item(), sample_img.max().item())
print("Label Type:", sample_label.dtype)

In [None]:
# CELL 5: DataLoaders
batch_size = config['training']['batch_size']

train_set = WaferDataset(all_images, all_labels, train_idx, transform=train_transform)
val_set = WaferDataset(all_images, all_labels, val_idx, transform=val_transform)

# Shuffle train only
train_loader = DataLoader(
    train_set, 
    batch_size=batch_size, 
    shuffle=True, 
    num_workers=2, 
    pin_memory=True, 
    drop_last=True, 
    persistent_workers=True
)

val_loader = DataLoader(
    val_set, 
    batch_size=batch_size, 
    shuffle=False, 
    num_workers=2, 
    pin_memory=True, 
    persistent_workers=True
)

print(f"Train Batches: {len(train_loader)}")
print(f"Val Batches: {len(val_loader)}")

In [None]:
# CELL 5.5: DataLoader Sanity Check
images_batch, labels_batch = next(iter(train_loader))

print("Images Batch:", images_batch.shape)
print("Labels Batch:", labels_batch.shape)

In [None]:
# CELL 6: Model Definition
model_name = config['model']['backbone']
num_classes = config['model']['num_classes']
dropout_rate = config['model']['dropout']

print(f"Creating model: {model_name}")

# 1. Load backbone without classifier
backbone = timm.create_model(model_name, pretrained=True, num_classes=0)

# 2. Get feature dimension
in_features = backbone.num_features

# 3. Define classifier head
classifier = nn.Sequential(
    nn.Dropout(dropout_rate),
    nn.Linear(in_features, num_classes)
)

# 4. Wrap into full model
class WaferClassifier(nn.Module):
    def __init__(self, backbone, classifier):
        super().__init__()
        self.backbone = backbone
        self.classifier = classifier
        
    def forward(self, x):
        features = self.backbone(x)
        return self.classifier(features)

model = WaferClassifier(backbone, classifier)
model = model.to(device)

param_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable Parameters: {param_count:,}")

In [None]:
# CELL 7: Loss & Optimizer

# Handle Imbalance using Class Weights
if config['training']['use_class_weights']:
    train_labels = all_labels[train_idx]
    class_weights = compute_class_weight('balanced', classes=np.arange(num_classes), y=train_labels)
    
    # Normalize weights
    class_weights = class_weights / class_weights.mean()
    
    class_weights_reg = torch.tensor(class_weights, dtype=torch.float32).to(device)
    print("Class Weights enabled:", class_weights_reg)
else:
    class_weights_reg = None

# Loss function with optional label smoothing
criterion = nn.CrossEntropyLoss(
    weight=class_weights_reg, 
    label_smoothing=config['training'].get('label_smoothing', 0.0)
)

lr = config['training']['learning_rate']
wd = config['training']['weight_decay']

optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)

# Cosine Scheduler
num_epochs = config['training']['epochs']
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6)

In [None]:
# CELL 8: Training Loop
use_amp = config['training'].get('use_amp', False)
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

best_acc = 0.0
history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': [], 'val_top3_acc': []}
checkpoint_every = config['checkpoint']['checkpoint_every']

print(f"Starting training (AMP={use_amp})...")

for epoch in range(num_epochs):
    # --- TRAIN ---
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]")
    for images, labels in loop:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        
        with torch.cuda.amp.autocast(enabled=use_amp):
            outputs = model(images)
            loss = criterion(outputs, labels)
            
        if use_amp:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()
        
        running_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        loop.set_postfix(loss=loss.item())
        
    epoch_loss = running_loss / total
    epoch_acc = correct / total
    
    # --- VALIDATE ---
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_top3_correct = 0
    val_total = 0
    
    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            val_loss += loss.item() * images.size(0)
            _, predicted = outputs.max(1)
            val_total += labels.size(0)
            val_correct += predicted.eq(labels).sum().item()
            
            # Top-3
            _, top3_preds = outputs.topk(3, dim=1)
            val_top3_correct += top3_preds.eq(labels.view(-1, 1).expand_as(top3_preds)).sum().item()
            
    val_epoch_loss = val_loss / val_total
    val_epoch_acc = val_correct / val_total
    val_epoch_top3 = val_top3_correct / val_total
    
    # Update Schedule
    scheduler.step()
    
    # Store History
    history['train_loss'].append(epoch_loss)
    history['train_acc'].append(epoch_acc)
    history['val_loss'].append(val_epoch_loss)
    history['val_acc'].append(val_epoch_acc)
    history['val_top3_acc'].append(val_epoch_top3)
    
    print(f"Results: Train Loss={epoch_loss:.4f}, Acc={epoch_acc:.4f} | Val Loss={val_epoch_loss:.4f}, Acc={val_epoch_acc:.4f}, Top3={val_epoch_top3:.4f}")
    
    # Save Best Model
    if val_epoch_acc > best_acc:
        best_acc = val_epoch_acc
        best_dir = os.path.dirname(os.path.join(PROJECT_ROOT, "models/final/best_model.pth"))
        os.makedirs(best_dir, exist_ok=True)
        best_path = os.path.join(PROJECT_ROOT, "models/final/best_model.pth")
        torch.save(model.state_dict(), best_path)
        print(f"  New best model saved! ({best_acc:.4f})")
        
    # Regular Checkpoint
    if (epoch + 1) % checkpoint_every == 0:
        ckpt_path = os.path.join(CHECKPOINT_DIR, f"ckpt_epoch_{epoch+1}.pth")
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'loss': val_epoch_loss,
        }, ckpt_path)
        print(f"  Checkpoint saved: {ckpt_path}")

In [None]:
# CELL 9: Plot Curves
plot_dir = os.path.dirname(os.path.join(PROJECT_ROOT, "models/metrics/training_curves.png"))
os.makedirs(plot_dir, exist_ok=True)

plt.figure(figsize=(12, 5))

# Accuracy
plt.subplot(1, 2, 1)
plt.plot(history['train_acc'], label='Train Acc')
plt.plot(history['val_acc'], label='Val Acc')
plt.title('Accuracy')
plt.xlabel('Epoch')
plt.legend()
plt.grid(True)

# Loss
plt.subplot(1, 2, 2)
plt.plot(history['train_loss'], label='Train Loss')
plt.plot(history['val_loss'], label='Val Loss')
plt.title('Loss')
plt.xlabel('Epoch')
plt.legend()
plt.grid(True)

plot_path = os.path.join(PROJECT_ROOT, "models/metrics/training_curves.png")
plt.savefig(plot_path)
print(f"Curves saved to {plot_path}")
plt.show()

In [None]:
# CELL 10: Save Artifacts
history_path = os.path.join(PROJECT_ROOT, "models/metrics/training_history.json")

with open(history_path, 'w') as f:
    json.dump(history, f)
    
final_model_path = os.path.join(PROJECT_ROOT, "models/final/final_model.pth")
torch.save(model.state_dict(), final_model_path)

print("Training Complete. Artifacts saved.")