In [8]:
import numpy as np
import pandas as pd
import torch
import torch.optim as optim
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import time
import datetime

In [9]:
class ResBlock(nn.Module):
  def __init__(self, n_chans):
    super(ResBlock, self).__init__()
    self.conv = nn.Conv2d(n_chans, n_chans, kernel_size=3, padding=1, bias=False)
    self.batch_norm = nn.BatchNorm2d(num_features=n_chans)
    torch.nn.init.kaiming_normal_(self.conv.weight, nonlinearity='relu')
    torch.nn.init.constant_(self.batch_norm.weight, 0.5)
    torch.nn.init.zeros_(self.batch_norm.bias)

  def forward(self, x):
    out = self.conv(x)
    out = self.batch_norm(out)
    out = torch.relu(out)
    return out + x

class NetResDeep(nn.Module):
  def __init__(self, n_chans1=32, n_blocks=10):
    super(NetResDeep, self).__init__()
    self.n_chans1 = n_chans1
    self.conv1 = nn.Conv2d(3, n_chans1, kernel_size=3, padding=1)
    self.resblocks = nn.Sequential(*(n_blocks * [ResBlock(n_chans=n_chans1)]))
    self.fc1 = nn.Linear(8 * 8 * n_chans1, 32)
    self.fc2 = nn.Linear(32, 10)

  def forward(self, x):
    out = nn.functional.max_pool2d(torch.relu(self.conv1(x)), 2)
    out = self.resblocks(out)
    out = nn.functional.max_pool2d(out, 2)
    out = out.view(-1, 8 * 8 * self.n_chans1)
    out = torch.relu(self.fc1(out))
    out = self.fc2(out)
    return out

In [10]:
def training_loop(n_epochs, model, optimizer, loss_fn, train_loader):
  start_time = time.time()

  for epoch in range(1, n_epochs + 1):
    model.train()
    loss_train = 0.0
    for imgs, labels in train_loader:
      outputs = model(imgs)
      loss = loss_fn(outputs, labels)

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      loss_train += loss.item()

    print(f'{datetime.datetime.now()} Epoch {epoch}, Training loss {(loss_train / len(train_loader)):.4f}')

  end_time = time.time()
  elapsed_time = end_time - start_time
  print(f"Training completed in {elapsed_time:.2f} seconds.")

# Validation function
def validate(model, train_loader, val_loader):
  model.eval()
  for name, loader in [('Training', train_loader), ('Validation', val_loader)]:
    correct = 0
    total = 0

    with torch.no_grad():
      for imgs, labels in loader:
        outputs = model(imgs)
        _, predicted = torch.max(outputs, dim=1)
        total += labels.shape[0]
        correct += int((predicted == labels).sum())

    print(f'{name} accuracy: {correct / total:.4f}')

In [11]:
data_path = '../data-unversioned/p1ch7/'
cifar10 = datasets.CIFAR10(data_path, train=True, download=True,
                             transform=transforms.Compose([
                             transforms.ToTensor(),
                             transforms.Normalize((0.4914, 0.4822, 0.4465),
                                                  (0.2470, 0.2435, 0.2616))
                             ]))
cifar10_val = datasets.CIFAR10(data_path, train=False, download=True,
                             transform=transforms.Compose([
                             transforms.ToTensor(),
                             transforms.Normalize((0.4942, 0.4851, 0.4504),
                                                  (0.2467, 0.2429, 0.2616))
                             ]))

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ../data-unversioned/p1ch7/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:05<00:00, 29.8MB/s]


Extracting ../data-unversioned/p1ch7/cifar-10-python.tar.gz to ../data-unversioned/p1ch7/
Files already downloaded and verified


In [12]:
train_loader = torch.utils.data.DataLoader(cifar10, batch_size=64, shuffle=True)
val_loader = torch.utils.data.DataLoader(cifar10_val, batch_size=64, shuffle=True)

model = NetResDeep()
optimizer = optim.SGD(model.parameters(), lr=1e-2)
loss_fn = nn.CrossEntropyLoss()

training_loop(
    n_epochs = 20,
    model = model,
    optimizer = optimizer,
    loss_fn = loss_fn,
    train_loader = train_loader
)

validate(model, train_loader, val_loader)

2024-12-09 19:29:21.514271 Epoch 1, Training loss 2.0203
2024-12-09 19:33:30.141019 Epoch 2, Training loss 1.6908
2024-12-09 19:37:35.150030 Epoch 3, Training loss 1.4289
2024-12-09 19:41:41.874685 Epoch 4, Training loss 1.2555
2024-12-09 19:45:48.929531 Epoch 5, Training loss 1.1693
2024-12-09 19:49:54.809520 Epoch 6, Training loss 1.1084
2024-12-09 19:54:03.329151 Epoch 7, Training loss 1.0515
2024-12-09 19:58:12.556815 Epoch 8, Training loss 0.9894
2024-12-09 20:02:24.494461 Epoch 9, Training loss 0.9335
2024-12-09 20:06:35.615963 Epoch 10, Training loss 0.8841
2024-12-09 20:10:44.680837 Epoch 11, Training loss 0.8504
2024-12-09 20:14:57.515843 Epoch 12, Training loss 0.8210
2024-12-09 20:19:08.897550 Epoch 13, Training loss 0.7954
2024-12-09 20:23:18.484789 Epoch 14, Training loss 0.7708
2024-12-09 20:27:26.107398 Epoch 15, Training loss 0.7454
2024-12-09 20:31:32.027186 Epoch 16, Training loss 0.7284
2024-12-09 20:35:39.073138 Epoch 17, Training loss 0.7071
2024-12-09 20:39:44.150