In [None]:

import logging
import sys
import warnings

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import time

# Reduce noisy compiler chatter and known transient warnings
try:
    import torch._logging as torch_logging
    torch_logging.set_logs(dynamo=logging.ERROR, inductor=logging.ERROR)
except Exception:
    pass
warnings.filterwarnings("ignore", message="Error detected in MaxPool2DBackward0")

# Reproducibility tweaks
torch.manual_seed(0)

if torch.cuda.is_available():
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    pass
else:
    torch.set_float32_matmul_precision("high")

# Device configuration
device = torch.device(
    "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
)
print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")

try:
    torch.set_flush_denormal(True)
    print("Attempted to enable denormal flushing.")
except AttributeError:
    print("torch.set_flush_denormal is not available in this PyTorch version/build for this platform.")
except Exception as e:
    print(f"Could not set denormal flushing: {e}")

# Hyperparameters
NEW_BATCH_SIZE = 1024
NUM_EPOCHS_TARGET = 5
MAX_LR = 0.0218
WEIGHT_DECAY = 3e-5
NUM_WORKERS_DATALOADER = 6
PCT_START_OCLR = 0.22
AUG_PROB = 0.75
CHANNELS_C1 = 10
DIV_FACTOR_OCLR = 18
FINAL_DIV_FACTOR_OCLR = 1e5
PREFETCH_FACTOR = 4
USE_COMPILE = True
SWA_KEEP = 2
PIN_MEMORY = device.type == "cuda"
NON_BLOCKING_TRANSFER = device.type in {"cuda", "mps"}

# Transforms
transform1 = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
    transforms.RandomApply([
        transforms.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1))
    ], p=AUG_PROB)
])

# DataLoaders
trainset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform)
loader_kwargs = {
    "batch_size": NEW_BATCH_SIZE,
    "shuffle": True,
    "num_workers": NUM_WORKERS_DATALOADER,
    "pin_memory": PIN_MEMORY
}
if NUM_WORKERS_DATALOADER > 0:
    loader_kwargs["persistent_workers"] = True
    loader_kwargs["prefetch_factor"] = PREFETCH_FACTOR
    if sys.platform == "darwin":
        loader_kwargs["multiprocessing_context"] = "fork"
trainloader = torch.utils.data.DataLoader(trainset, **loader_kwargs)

testset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=False, transform=transform1)
test_loader_kwargs = {
    "batch_size": NEW_BATCH_SIZE * 2,
    "shuffle": False,
    "num_workers": NUM_WORKERS_DATALOADER,
    "pin_memory": PIN_MEMORY
}
if NUM_WORKERS_DATALOADER > 0:
    test_loader_kwargs["persistent_workers"] = True
    test_loader_kwargs["prefetch_factor"] = max(2, PREFETCH_FACTOR // 2)
    if sys.platform == "darwin":
        test_loader_kwargs["multiprocessing_context"] = "fork"
testloader = torch.utils.data.DataLoader(testset, **test_loader_kwargs)

# Model Definition
class LeNet5(nn.Module):
    def __init__(self, c1_channels=6):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(1, c1_channels, 5)
        self.conv2 = nn.Conv2d(c1_channels, 16, 5)
        self.conv3 = nn.Conv2d(16, 32, 3)
        self.fc1 = nn.Linear(32 * 1 * 1, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

        self.bn1 = nn.BatchNorm2d(c1_channels)
        self.bn2 = nn.BatchNorm2d(16)
        self.bn3 = nn.BatchNorm2d(32)
        self.bn4 = nn.BatchNorm1d(120)
        self.bn5 = nn.BatchNorm1d(84)

        self.relu = nn.ReLU(inplace=True)
        self.pool = nn.MaxPool2d(2)

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.pool(self.relu(self.bn1(self.conv1(x))))
        x = self.pool(self.relu(self.bn2(self.conv2(x))))
        x = self.pool(self.relu(self.bn3(self.conv3(x))))
        x = torch.flatten(x, 1)
        x = self.relu(self.bn4(self.fc1(x)))
        x = self.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)
        return x

net = LeNet5(c1_channels=CHANNELS_C1).to(device=device)
compiled_model = False
if USE_COMPILE and hasattr(torch, "compile"):
    try:
        net = torch.compile(net, mode="reduce-overhead")
        compiled_model = True
        print("Model compiled with torch.compile (reduce-overhead mode).")
    except Exception as exc:
        print(f"torch.compile failed; falling back to eager execution: {exc}")

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(net.parameters(), lr=MAX_LR, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=MAX_LR,
    epochs=NUM_EPOCHS_TARGET,
    steps_per_epoch=len(trainloader),
    pct_start=PCT_START_OCLR,
    div_factor=DIV_FACTOR_OCLR,
    final_div_factor=FINAL_DIV_FACTOR_OCLR
)

# Training Loop warmup for compiled model
if compiled_model:
    warmup_iter = iter(trainloader)
    warmup_inputs, warmup_labels = next(warmup_iter)
    warmup_inputs = warmup_inputs.to(device, non_blocking=NON_BLOCKING_TRANSFER)
    warmup_labels = warmup_labels.to(device, non_blocking=NON_BLOCKING_TRANSFER)
    optimizer.zero_grad(set_to_none=True)
    warmup_outputs = net(warmup_inputs)
    warmup_loss = criterion(warmup_outputs, warmup_labels)
    warmup_loss.backward()
    optimizer.zero_grad(set_to_none=True)
    if hasattr(torch, 'mps') and torch.backends.mps.is_available():
        torch.mps.synchronize()
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    del warmup_iter, warmup_inputs, warmup_labels, warmup_outputs, warmup_loss

start_time = time.time()
print(f"Starting training with num_workers={NUM_WORKERS_DATALOADER} for {NUM_EPOCHS_TARGET} epochs.")
print(f"MAX_LR={MAX_LR}, PCT_START={PCT_START_OCLR}, AUG_PROB={AUG_PROB}, C1_CHANNELS={CHANNELS_C1}, DIV_FACTOR={DIV_FACTOR_OCLR}, FINAL_DIV_FACTOR={FINAL_DIV_FACTOR_OCLR}")
cumulative_time = 0
recent_states = []

for epoch in range(NUM_EPOCHS_TARGET):
    net.train()
    running_loss = 0.0
    epoch_start_time = time.time()

    for inputs, labels in trainloader:
        inputs = inputs.to(device, non_blocking=NON_BLOCKING_TRANSFER)
        labels = labels.to(device, non_blocking=NON_BLOCKING_TRANSFER)

        optimizer.zero_grad(set_to_none=True)

        outputs = net(inputs)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()
        scheduler.step()

        running_loss += loss.item()

    epoch_duration = time.time() - epoch_start_time
    cumulative_time += epoch_duration
    avg_loss = running_loss / len(trainloader)
    current_lr = optimizer.param_groups[0]['lr']
    print(f'Epoch {epoch + 1}/{NUM_EPOCHS_TARGET}, Loss: {avg_loss:.4f}, LR: {current_lr:.6f}, Epoch Time: {epoch_duration:.2f}s, Total Time: {cumulative_time:.2f}s')

    net.eval()
    correct_val = 0
    total_val = 0
    with torch.no_grad():
        for images_val, labels_val in testloader:
            images_val = images_val.to(device, non_blocking=NON_BLOCKING_TRANSFER)
            labels_val = labels_val.to(device, non_blocking=NON_BLOCKING_TRANSFER)
            outputs_val = net(images_val)
            _, predicted_val = torch.max(outputs_val.data, 1)
            total_val += labels_val.size(0)
            correct_val += (predicted_val == labels_val).sum().item()
    accuracy_val = 100 * correct_val / total_val
    print(f'Epoch {epoch + 1}/{NUM_EPOCHS_TARGET}, Validation Accuracy: {accuracy_val:.2f}%')
    recent_states.append({k: v.detach().cpu().clone() for k, v in net.state_dict().items()})
    if len(recent_states) > SWA_KEEP:
        recent_states.pop(0)

training_time = time.time() - start_time
print(f"\nTraining finished in {training_time:.2f} seconds.")
if compiled_model:
    print("Trained with a compiled model.")
else:
    print("Trained with an uncompiled (eager) model.")


def evaluate(loader, model):
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device, non_blocking=NON_BLOCKING_TRANSFER)
            labels = labels.to(device, non_blocking=NON_BLOCKING_TRANSFER)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

net.eval()
base_state = {k: v.detach().cpu().clone() for k, v in net.state_dict().items()}
base_accuracy = evaluate(testloader, net)
print(f"Accuracy of the network on the 10000 test images: {base_accuracy:.2f} %")
if recent_states:
    avg_state = {k: sum(state[k] for state in recent_states) / len(recent_states) for k in base_state}
    avg_state = {k: v.to(device) for k, v in avg_state.items()}
    net.load_state_dict(avg_state, strict=False)
    swa_accuracy = evaluate(testloader, net)
    print(f"Accuracy of the network on the 10000 test images (SWA {len(recent_states)}): {swa_accuracy:.2f} %")
    avg_state_cpu = {k: v.cpu() for k, v in avg_state.items()}
    net.load_state_dict({k: base_state[k].to(device) for k in base_state}, strict=False)
else:
    swa_accuracy = base_accuracy


In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(4, 4, figsize=(8, 10))
for i in range(16):
    ax = axes[i // 4, i % 4]
    ax.imshow(images.cpu()[i].view(28, 28), cmap='gray')
    ax.set_title('Predicted: {}'.format(predicted[i].item()), fontsize=12, pad=15)
    ax.set_xticks([])
    ax.set_yticks([])

plt.tight_layout()
plt.show()