In [3]:
import torch.nn as nn
import torchvision
from torchvision.transforms import v2
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from random import choices
import numpy as np
import random
from sklearn.model_selection import train_test_split
import warnings

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
seed = 42
warnings.filterwarnings("ignore")

In [5]:
torch.random.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

In [6]:
training_dataset = torchvision.datasets.CIFAR10(root="./", download=True, train=True)
test_dataset = torchvision.datasets.CIFAR10(root="./", download=True, train=False)

100%|██████████| 170M/170M [00:04<00:00, 35.0MB/s] 


In [7]:
training_images = []
test_images = []
transformation = torchvision.transforms.ToTensor()

for img in training_dataset.data:
  training_images.append(transformation(img))

for img in test_dataset.data:
  test_images.append(transformation(img))

training_images = torch.stack(training_images)
test_images = torch.stack(test_images)

In [8]:
per_pixel_mean = torch.sum(training_images, dim=0)/training_images.shape[0]
training_images = training_images - per_pixel_mean
test_images = test_images - per_pixel_mean

In [9]:
train_labels = training_dataset.targets
test_labels = test_dataset.targets
y_train = torch.Tensor([[1 if i == el else 0 for i in range(10)] for el in train_labels])
y_test = torch.Tensor([[1 if i == el else 0 for i in range(10)] for el in test_labels])

In [72]:

class ResidualBlock(nn.Module):

  def __init__(self, in_channels, out_channels, kernel, padding1, padding2, stride):
    super().__init__()
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.upsampling = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride)

    self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel, padding=padding1, stride=stride)
    self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel, padding=padding2, stride=1)
    self.bn1 = nn.BatchNorm2d(num_features=out_channels)
    self.bn2 = nn.BatchNorm2d(num_features=out_channels)

    nn.init.normal_(self.upsampling.weight, 0, (2/20)**0.5)
    nn.init.normal_(self.conv1.weight, 0, (2/20)**0.5)
    nn.init.normal_(self.conv2.weight, 0, (2/20)**0.5)

    nn.init.zeros_(self.upsampling.bias)
    nn.init.zeros_(self.conv1.bias)
    nn.init.zeros_(self.conv2.bias)

  def forward(self, x):
    if self.in_channels != self.out_channels:
      x_skip = self.upsampling(x)
    else:
      x_skip = x

    x = self.conv1(x)
    x = self.bn1(x)
    x = nn.functional.relu(x)
    x = self.conv2(x)
    x = x + x_skip
    x = self.bn2(x)
    x = nn.functional.relu(x)
    return x


class ResNet20(nn.Module):

  def __init__(self):
    super().__init__()
    self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)

    self.res_block_16_1 = ResidualBlock(in_channels=16, out_channels=16, kernel=3, padding1=1,padding2=1,stride=1)
    self.res_block_16_2 = ResidualBlock(in_channels=16, out_channels=16, kernel=3, padding1=1,padding2=1,stride=1)
    self.res_block_16_3 = ResidualBlock(in_channels=16, out_channels=16, kernel=3, padding1=1,padding2=1,stride=1)
    self.res_block_16_4 = ResidualBlock(in_channels=16, out_channels=16, kernel=3, padding1=1,padding2=1,stride=1)
    self.res_block_16_5 = ResidualBlock(in_channels=16, out_channels=16, kernel=3, padding1=1,padding2=1,stride=1)
    self.res_block_16_6 = ResidualBlock(in_channels=16, out_channels=16, kernel=3, padding1=1,padding2=1,stride=1)
    self.res_block_32_1 = ResidualBlock(in_channels=16, out_channels=32, kernel=3, padding1=1, padding2=1, stride=2)
    self.res_block_32_2 = ResidualBlock(in_channels=32, out_channels=32, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_32_3 = ResidualBlock(in_channels=32, out_channels=32, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_32_4 = ResidualBlock(in_channels=32, out_channels=32, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_32_5 = ResidualBlock(in_channels=32, out_channels=32, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_32_6 = ResidualBlock(in_channels=32, out_channels=32, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_64_1 = ResidualBlock(in_channels=32, out_channels=64, kernel=3, padding1=1, padding2=1, stride=2)
    self.res_block_64_2 = ResidualBlock(in_channels=64, out_channels=64, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_64_3 = ResidualBlock(in_channels=64, out_channels=64, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_64_4 = ResidualBlock(in_channels=64, out_channels=64, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_64_5 = ResidualBlock(in_channels=64, out_channels=64, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_64_6 = ResidualBlock(in_channels=64, out_channels=64, kernel=3, padding1=1, padding2=1, stride=1)

    self.global_pooling = nn.AvgPool2d(kernel_size=8)
    self.fc = nn.Linear(64, 10)

    nn.init.normal_(self.conv1.weight, 0, (2/20)**0.5)
    nn.init.normal_(self.fc.weight, 0, (2/20)**0.5)

    nn.init.zeros_(self.conv1.bias)
    nn.init.zeros_(self.fc.bias)


  def forward(self, x):
    x = self.conv1(x)
    x = self.res_block_16_1(x)
    x = self.res_block_16_2(x)
    x = self.res_block_16_3(x)
    x = self.res_block_16_4(x)
    x = self.res_block_16_5(x)
    x = self.res_block_16_6(x)
    x = self.res_block_32_1(x)
    x = self.res_block_32_2(x)
    x = self.res_block_32_3(x)
    x = self.res_block_32_4(x)
    x = self.res_block_32_5(x)
    x = self.res_block_32_6(x)
    x = self.res_block_64_1(x)
    x = self.res_block_64_2(x)
    x = self.res_block_64_3(x)
    x = self.res_block_64_4(x)
    x = self.res_block_64_5(x)
    x = self.res_block_64_6(x)
    x = self.global_pooling(x)
    x = torch.flatten(x, 1, -1)
    x = self.fc(x)
    return x

In [10]:
X_train, X_val, y_train, y_val = train_test_split(training_images, y_train, test_size=0.1, random_state=seed)

In [11]:
transforms = v2.Compose([
    v2.RandomHorizontalFlip(p=0.5),
    v2.RandomCrop(size=32, padding=4)
])

temp_x = []
temp_y = []
for img, label in [*zip(X_train, y_train)]:
  for i in range(5):
    temp_x.append(transforms(img))
    temp_y.append(label)
X_train = torch.stack(temp_x)
y_train = torch.stack(temp_y)

In [12]:
print(X_train.shape)

torch.Size([225000, 3, 32, 32])


In [78]:
weight_decay = 0.0001
momentum = 0.9
lr = 0.1
batch_size = 128
iterations = 0
over=False

train_batches = DataLoader([*zip(X_train, y_train)], batch_size=batch_size, shuffle=True)
val_batches = DataLoader([*zip(X_val, y_val)], batch_size=batch_size, shuffle=True)
loss_fn = nn.CrossEntropyLoss()
val_loss_fn = nn.CrossEntropyLoss(reduction="sum")
model = ResNet20().to(device)
optimizer = torch.optim.SGD(params=model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)

while True:

  for batch in train_batches:
    model.train()
    iterations += 1
    optimizer.zero_grad()
    features, target = batch[:-1], batch[-1]
    features = features[0].to(device)
    target = target.to(device)
    outputs = model(features)
    perte = loss_fn(outputs, target)
    perte.backward()
    optimizer.step()

    if iterations % 100 == 0:
      model.eval()
      total_loss = 0
      for batch in val_batches:
        features, target = batch[:-1], batch[-1]
        features = features[0].to(device)
        target = target.to(device)
        outputs = model(features)
        perte = val_loss_fn(outputs, target)
        total_loss += perte.item()
      print(f"Iteration {iterations}: Loss {total_loss/X_val.shape[0]}")

    match iterations:
      case 32000:
        lr /= 10
        optimizer = torch.optim.SGD(params=model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
      case 48000:
        lr /= 10
        optimizer = torch.optim.SGD(params=model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
      case 64000:
        over=True
        break
      case _:
        continue




  if over:
    break


Iteration 100: Loss 2.1981354187011717
Iteration 200: Loss 2.0027960945129393
Iteration 300: Loss 2.002994049835205
Iteration 400: Loss 1.9953965141296386
Iteration 500: Loss 1.858583112335205
Iteration 600: Loss 1.8534700119018555
Iteration 700: Loss 1.7887409023284913
Iteration 800: Loss 1.7238394760131837
Iteration 900: Loss 1.8912242183685302
Iteration 1000: Loss 1.6848548057556152
Iteration 1100: Loss 1.8760100011825562
Iteration 1200: Loss 1.6785180240631103
Iteration 1300: Loss 1.6412097129821777
Iteration 1400: Loss 1.6558067039489746
Iteration 1500: Loss 1.5729736562728882
Iteration 1600: Loss 1.6453614837646484
Iteration 1700: Loss 1.5380710227966308
Iteration 1800: Loss 1.4981574514389038
Iteration 1900: Loss 1.5487183439254761
Iteration 2000: Loss 1.5583930297851563
Iteration 2100: Loss 1.4706424869537353
Iteration 2200: Loss 1.504853060913086
Iteration 2300: Loss 1.3554442335128785
Iteration 2400: Loss 1.4096236949920655
Iteration 2500: Loss 1.5129338760375977
Iteration 26

In [79]:
model.eval()
test_loader = DataLoader([*zip(test_images, test_labels)], batch_size=batch_size, shuffle=False)
correct = 0
for batch in test_loader:
  features, target = batch[:-1], batch[-1]
  features = features[0].to(device)
  target = target.to(device)
  outputs = model(features)
  correct += torch.where(torch.argmax(torch.softmax(outputs, dim=1), dim=1)==target, 1, 0).sum()
print((correct/test_images.shape[0])*100)

tensor(89.4800, device='cuda:0')


In [80]:
del model, features, target

In [13]:

class ResidualBlock(nn.Module):

  def __init__(self, in_channels, out_channels, kernel, padding1, padding2, stride):
    super().__init__()
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.upsampling = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride)

    self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel, padding=padding1, stride=stride)
    self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel, padding=padding2, stride=1)
    self.bn1 = nn.BatchNorm2d(num_features=out_channels)
    self.bn2 = nn.BatchNorm2d(num_features=out_channels)

    nn.init.normal_(self.upsampling.weight, 0, (2/32)**0.5)
    nn.init.normal_(self.conv1.weight, 0, (2/32)**0.5)
    nn.init.normal_(self.conv2.weight, 0, (2/32)**0.5)

    nn.init.zeros_(self.upsampling.bias)
    nn.init.zeros_(self.conv1.bias)
    nn.init.zeros_(self.conv2.bias)

  def forward(self, x):
    if self.in_channels != self.out_channels:
      x_skip = self.upsampling(x)
    else:
      x_skip = x

    x = self.conv1(x)
    x = self.bn1(x)
    x = nn.functional.relu(x)
    x = self.conv2(x)
    x = x + x_skip
    x = self.bn2(x)
    x = nn.functional.relu(x)
    return x



class ResNet32(nn.Module):

  def __init__(self):
    super().__init__()
    self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)

    self.res_block_16_1 = ResidualBlock(in_channels=16, out_channels=16, kernel=3, padding1=1,padding2=1,stride=1)
    self.res_block_16_2 = ResidualBlock(in_channels=16, out_channels=16, kernel=3, padding1=1,padding2=1,stride=1)
    self.res_block_16_3 = ResidualBlock(in_channels=16, out_channels=16, kernel=3, padding1=1,padding2=1,stride=1)
    self.res_block_16_4 = ResidualBlock(in_channels=16, out_channels=16, kernel=3, padding1=1,padding2=1,stride=1)
    self.res_block_16_5 = ResidualBlock(in_channels=16, out_channels=16, kernel=3, padding1=1,padding2=1,stride=1)
    self.res_block_16_6 = ResidualBlock(in_channels=16, out_channels=16, kernel=3, padding1=1,padding2=1,stride=1)
    self.res_block_16_7 = ResidualBlock(in_channels=16, out_channels=16, kernel=3, padding1=1,padding2=1,stride=1)
    self.res_block_16_8 = ResidualBlock(in_channels=16, out_channels=16, kernel=3, padding1=1,padding2=1,stride=1)
    self.res_block_16_9 = ResidualBlock(in_channels=16, out_channels=16, kernel=3, padding1=1,padding2=1,stride=1)
    self.res_block_16_10 = ResidualBlock(in_channels=16, out_channels=16, kernel=3, padding1=1,padding2=1,stride=1)
    self.res_block_32_1 = ResidualBlock(in_channels=16, out_channels=32, kernel=3, padding1=1, padding2=1, stride=2)
    self.res_block_32_2 = ResidualBlock(in_channels=32, out_channels=32, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_32_3 = ResidualBlock(in_channels=32, out_channels=32, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_32_4 = ResidualBlock(in_channels=32, out_channels=32, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_32_5 = ResidualBlock(in_channels=32, out_channels=32, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_32_6 = ResidualBlock(in_channels=32, out_channels=32, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_32_7 = ResidualBlock(in_channels=32, out_channels=32, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_32_8 = ResidualBlock(in_channels=32, out_channels=32, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_32_9 = ResidualBlock(in_channels=32, out_channels=32, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_32_10 = ResidualBlock(in_channels=32, out_channels=32, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_64_1 = ResidualBlock(in_channels=32, out_channels=64, kernel=3, padding1=1, padding2=1, stride=2)
    self.res_block_64_2 = ResidualBlock(in_channels=64, out_channels=64, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_64_3 = ResidualBlock(in_channels=64, out_channels=64, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_64_4 = ResidualBlock(in_channels=64, out_channels=64, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_64_5 = ResidualBlock(in_channels=64, out_channels=64, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_64_6 = ResidualBlock(in_channels=64, out_channels=64, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_64_7 = ResidualBlock(in_channels=64, out_channels=64, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_64_8 = ResidualBlock(in_channels=64, out_channels=64, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_64_9 = ResidualBlock(in_channels=64, out_channels=64, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_64_10 = ResidualBlock(in_channels=64, out_channels=64, kernel=3, padding1=1, padding2=1, stride=1)

    self.global_pooling = nn.AvgPool2d(kernel_size=8)
    self.fc = nn.Linear(64, 10)

    nn.init.normal_(self.conv1.weight, 0, (2/32)**0.5)
    nn.init.normal_(self.fc.weight, 0, (2/32)**0.5)

    nn.init.zeros_(self.conv1.bias)
    nn.init.zeros_(self.fc.bias)


  def forward(self, x):
    x = self.conv1(x)
    x = self.res_block_16_1(x)
    x = self.res_block_16_2(x)
    x = self.res_block_16_3(x)
    x = self.res_block_16_4(x)
    x = self.res_block_16_5(x)
    x = self.res_block_16_6(x)
    x = self.res_block_16_7(x)
    x = self.res_block_16_8(x)
    x = self.res_block_16_9(x)
    x = self.res_block_16_10(x)
    x = self.res_block_32_1(x)
    x = self.res_block_32_2(x)
    x = self.res_block_32_3(x)
    x = self.res_block_32_4(x)
    x = self.res_block_32_5(x)
    x = self.res_block_32_6(x)
    x = self.res_block_32_7(x)
    x = self.res_block_32_8(x)
    x = self.res_block_32_9(x)
    x = self.res_block_32_10(x)
    x = self.res_block_64_1(x)
    x = self.res_block_64_2(x)
    x = self.res_block_64_3(x)
    x = self.res_block_64_4(x)
    x = self.res_block_64_5(x)
    x = self.res_block_64_6(x)
    x = self.res_block_64_7(x)
    x = self.res_block_64_8(x)
    x = self.res_block_64_9(x)
    x = self.res_block_64_10(x)
    x = self.global_pooling(x)
    x = torch.flatten(x, 1, -1)
    x = self.fc(x)
    return x

In [14]:
weight_decay = 0.0001
momentum = 0.9
lr = 0.1
batch_size = 128
iterations = 0
over=False

train_batches = DataLoader([*zip(X_train, y_train)], batch_size=batch_size, shuffle=True)
val_batches = DataLoader([*zip(X_val, y_val)], batch_size=batch_size, shuffle=True)
loss_fn = nn.CrossEntropyLoss()
val_loss_fn = nn.CrossEntropyLoss(reduction="sum")
model = ResNet32().to(device)
optimizer = torch.optim.SGD(params=model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)

while True:

  for batch in train_batches:
    model.train()
    iterations += 1
    optimizer.zero_grad()
    features, target = batch[:-1], batch[-1]
    features = features[0].to(device)
    target = target.to(device)
    outputs = model(features)
    perte = loss_fn(outputs, target)
    perte.backward()
    optimizer.step()

    if iterations % 100 == 0:
      model.eval()
      total_loss = 0
      for batch in val_batches:
        features, target = batch[:-1], batch[-1]
        features = features[0].to(device)
        target = target.to(device)
        outputs = model(features)
        perte = val_loss_fn(outputs, target)
        total_loss += perte.item()
      print(f"Iteration {iterations}: Loss {total_loss/X_val.shape[0]}")

    match iterations:
      case 32000:
        lr /= 10
        optimizer = torch.optim.SGD(params=model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
      case 48000:
        lr /= 10
        optimizer = torch.optim.SGD(params=model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
      case 64000:
        over=True
        break
      case _:
        continue




  if over:
    break


Iteration 100: Loss 2.3226703525543213
Iteration 200: Loss 2.3050545318603515
Iteration 300: Loss 2.2973716629028322
Iteration 400: Loss 2.290064970397949
Iteration 500: Loss 2.2853515075683593
Iteration 600: Loss 2.2823936752319334
Iteration 700: Loss 2.279480103302002
Iteration 800: Loss 2.262299355697632
Iteration 900: Loss 2.2617228031158447
Iteration 1000: Loss 2.294735417938232
Iteration 1100: Loss 2.1927947784423827
Iteration 1200: Loss 2.222826135635376
Iteration 1300: Loss 2.1830554988861084
Iteration 1400: Loss 2.273524047470093
Iteration 1500: Loss 2.0837497257232664
Iteration 1600: Loss 2.150034883880615
Iteration 1700: Loss 2.033199182891846
Iteration 1800: Loss 1.9862348648071289
Iteration 1900: Loss 1.9213703172683716
Iteration 2000: Loss 2.100653741455078
Iteration 2100: Loss 2.0878962371826173
Iteration 2200: Loss 1.8215091533660888
Iteration 2300: Loss 1.7649912139892578
Iteration 2400: Loss 2.0041395950317384
Iteration 2500: Loss 1.793077935028076
Iteration 2600: Los

In [15]:
model.eval()
test_loader = DataLoader([*zip(test_images, test_labels)], batch_size=batch_size, shuffle=False)
correct = 0
for batch in test_loader:
  features, target = batch[:-1], batch[-1]
  features = features[0].to(device)
  target = target.to(device)
  outputs = model(features)
  correct += torch.where(torch.argmax(torch.softmax(outputs, dim=1), dim=1)==target, 1, 0).sum()
print((correct/test_images.shape[0])*100)

tensor(82.7400, device='cuda:0')


In [17]:
del model, features, target

In [16]:

class ResidualBlock(nn.Module):

  def __init__(self, in_channels, out_channels, kernel, padding1, padding2, stride):
    super().__init__()
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.upsampling = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride)

    self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel, padding=padding1, stride=stride)
    self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel, padding=padding2, stride=1)
    self.bn1 = nn.BatchNorm2d(num_features=out_channels)
    self.bn2 = nn.BatchNorm2d(num_features=out_channels)

    nn.init.normal_(self.upsampling.weight, 0, (2/44)**0.5)
    nn.init.normal_(self.conv1.weight, 0, (2/44)**0.5)
    nn.init.normal_(self.conv2.weight, 0, (2/44)**0.5)

    nn.init.zeros_(self.upsampling.bias)
    nn.init.zeros_(self.conv1.bias)
    nn.init.zeros_(self.conv2.bias)

  def forward(self, x):
    if self.in_channels != self.out_channels:
      x_skip = self.upsampling(x)
    else:
      x_skip = x

    x = self.conv1(x)
    x = self.bn1(x)
    x = nn.functional.relu(x)
    x = self.conv2(x)
    x = x + x_skip
    x = self.bn2(x)
    x = nn.functional.relu(x)
    return x



class ResNet44(nn.Module):

  def __init__(self):
    super().__init__()
    self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)

    self.res_block_16_1 = ResidualBlock(in_channels=16, out_channels=16, kernel=3, padding1=1,padding2=1,stride=1)
    self.res_block_16_2 = ResidualBlock(in_channels=16, out_channels=16, kernel=3, padding1=1,padding2=1,stride=1)
    self.res_block_16_3 = ResidualBlock(in_channels=16, out_channels=16, kernel=3, padding1=1,padding2=1,stride=1)
    self.res_block_16_4 = ResidualBlock(in_channels=16, out_channels=16, kernel=3, padding1=1,padding2=1,stride=1)
    self.res_block_16_5 = ResidualBlock(in_channels=16, out_channels=16, kernel=3, padding1=1,padding2=1,stride=1)
    self.res_block_16_6 = ResidualBlock(in_channels=16, out_channels=16, kernel=3, padding1=1,padding2=1,stride=1)
    self.res_block_16_7 = ResidualBlock(in_channels=16, out_channels=16, kernel=3, padding1=1,padding2=1,stride=1)
    self.res_block_16_8 = ResidualBlock(in_channels=16, out_channels=16, kernel=3, padding1=1,padding2=1,stride=1)
    self.res_block_16_9 = ResidualBlock(in_channels=16, out_channels=16, kernel=3, padding1=1,padding2=1,stride=1)
    self.res_block_16_10 = ResidualBlock(in_channels=16, out_channels=16, kernel=3, padding1=1,padding2=1,stride=1)
    self.res_block_16_11 = ResidualBlock(in_channels=16, out_channels=16, kernel=3, padding1=1,padding2=1,stride=1)
    self.res_block_16_12 = ResidualBlock(in_channels=16, out_channels=16, kernel=3, padding1=1,padding2=1,stride=1)
    self.res_block_16_13 = ResidualBlock(in_channels=16, out_channels=16, kernel=3, padding1=1,padding2=1,stride=1)
    self.res_block_16_14 = ResidualBlock(in_channels=16, out_channels=16, kernel=3, padding1=1,padding2=1,stride=1)  
    self.res_block_32_1 = ResidualBlock(in_channels=16, out_channels=32, kernel=3, padding1=1, padding2=1, stride=2)
    self.res_block_32_2 = ResidualBlock(in_channels=32, out_channels=32, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_32_3 = ResidualBlock(in_channels=32, out_channels=32, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_32_4 = ResidualBlock(in_channels=32, out_channels=32, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_32_5 = ResidualBlock(in_channels=32, out_channels=32, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_32_6 = ResidualBlock(in_channels=32, out_channels=32, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_32_7 = ResidualBlock(in_channels=32, out_channels=32, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_32_8 = ResidualBlock(in_channels=32, out_channels=32, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_32_9 = ResidualBlock(in_channels=32, out_channels=32, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_32_10 = ResidualBlock(in_channels=32, out_channels=32, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_32_11 = ResidualBlock(in_channels=32, out_channels=32, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_32_12 = ResidualBlock(in_channels=32, out_channels=32, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_32_13 = ResidualBlock(in_channels=32, out_channels=32, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_32_14 = ResidualBlock(in_channels=32, out_channels=32, kernel=3, padding1=1, padding2=1, stride=1)  
    self.res_block_64_1 = ResidualBlock(in_channels=32, out_channels=64, kernel=3, padding1=1, padding2=1, stride=2)
    self.res_block_64_2 = ResidualBlock(in_channels=64, out_channels=64, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_64_3 = ResidualBlock(in_channels=64, out_channels=64, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_64_4 = ResidualBlock(in_channels=64, out_channels=64, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_64_5 = ResidualBlock(in_channels=64, out_channels=64, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_64_6 = ResidualBlock(in_channels=64, out_channels=64, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_64_7 = ResidualBlock(in_channels=64, out_channels=64, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_64_8 = ResidualBlock(in_channels=64, out_channels=64, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_64_9 = ResidualBlock(in_channels=64, out_channels=64, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_64_10 = ResidualBlock(in_channels=64, out_channels=64, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_64_11 = ResidualBlock(in_channels=64, out_channels=64, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_64_12 = ResidualBlock(in_channels=64, out_channels=64, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_64_13 = ResidualBlock(in_channels=64, out_channels=64, kernel=3, padding1=1, padding2=1, stride=1)
    self.res_block_64_14 = ResidualBlock(in_channels=64, out_channels=64, kernel=3, padding1=1, padding2=1, stride=1)

    self.global_pooling = nn.AvgPool2d(kernel_size=8)
    self.fc = nn.Linear(64, 10)

    nn.init.normal_(self.conv1.weight, 0, (2/44)**0.5)
    nn.init.normal_(self.fc.weight, 0, (2/44)**0.5)

    nn.init.zeros_(self.conv1.bias)
    nn.init.zeros_(self.fc.bias)


  def forward(self, x):
    x = self.conv1(x)
    x = self.res_block_16_1(x)
    x = self.res_block_16_2(x)
    x = self.res_block_16_3(x)
    x = self.res_block_16_4(x)
    x = self.res_block_16_5(x)
    x = self.res_block_16_6(x)
    x = self.res_block_16_7(x)
    x = self.res_block_16_8(x)
    x = self.res_block_16_9(x)
    x = self.res_block_16_10(x)
    x = self.res_block_16_11(x)
    x = self.res_block_16_12(x)
    x = self.res_block_16_13(x)
    x = self.res_block_16_14(x)
    x = self.res_block_32_1(x)
    x = self.res_block_32_2(x)
    x = self.res_block_32_3(x)
    x = self.res_block_32_4(x)
    x = self.res_block_32_5(x)
    x = self.res_block_32_6(x)
    x = self.res_block_32_7(x)
    x = self.res_block_32_8(x)
    x = self.res_block_32_9(x)
    x = self.res_block_32_10(x)
    x = self.res_block_32_11(x)
    x = self.res_block_32_12(x)
    x = self.res_block_32_13(x)
    x = self.res_block_32_14(x)
    x = self.res_block_64_1(x)
    x = self.res_block_64_2(x)
    x = self.res_block_64_3(x)
    x = self.res_block_64_4(x)
    x = self.res_block_64_5(x)
    x = self.res_block_64_6(x)
    x = self.res_block_64_7(x)
    x = self.res_block_64_8(x)
    x = self.res_block_64_9(x)
    x = self.res_block_64_10(x)
    x = self.res_block_64_11(x)
    x = self.res_block_64_12(x)
    x = self.res_block_64_13(x)
    x = self.res_block_64_14(x)
    x = self.global_pooling(x)
    x = torch.flatten(x, 1, -1)
    x = self.fc(x)
    return x

In [18]:
weight_decay = 0.0001
momentum = 0.9
lr = 0.1
batch_size = 128
iterations = 0
over=False

train_batches = DataLoader([*zip(X_train, y_train)], batch_size=batch_size, shuffle=True)
val_batches = DataLoader([*zip(X_val, y_val)], batch_size=batch_size, shuffle=True)
loss_fn = nn.CrossEntropyLoss()
val_loss_fn = nn.CrossEntropyLoss(reduction="sum")
model = ResNet44().to(device)
optimizer = torch.optim.SGD(params=model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)

while True:

  for batch in train_batches:
    model.train()
    iterations += 1
    optimizer.zero_grad()
    features, target = batch[:-1], batch[-1]
    features = features[0].to(device)
    target = target.to(device)
    outputs = model(features)
    perte = loss_fn(outputs, target)
    perte.backward()
    optimizer.step()

    if iterations % 100 == 0:
      model.eval()
      total_loss = 0
      for batch in val_batches:
        features, target = batch[:-1], batch[-1]
        features = features[0].to(device)
        target = target.to(device)
        outputs = model(features)
        perte = val_loss_fn(outputs, target)
        total_loss += perte.item()
      print(f"Iteration {iterations}: Loss {total_loss/X_val.shape[0]}")

    match iterations:
      case 32000:
        lr /= 10
        optimizer = torch.optim.SGD(params=model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
      case 48000:
        lr /= 10
        optimizer = torch.optim.SGD(params=model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
      case 64000:
        over=True
        break
      case _:
        continue




  if over:
    break


Iteration 100: Loss 2.316779393386841
Iteration 200: Loss 2.305696852493286
Iteration 300: Loss 2.3121582534790037
Iteration 400: Loss 2.301698243713379
Iteration 500: Loss 2.300693458557129
Iteration 600: Loss 2.2969364921569824
Iteration 700: Loss 2.2968462966918946
Iteration 800: Loss 2.2976775859832763
Iteration 900: Loss 2.3028351707458494
Iteration 1000: Loss 2.2900081520080566
Iteration 1100: Loss 2.3550639694213866
Iteration 1200: Loss 2.285883277130127
Iteration 1300: Loss 2.3302780265808107
Iteration 1400: Loss 2.276705685806274
Iteration 1500: Loss 2.2733742637634275
Iteration 1600: Loss 2.2746206394195556
Iteration 1700: Loss 2.2775066928863525
Iteration 1800: Loss 2.27186792678833
Iteration 1900: Loss 2.2749424171447754
Iteration 2000: Loss 2.2688622138977053
Iteration 2100: Loss 2.266777786254883
Iteration 2200: Loss 2.2907829906463624
Iteration 2300: Loss 2.2770531482696534
Iteration 2400: Loss 2.272322298049927
Iteration 2500: Loss 2.266738856124878
Iteration 2600: Loss

In [19]:
model.eval()
test_loader = DataLoader([*zip(test_images, test_labels)], batch_size=batch_size, shuffle=False)
correct = 0
for batch in test_loader:
  features, target = batch[:-1], batch[-1]
  features = features[0].to(device)
  target = target.to(device)
  outputs = model(features)
  correct += torch.where(torch.argmax(torch.softmax(outputs, dim=1), dim=1)==target, 1, 0).sum()
print((correct/test_images.shape[0])*100)

tensor(76.2400, device='cuda:0')
