<a href="https://colab.research.google.com/github/dhirenmalik/comp3710-demo-2/blob/main/COMP37103_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import time
from sklearn.datasets import fetch_lfw_people
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import torch
import torch.nn as nn

# Task
Implement a classification neural network for the CIFAR10 dataset.

## Load and preprocess cifar10 data

### Subtask:
Load the CIFAR10 dataset and perform any necessary preprocessing steps, such as normalization and reshaping.


In [None]:
from sklearn.datasets import fetch_openml

# Fetch the CIFAR10
X, y = fetch_openml("CIFAR_10", version=1, return_X_y=True)
print("Shape of X:", X.shape)
print("Shape of y:", y.shape)

Shape of X: (60000, 3072)
Shape of y: (60000,)


**Reasoning**:
Load the CIFAR10 dataset, print the shapes, preprocess the data for CNN, and split into training and testing sets.



In [None]:
# Reshape X to (num_samples, channels, height, width) for PyTorch
X_reshaped = X.values.reshape(-1, 3, 32, 32).astype(np.float32)

# Normalize the data to the range [0, 1]
X_normalized = X_reshaped / 255.0

# Convert to PyTorch tensors
X_tensor = torch.tensor(X_normalized, dtype=torch.float32)
y_tensor = torch.tensor(y.astype('category').cat.codes, dtype=torch.long) # Convert string labels to long tensors

# Split into training and test sets
X_train_cifar, X_test_cifar, y_train_cifar, y_test_cifar = train_test_split(
    X_tensor, y_tensor, test_size=0.2, random_state=42, stratify=y_tensor
)

print("X_train_cifar shape:", X_train_cifar.shape)
print("y_train_cifar shape:", y_train_cifar.shape)
print("X_test_cifar shape:", X_test_cifar.shape)
print("y_test_cifar shape:", y_test_cifar.shape)

X_train_cifar shape: torch.Size([48000, 3, 32, 32])
y_train_cifar shape: torch.Size([48000])
X_test_cifar shape: torch.Size([12000, 3, 32, 32])
y_test_cifar shape: torch.Size([12000])


## Define the neural network model, loss function, and optimizer

### Subtask:
Define a ResNet-18 model for CIFAR10 and choose an appropriate loss function and optimizer.

**Reasoning**:
Define a ResNet-18 model, modify its final layer for CIFAR10's 10 classes, and set up the Cross-Entropy Loss and Adam optimizer.

In [None]:
from __future__ import annotations
from typing import Type, List, Optional
import torch.nn as nn
import torch.nn.functional as F
import torch

# ---------------------------
# Building block: BasicBlock (modified to match the second implementation)
# ---------------------------
class BasicBlock(nn.Module):
    """ResNet-18/34 residual block with two 3x3 convs, matching the second implementation structure."""
    expansion: int = 1  # output channels multiplier (1 for BasicBlock)

    def __init__(
        self,
        in_planes: int,
        planes: int,
        stride: int = 1,
        norm_layer: Optional[Type[nn.Module]] = None,
        bn_decay_rate: float = 0.9, # Added from second implementation
        bn_eps: float = 1e-5, # Added from second implementation
        kernel_size: int = 3, # Added from second implementation
        rate: int = 1, # Added from second implementation for dilation (though unused in this ResNet config)
    ):
        super().__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d

        padding = (kernel_size // 2) * rate # Calculate padding based on kernel and dilation


        # Shortcut projection: 1x1 conv + BN (always, to match second implementation)
        self.downsample = nn.Sequential(
            nn.Conv2d(in_planes, planes * self.expansion, kernel_size=1, stride=stride, bias=False),
            norm_layer(planes * self.expansion, eps=bn_eps, momentum=1.0 - bn_decay_rate),
        )

        # First 3x3 conv
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=kernel_size, stride=stride,
                               padding=padding, dilation=rate, bias=False) # Added dilation
        self.bn1 = norm_layer(planes, eps=bn_eps, momentum=1.0 - bn_decay_rate)
        self.relu = nn.ReLU(inplace=True)

        # Second 3x3 conv
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=kernel_size, stride=1,
                               padding=padding, dilation=rate, bias=False) # Added dilation
        self.bn2 = norm_layer(planes, eps=bn_eps, momentum=1.0 - bn_decay_rate)

        # Removed dropout layers

        # Initialize weights (Kaiming normal for convs)
        for m in [self.downsample[0], self.conv1, self.conv2]:
             nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')


    def forward(self, x: torch.Tensor, is_training: bool = True, test_local_stats: bool = False) -> torch.Tensor:
        # Note: In PyTorch, BN uses module.train()/eval() to pick batch vs running stats.
        # We keep is_training/test_local_stats in the signature for parity but do not override BN behavior here.

        identity = self.downsample(x) # Use the always-present downsample

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        # Removed dropout

        out = self.conv2(out)
        out = self.bn2(out)

        out = out + identity # Residual connection
        out = self.relu(out) # ReLU after addition
        # Removed dropout
        return out


# ---------------------------
# ResNet backbone (modified to match the second implementation)
# ---------------------------
class ResNet(nn.Module):
    """
    ResNet variant matching the provided Haiku version structure.
    By default this builds a ResNet-style model using BasicBlock logic.
    """
    def __init__(
        self,
        block: Type[BasicBlock],
        layers: List[int], # This structure is from the original ResNet, second model uses fixed layers
        num_classes: int = 10, # Default for CIFAR10
        in_channels: int = 3,
        norm_layer: Optional[Type[nn.Module]] = None,
        bn_decay_rate: float = 0.9, # Added from second implementation
        bn_eps: float = 1e-5, # Added from second implementation
        channels: int = 64, # Added from second implementation for initial channels
    ):
        super().__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d

        self.inplanes = channels # Start with the specified initial channels

        # Initial conv: 3x3, stride=1, no bias + BN + ReLU (matching second implementation)
        self.conv1 = nn.Conv2d(in_channels, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1   = norm_layer(self.inplanes, eps=bn_eps, momentum=1.0 - bn_decay_rate)
        self.relu  = nn.ReLU(inplace=True)
        self.maxpool = nn.Identity()  # no pooling in the stem, matches second implementation logic

        # Stages matching the second implementation's channel progression and structure
        # Note: The layer structure [2, 2, 2, 2] is used here for ResNet-18 style blocks,
        # but the channel sizes and strides follow the second implementation's pattern.
        self.layer1 = self._make_layer(block, channels,  2, stride=1, norm_layer=norm_layer, bn_decay_rate=bn_decay_rate, bn_eps=bn_eps) # channels -> channels
        self.layer2 = self._make_layer(block, channels * 2, 2, stride=2, norm_layer=norm_layer, bn_decay_rate=bn_decay_rate, bn_eps=bn_eps) # channels -> channels * 2
        self.layer3 = self._make_layer(block, channels * 4, 2, stride=2, norm_layer=norm_layer, bn_decay_rate=bn_decay_rate, bn_eps=bn_eps) # channels * 2 -> channels * 4
        self.layer4 = self._make_layer(block, channels * 4, 2, stride=2, norm_layer=norm_layer, bn_decay_rate=bn_decay_rate, bn_eps=bn_eps) # channels * 4 -> channels * 4 (stays at 256)


        # Head: global average pool + linear classifier (matching second implementation)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        # Output channels for the final linear layer is based on the last stage's channels
        self.fc = nn.Linear(channels * 4 * block.expansion, num_classes) # Use channels * 4
        # Initialize final linear layer weights and bias to zero
        nn.init.zeros_(self.fc.weight)
        nn.init.zeros_(self.fc.bias)


        # No _init_weights function called here, initialization is done per module/block

    def _make_layer(
        self,
        block: Type[BasicBlock],
        planes: int,
        blocks: int,
        stride: int,
        norm_layer: Type[nn.Module],
        bn_decay_rate: float,
        bn_eps: float,
    ) -> nn.Sequential:
        """
        Create one ResNet stage with `blocks` residual blocks, matching the second implementation's logic.
        The first block may downsample spatially via `stride=2`.
        """
        # Downsampling is handled *within* the BasicBlock in the second implementation's logic
        # So we don't need a separate downsample module defined here.
        # The BasicBlock's internal proj_conv handles the channel/stride matching.

        layers = []
        # The first block in a stage handles the potential stride and channel change
        layers.append(block(self.inplanes, planes, stride, norm_layer, bn_decay_rate, bn_eps))
        self.inplanes = planes * block.expansion

        # Subsequent blocks in the stage have stride 1 and same channels
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, 1, norm_layer, bn_decay_rate, bn_eps))

        return nn.Sequential(*layers)

    def forward(self, x: torch.Tensor, is_training: bool = True, test_local_stats: bool = False) -> torch.Tensor:
      # Note: is_training and test_local_stats are passed to blocks but don't override BN behavior here.
      # 32x32 -> 32x32 (stride 1, no pooling)
      x = self.conv1(x)
      x = self.bn1(x)
      x = self.relu(x)
      x = self.maxpool(x) # Identity

      # Residual stages (spatial sizes and channel progression match second implementation)
      x = self.layer1(x, is_training=is_training, test_local_stats=test_local_stats)  # 32x32 -> 32x32, channels
      x = self.layer2(x, is_training=is_training, test_local_stats=test_local_stats)  # 32x32 -> 16x16, channels * 2
      x = self.layer3(x, is_training=is_training, test_local_stats=test_local_stats)  # 16x16 -> 8x8, channels * 4
      x = self.layer4(x, is_training=is_training, test_local_stats=test_local_stats)  # 8x8 -> 4x4, channels * 4


      x = self.avgpool(x)         # [B, C, 1, 1]
      x = torch.flatten(x, 1)     # [B, C]
      x = self.fc(x)              # [B, num_classes]
      return x


def resnet18(num_classes: int = 10, in_channels: int = 3, dropout_prob: float = 0.0) -> ResNet: # Keep signature but dropout_prob is ignored
    """Factory for ResNet-18 style model matching the second implementation structure."""
    # The layer structure [2, 2, 2, 2] from original ResNet18 is used for number of blocks per stage,
    # but the channel sizes and strides are fixed within the ResNet class to match the second implementation.
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_channels=in_channels, channels=64) # channels=64 matching second impl initial channels

In [None]:
import torch
from torch.utils.data import DataLoader, TensorDataset
import torchvision.transforms as transforms
from torchvision import datasets
import torch.optim.lr_scheduler as lr_scheduler

# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load a pre-trained ResNet-18 model
model = resnet18(num_classes=10, dropout_prob=0).to(device)

# Define the loss function and optimizer
# criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)

# Parameters from the user's working model
epochs = 35
batch_size = 128
decay = 5e-4
learning_rate = 0.1

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=0.0)

# Define data augmentation transforms
transform_train = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
    transforms.Normalize((0.4917, 0.4824, 0.4469), (0.2024, 0.1995, 0.2011)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
])

transform_test = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
    transforms.Normalize((0.4917, 0.4824, 0.4469), (0.2024, 0.1995, 0.2011)),
])

# Apply transforms to the data before creating TensorDatasets
# Note: For simplicity and to match the user's provided context which doesn't explicitly show
# transforms being applied *before* creating the TensorDataset, we will keep the current approach
# of having transforms within the DataLoader (or applied implicitly by the DataLoader if using
# standard torchvision datasets, which we are not here). The current setup with TensorDataset
# and no explicit transform application in the DataLoader means transforms are not being used.
# If transforms are desired, they should be applied here or by using torchvision.datasets directly.
# For now, we will proceed without explicit transforms applied to the tensors.
X_train_cifar_transformed = torch.stack([transform_train(img) for img in X_train_cifar])
X_test_cifar_transformed = torch.stack([transform_test(img) for img in X_test_cifar])


# Create TensorDatasets from the transformed data
train_dataset = TensorDataset(X_train_cifar_transformed, y_train_cifar)
test_dataset = TensorDataset(X_test_cifar_transformed, y_test_cifar)


# Create DataLoaders from the TensorDatasets
# Transforms are not applied here with TensorDataset. If transforms are needed,
# consider using torchvision.datasets.CIFAR10 directly or applying transforms
# before creating the TensorDataset.
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                          num_workers=2, pin_memory=(device.type=="cuda"))
val_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
                          num_workers=2, pin_memory=(device.type=="cuda"))

print("Model defined and moved to:", device)

Model defined and moved to: cuda


## Train the model

### Subtask:
Train the defined ResNet-18 model on the CIFAR10 training data.

In [None]:
import time
import torch.optim.lr_scheduler as lr_scheduler
from torch.optim import SGD # Import SGD optimizer

# Parameters from the user's working model (aligned with the second model)
epochs = 35 # Changed to 35
batch_size = 128
decay = 5e-4             # L2 regularization coefficient (will be added to loss)
# learning_rate = 0.05 # Base learning rate
learning_rate = 0.1 # Max learning rate for OneCycleLR

# OneCycle parameters (approximate optax.linear_onecycle)
pct_start = 15.0 / epochs
# pct_final = 30./epochs # Not directly used in OneCycleLR in the same way
div_factor = 20.0 # Changed to 20.0
final_div_factor = 200.0 # Changed to 200.0

# Calculate total training steps for the scheduler
total_images = len(train_loader.dataset)
# total_batch_size = batch_size * (torch.cuda.device_count() if device.type == "cuda" else 1) # Not needed with single GPU
# num_train_steps = (total_images * epochs) // total_batch_size # Not needed, use steps_per_epoch * epochs
steps_per_epoch = len(train_loader)
total_steps = steps_per_epoch * epochs

# Define the loss function (CrossEntropyLoss)
ce_loss_fn = nn.CrossEntropyLoss(reduction='mean')


# Define the optimizer (SGD with momentum, no weight_decay here)
# optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=0.0)
optimizer = SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=0.0, nesterov=False)


# Define a learning rate scheduler (OneCycleLR)
scheduler = lr_scheduler.OneCycleLR(optimizer,
                                    max_lr=learning_rate,
                                    total_steps=total_steps, # Use total_steps
                                    pct_start=pct_start,
                                    anneal_strategy='linear', # Use linear annealing
                                    div_factor=div_factor,
                                    final_div_factor=final_div_factor)


# use_amp = (device.type == "cuda") # AMP is not used in the second model's training loop
# scaler = torch.cuda.amp.GradScaler(enabled=use_amp) # Scaler is not used

start_time = time.time() # Start timing

for epoch in range(epochs):
    model.train()  # Set the model to training mode
    running_loss = 0.0
    # Create a test iterator for periodic evaluation similar to the JAX script
    test_iter = iter(val_loader) # Use val_loader for test data

    for i, (inputs, labels) in enumerate(train_loader): # Use the DataLoader
        inputs = inputs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True) # Set gradients to None for performance

        # with torch.cuda.amp.autocast(enabled=use_amp): # AMP is not used
        outputs = model(inputs, is_training=True)  # Forward pass, pass is_training=True
        # Calculate the loss with L2 regularization
        cross_entropy_loss = ce_loss_fn(outputs, labels)

        # Add L2 regularization manually (since weight_decay in optimizer applies to all params)
        # This matches the user's provided loss function structure
        # l2_loss = 0.0 # Calculate L2 loss using the separate function
        # for param in model.parameters():
        #     if param.requires_grad:
        #         l2_loss += torch.norm(param, p=2)**2

        # loss = cross_entropy_loss + 0.5 * decay * l2_loss # Multiply L2 loss by 0.5 * decay as in the user's example
        reg_loss = decay * l2_regularization(model) # Use the l2_regularization function
        loss = cross_entropy_loss + reg_loss


        # scaler.scale(loss).backward()  # Backpropagation with scaler # Not using scaler
        loss.backward()
        # scaler.step(optimizer)  # Update weights with scaler # Not using scaler
        optimizer.step()
        # scaler.update() # Update the scaler # Not using scaler

        # Step the scheduler after each batch
        scheduler.step()

        running_loss += loss.item() # Accumulate raw loss for averaging

        # Print batch loss periodically (optional)
        log_interval = 100 # Define log_interval
        global_step = epoch * len(train_loader) + i # Calculate global step
        if global_step % log_interval == 0:
            # Compute train accuracy on current batch
            train_acc = accuracy_from_logits(outputs, labels) # Use outputs for accuracy calculation

            # Compute test accuracy on one batch (approximate periodic eval)
            try:
                test_images, test_labels = next(test_iter)
            except StopIteration:
                test_iter = iter(val_loader) # Reset iterator with val_loader
                test_images, test_labels = next(test_iter)
            test_images = test_images.to(device, non_blocking=True)
            test_labels = test_labels.to(device, non_blocking=True)
            with torch.no_grad():
                test_logits = model(test_images, is_training=False)
                test_acc = accuracy_from_logits(test_logits, test_labels)

            print(f"[Step {global_step}, Loss {loss.item():.5f}] Train / Test accuracy: {train_acc:.3f} / {test_acc:.3f}")


    avg_loss = running_loss / len(train_loader) # Calculate average loss per epoch
    print(f'Epoch [{epoch+1}/{epochs}], Training Loss: {avg_loss:.4f}')

    # Evaluate on the validation set (full evaluation at end of epoch)
    val_acc = evaluate(model, val_loader) # Use the evaluate function
    print(f'Epoch [{epoch+1}/{epochs}], Validation Accuracy: {val_acc:.4f}%') # Print as percentage with 4 decimal places


end_time = time.time() # End timing
total_training_time = end_time - start_time
print(f"Total training time: {total_training_time:.2f} seconds")

TypeError: Sequential.forward() got an unexpected keyword argument 'is_training'

## Evaluate the model

### Subtask:
Evaluate the trained ResNet-18 model on the CIFAR10 test data and report the classification metrics.

In [None]:
model.eval()
all_preds = []
all_labels = []
with torch.inference_mode():
    for images, labels in val_loader:
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        preds = model(images)

        all_preds.append(preds.argmax(dim=1).cpu())
        all_labels.append(labels.cpu())

all_preds = torch.cat(all_preds)
all_labels = torch.cat(all_labels)

print(classification_report(all_labels, all_preds))

# Shakes code adapted using chatGPT

In [None]:
import os
import time
import math
import random
from typing import Tuple, Dict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader

import torchvision
import torchvision.transforms as T

# Reproducibility
seed = 42
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


# Cell 2: Model Definition (ResNet + ResNetBlock)

In [None]:
def same_padding(kernel_size: int, dilation: int = 1) -> int:
    """
    Compute SAME padding for odd kernel sizes in 2D.
    For a 3x3 kernel, padding = dilation.
    """
    assert kernel_size % 2 == 1, "SAME padding formula here assumes odd kernel size."
    return dilation * (kernel_size // 2)

class ResNetBlock(nn.Module):
    """
    Residual Net Block
    Matches the Haiku implementation:
    - Projection 1x1 conv + BN on the shortcut path (always, even when stride=1 and channels match).
    - Two 3x3 convs each with BN; ReLU after the first BN; final ReLU after addition.
    - Supports dilation via 'rate'.
    """
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        stride: int,
        rate: int = 1,
        kernel_size: int = 3,
        bn_decay_rate: float = 0.9,
        bn_eps: float = 1e-5,
        name: str = None
    ):
        super().__init__()
        self.name = name
        padding = same_padding(kernel_size, rate)

        # Shortcut projection: 1x1 conv, stride=stride, no bias
        self.proj_conv = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=1,
            stride=stride,
            padding=0,
            bias=False
        )
        # BatchNorm config: momentum in PyTorch = 1 - decay_rate in Haiku
        self.proj_bn = nn.BatchNorm2d(
            num_features=out_channels,
            eps=bn_eps,
            momentum=1.0 - bn_decay_rate,
            affine=True,
            track_running_stats=True
        )

        # First conv-bn
        self.conv1 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=rate,
            bias=False
        )
        self.bn1 = nn.BatchNorm2d(
            num_features=out_channels,
            eps=bn_eps,
            momentum=1.0 - bn_decay_rate,
            affine=True,
            track_running_stats=True
        )

        # Second conv-bn
        self.conv2 = nn.Conv2d(
            in_channels=out_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=1,
            padding=padding,
            dilation=rate,
            bias=False
        )
        self.bn2 = nn.BatchNorm2d(
            num_features=out_channels,
            eps=bn_eps,
            momentum=1.0 - bn_decay_rate,
            affine=True,
            track_running_stats=True
        )

        # Initialization: Kaiming normal for convs to approximate Haiku's VarianceScaling
        for m in [self.proj_conv, self.conv1, self.conv2]:
            nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')

        # BN affine params default to gamma=1, beta=0 (already default)

    def forward(self, x: torch.Tensor, is_training: bool = True, test_local_stats: bool = False) -> torch.Tensor:
        # Note: In PyTorch, BN uses module.train()/eval() to pick batch vs running stats.
        # We keep is_training/test_local_stats in the signature for parity but do not override BN behavior here.

        shortcut = self.proj_bn(self.proj_conv(x))

        out = self.conv1(x)
        out = self.bn1(out)
        out = F.relu(out, inplace=True)

        out = self.conv2(out)
        out = self.bn2(out)

        out = out + shortcut
        out = F.relu(out, inplace=True)
        return out

class ResNet(nn.Module):
    """
    ResNet variant matching the provided Haiku version:
    - Initial conv (3x3, stride 1, channels=64 by default) + BN + ReLU
    - Layers of residual blocks:
        1) 64->64 (stride 1), 64->64 (stride 1)
        2) 64->128 (stride 2), 128->128 (stride 1)
        3) 128->256 (stride 2), 256->256 (stride 1)
        4) 256->256 (stride 2), 256->256 (stride 1)   [note: stays at 256, not 512]
    - Global average pooling
    - Final Linear to output_size, weight initialized to zeros (and bias zeros) to match hk.Linear(w_init=zeros)
    """
    def __init__(
        self,
        output_size: int,
        channels: int = 64,
        bn_decay_rate: float = 0.9,
        bn_eps: float = 1e-5
    ):
        super().__init__()
        self.output_size = output_size
        self.channels = channels

        # Initial conv: 3x3, stride=1, no bias + BN + ReLU
        self.initial_conv = nn.Conv2d(
            in_channels=3,
            out_channels=channels,
            kernel_size=3,
            stride=1,
            padding=1,
            bias=False
        )
        self.initial_bn = nn.BatchNorm2d(
            num_features=channels,
            eps=bn_eps,
            momentum=1.0 - bn_decay_rate,
            affine=True,
            track_running_stats=True
        )

        nn.init.kaiming_normal_(self.initial_conv.weight, mode='fan_in', nonlinearity='relu')

        # Residual blocks as per provided Haiku layout
        self.res1a = ResNetBlock(channels, channels, stride=1, rate=1, bn_decay_rate=bn_decay_rate, bn_eps=bn_eps, name="resblock_1a")
        self.res1b = ResNetBlock(channels, channels, stride=1, rate=1, bn_decay_rate=bn_decay_rate, bn_eps=bn_eps, name="resblock_1b")

        self.res2a = ResNetBlock(channels, channels * 2, stride=2, rate=1, bn_decay_rate=bn_decay_rate, bn_eps=bn_eps, name="resblock_2a")
        self.res2b = ResNetBlock(channels * 2, channels * 2, stride=1, rate=1, bn_decay_rate=bn_decay_rate, bn_eps=bn_eps, name="resblock_2b")

        self.res3a = ResNetBlock(channels * 2, channels * 4, stride=2, rate=1, bn_decay_rate=bn_decay_rate, bn_eps=bn_eps, name="resblock_3a")
        self.res3b = ResNetBlock(channels * 4, channels * 4, stride=1, rate=1, bn_decay_rate=bn_decay_rate, bn_eps=bn_eps, name="resblock_3b")

        self.res4a = ResNetBlock(channels * 4, channels * 4, stride=2, rate=1, bn_decay_rate=bn_decay_rate, bn_eps=bn_eps, name="resblock_4a")
        self.res4b = ResNetBlock(channels * 4, channels * 4, stride=1, rate=1, bn_decay_rate=bn_decay_rate, bn_eps=bn_eps, name="resblock_4b")

        # Final linear: weight zero init to match hk.Linear(w_init=zeros)
        self.fc = nn.Linear(channels * 4, output_size, bias=True)
        nn.init.zeros_(self.fc.weight)
        nn.init.zeros_(self.fc.bias)

    def forward(self, x: torch.Tensor, is_training: bool = True, test_local_stats: bool = False) -> torch.Tensor:
        x = self.initial_conv(x)
        x = self.initial_bn(x)
        x = F.relu(x, inplace=True)

        x = self.res1a(x, is_training=is_training, test_local_stats=test_local_stats)
        x = self.res1b(x, is_training=is_training, test_local_stats=test_local_stats)

        x = self.res2a(x, is_training=is_training, test_local_stats=test_local_stats)
        x = self.res2b(x, is_training=is_training, test_local_stats=test_local_stats)

        x = self.res3a(x, is_training=is_training, test_local_stats=test_local_stats)
        x = self.res3b(x, is_training=is_training, test_local_stats=test_local_stats)

        x = self.res4a(x, is_training=is_training, test_local_stats=test_local_stats)
        x = self.res4b(x, is_training=is_training, test_local_stats=test_local_stats)

        # Global average pooling
        x = F.adaptive_avg_pool2d(x, output_size=1)  # [N, C, 1, 1]
        x = torch.flatten(x, 1)  # [N, C]
        logits = self.fc(x)      # [N, output_size]
        return logits

# Cell 3: Data - CIFAR-10 Datasets and DataLoaders

In [None]:
# Hyperparameters mirroring the JAX script
epochs = 35
batch_size = 128
decay = 5e-4             # L2 regularization coefficient (will be added to loss)
learning_rate = 0.1

# OneCycle parameters (approximate optax.linear_onecycle)
pct_start = 15.0 / epochs
div_factor = 20.0
final_div_factor = 200.0

# Standard CIFAR-10 normalization
CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR10_STD = (0.2470, 0.2435, 0.2616)

train_transform = T.Compose([
    T.RandomCrop(32, padding=4),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize(CIFAR10_MEAN, CIFAR10_STD),
])

test_transform = T.Compose([
    T.ToTensor(),
    T.Normalize(CIFAR10_MEAN, CIFAR10_STD),
])

# Download and create datasets
data_root = "./data"
train_dataset = torchvision.datasets.CIFAR10(root=data_root, train=True, download=True, transform=train_transform)
test_dataset = torchvision.datasets.CIFAR10(root=data_root, train=False, download=True, transform=test_transform)

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2,
    pin_memory=True,
    drop_last=False
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
    drop_last=False
)

num_classes = 10
total_images = len(train_dataset)
steps_per_epoch = math.ceil(total_images / batch_size)
total_steps = steps_per_epoch * epochs

print(f"Found {total_images} training images")
print(f"No. of Classes: {num_classes}")
print(f"Steps per epoch: {steps_per_epoch}, Total steps: {total_steps}")

100%|██████████| 170M/170M [00:13<00:00, 12.8MB/s]


Found 50000 training images
No. of Classes: 10
Steps per epoch: 391, Total steps: 13685


# Cell 4: Training Utilities

In [None]:
def l2_regularization(model: nn.Module) -> torch.Tensor:
    """
    Compute 0.5 * sum(||p||^2) over all trainable parameters.
    This mirrors the JAX loss-based L2 regularization (not decoupled weight_decay).
    """
    reg = torch.tensor(0.0, device=device)
    for p in model.parameters():
        if p.requires_grad:
            reg = reg + torch.sum(p.pow(2))
    return 0.5 * reg

@torch.no_grad()
def accuracy_from_logits(logits: torch.Tensor, targets: torch.Tensor) -> float:
    preds = logits.argmax(dim=1)
    return (preds == targets).float().mean().item()

@torch.no_grad()
def evaluate(model: nn.Module, loader: DataLoader) -> float:
    model.eval()
    total_correct = 0
    total_count = 0
    for images, labels in loader:
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        logits = model(images, is_training=False)
        preds = logits.argmax(dim=1)
        total_correct += (preds == labels).sum().item()
        total_count += labels.size(0)
    return total_correct / total_count

def train_one_epoch(
    model: nn.Module,
    loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    scheduler: OneCycleLR,
    epoch: int,
    log_interval: int = 100
):
    model.train()
    ce_loss_fn = nn.CrossEntropyLoss(reduction='mean')
    running_loss = 0.0
    global_step_start = epoch * len(loader)

    # Create a test iterator for periodic evaluation similar to the JAX script
    test_iter = iter(test_loader)

    for batch_idx, (images, labels) in enumerate(loader):
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        logits = model(images, is_training=True)
        ce_loss = ce_loss_fn(logits, labels)
        reg_loss = decay * l2_regularization(model)
        loss = ce_loss + reg_loss
        loss.backward()
        optimizer.step()
        scheduler.step()

        running_loss += loss.item()

        global_step = global_step_start + batch_idx
        if global_step % log_interval == 0:
            # Compute train accuracy on current batch
            train_acc = accuracy_from_logits(logits, labels)

            # Compute test accuracy on one batch (approximate periodic eval)
            try:
                test_images, test_labels = next(test_iter)
            except StopIteration:
                test_iter = iter(test_loader)
                test_images, test_labels = next(test_iter)
            test_images = test_images.to(device, non_blocking=True)
            test_labels = test_labels.to(device, non_blocking=True)
            with torch.no_grad():
                test_logits = model(test_images, is_training=False)
                test_acc = accuracy_from_logits(test_logits, test_labels)

            print(f"[Step {global_step}, Loss {loss.item():.5f}] Train / Test accuracy: {train_acc:.3f} / {test_acc:.3f}")

    avg_loss = running_loss / len(loader)
    return avg_loss

# Cell 5: Instantiate model, optimizer, scheduler, and run training

In [None]:
model = ResNet(output_size=num_classes, channels=64, bn_decay_rate=0.9, bn_eps=1e-5).to(device)

# Optimizer: SGD with momentum; no weight_decay here (we add L2 to the loss explicitly)
optimizer = SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=0.0, nesterov=False)

# OneCycleLR: approximate optax.linear_onecycle_schedule behavior
scheduler = OneCycleLR(
    optimizer,
    max_lr=learning_rate,
    total_steps=total_steps,
    pct_start=pct_start,          # warmup fraction ~ 15/epochs
    anneal_strategy='linear',
    div_factor=div_factor,        # initial lr = max_lr / div_factor
    final_div_factor=final_div_factor
)

print(model)

start_time = time.time()
for epoch in range(epochs):
    avg_loss = train_one_epoch(model, train_loader, optimizer, scheduler, epoch, log_interval=100)
    # End-of-epoch eval (optional)
    val_acc = evaluate(model, test_loader)
    print(f"Epoch {epoch+1}/{epochs}, Training Loss: {avg_loss:.5f}, Validation Acc: {val_acc:.4f}")

elapsed = time.time() - start_time
print(f"Training took {elapsed:.2f} secs or {elapsed/60:.2f} mins in total")

ResNet(
  (initial_conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (initial_bn): BatchNorm2d(64, eps=1e-05, momentum=0.09999999999999998, affine=True, track_running_stats=True)
  (res1a): ResNetBlock(
    (proj_conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (proj_bn): BatchNorm2d(64, eps=1e-05, momentum=0.09999999999999998, affine=True, track_running_stats=True)
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.09999999999999998, affine=True, track_running_stats=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.09999999999999998, affine=True, track_running_stats=True)
  )
  (res1b): ResNetBlock(
    (proj_conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (proj_bn): BatchNorm2d(64, eps=1e-05, momentum=0.099999999999999

# Cell 6: Final evaluation on the test set

In [None]:
final_test_acc = evaluate(model, test_loader)
print("top_1_acc:", final_test_acc)
model.eval()
all_preds = []
all_labels = []
with torch.inference_mode():
    for images, labels in val_loader:
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        preds = model(images)

        all_preds.append(preds.argmax(dim=1).cpu())
        all_labels.append(labels.cpu())

all_preds = torch.cat(all_preds)
all_labels = torch.cat(all_labels)

print(classification_report(all_labels, all_preds))
print("END")

top_1_acc: 0.9397
              precision    recall  f1-score   support

           0       0.98      0.98      0.98      1200
           1       0.99      1.00      0.99      1200
           2       0.98      0.98      0.98      1200
           3       0.96      0.96      0.96      1200
           4       0.98      0.97      0.98      1200
           5       0.97      0.96      0.96      1200
           6       0.98      1.00      0.99      1200
           7       0.99      0.99      0.99      1200
           8       0.99      0.99      0.99      1200
           9       0.99      0.99      0.99      1200

    accuracy                           0.98     12000
   macro avg       0.98      0.98      0.98     12000
weighted avg       0.98      0.98      0.98     12000

END
