In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torch.nn as nn
import numpy as np
from torch import Tensor
from typing import Type

In [2]:
# Data transforms (normalization & data augmentation)
# stats = ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
# train_tfms = transforms.Compose([transforms.Resize((64, 64)),
#                          transforms.RandomCrop(64, padding=4, padding_mode='reflect'),
#                          transforms.RandomHorizontalFlip(),
#                          transforms.ToTensor(),
#                          transforms.Normalize(*stats,inplace=True)])
# valid_tfms = transforms.Compose([transforms.Resize((64, 64)), transforms.ToTensor(), transforms.Normalize(*stats)])

In [3]:
train_tfms = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Define transformations for test data
valid_tfms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [4]:
traindata = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_tfms)
validationdata = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=valid_tfms)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:05<00:00, 29184104.95it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [5]:
len(traindata)

50000

In [6]:
len(validationdata)

10000

In [7]:
image, label = traindata[0]
print(image.shape, label)

torch.Size([3, 32, 32]) 6


In [8]:
trainDataLoder = torch.utils.data.DataLoader(traindata, batch_size=128, shuffle=True)
validationDataLoder = torch.utils.data.DataLoader(validationdata, batch_size=128, shuffle=False)

##Custom ResNet

In [9]:
def conv_block(in_channels, out_channels, pool=False):
    layers = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
              nn.BatchNorm2d(out_channels),
              nn.ReLU(inplace=True)]
    if pool: layers.append(nn.MaxPool2d(2))
    return nn.Sequential(*layers)

class ResNet9(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()

        self.conv1 = conv_block(in_channels, 64)
        self.res1 = nn.Sequential(conv_block(64, 64), conv_block(64, 64))

        self.conv2 = conv_block(64, 128, pool=True)
        self.res2 = nn.Sequential(conv_block(128, 128), conv_block(128, 128))
        self.res2_1 = nn.Sequential(conv_block(128, 128), conv_block(128, 128))
        
        self.conv3 = conv_block(128, 256, pool=True)
        self.res3 = nn.Sequential(conv_block(256, 256), conv_block(256, 256))
        self.res3_1 = nn.Sequential(conv_block(256, 256), conv_block(256, 256))
        
        self.conv4 = conv_block(256, 512, pool=True)
        
        self.classifier = nn.Sequential(nn.MaxPool2d(4),
                                        nn.Flatten(),
#                                         nn.Dropout(0.5),
                                        nn.Linear(512, num_classes)
                                       )

    def forward(self, xb):
        out = self.conv1(xb)
        out = self.res1(out) + out
        out = self.conv2(out)
        out = self.res2(out) + out
        out = self.res2_1(out) + out
        out = self.conv3(out)
        out = self.res3(out) + out
        out = self.res3_1(out) + out
        out = self.conv4(out)
        out = self.classifier(out)
        return out

In [10]:
model = ResNet9(3,10).cuda()
print(model)

ResNet9(
  (conv1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (res1): Sequential(
    (0): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (1): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
  )
  (conv2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=Fal

In [11]:
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} total parameters.")
total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} training parameters.")

4,585,866 total parameters.
4,585,866 training parameters.


In [12]:
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [13]:
def calculate_accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))

In [14]:
train_loss_history = []
train_acc_history = []
validation_loss_history = []
validation_acc_history = []

for epoch in range(60):
  train_loss = 0.0
  train_acc = 0.0
  validation_loss = 0.0
  validation_acc = 0.0

  model.train()
  for i, data in enumerate(trainDataLoder):
    images, labels = data
    images = images.cuda()
    labels = labels.cuda()
    optimizer.zero_grad()
    predicted_labels = model(images) ## forward pass
    fit = loss(predicted_labels, labels)  ## calculate loss
    fit.backward() ## backprop
    
#     nn.utils.clip_grad_value_(model.parameters(), 0.1)
    
    optimizer.step()  ## calculate gradients and updates weights
    train_loss += fit.item()
    train_acc += calculate_accuracy(predicted_labels, labels)

  model.eval()
  for i, data in enumerate(validationDataLoder):
    with torch.no_grad():
      images, labels = data
      images = images.cuda()
      labels = labels.cuda()
      predicted_labels = model(images)
      fit = loss(predicted_labels, labels)
      validation_loss += fit.item()
      validation_acc += calculate_accuracy(predicted_labels, labels)

  train_loss = train_loss/len(trainDataLoder)
  validation_loss = validation_loss/len(validationDataLoder)
  train_acc = train_acc/len(trainDataLoder)
  validation_acc = validation_acc/len(validationDataLoder)

  train_loss_history.append(train_loss)
  validation_loss_history.append(validation_loss)
  train_acc_history.append(train_acc)
  validation_acc_history.append(validation_acc)

  print(f'Epochs {epoch}:')
  print(f'        Train loss: {train_loss}, Validation loss: {validation_loss}')
  print(f'        Train Acc: {train_acc}, Validation Acc: {validation_acc}')

  if (epoch+1)%20==0:
     torch.save(model.state_dict(), f'model_{epoch+1}.pt')

Epochs 0:
        Train loss: 1.3565686293270276, Validation loss: 1.1431117955642411
        Train Acc: 0.5070252418518066, Validation Acc: 0.6129351258277893
Epochs 1:
        Train loss: 0.8398220162562398, Validation loss: 0.7864032199111166
        Train Acc: 0.7048273682594299, Validation Acc: 0.7257713675498962
Epochs 2:
        Train loss: 0.655611978162585, Validation loss: 0.6199331196803081
        Train Acc: 0.7715272903442383, Validation Acc: 0.787282407283783
Epochs 3:
        Train loss: 0.5421166122721894, Validation loss: 0.5942599837538562
        Train Acc: 0.8116728067398071, Validation Acc: 0.798061728477478
Epochs 4:
        Train loss: 0.4848165983891548, Validation loss: 0.4519854278504094
        Train Acc: 0.8323649168014526, Validation Acc: 0.8450356125831604
Epochs 5:
        Train loss: 0.42571524181939147, Validation loss: 0.48802637638924995
        Train Acc: 0.8512507677078247, Validation Acc: 0.8354430198669434
Epochs 6:
        Train loss: 0.384131200

In [15]:
torch.save(model.state_dict(), 'model.pt')