<a href="https://colab.research.google.com/github/clashgamer123/SOC_Pytorch/blob/main/ciphar_resnet18.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In this pytorch code we are going to use the inbuilt RESNET architecture particularly the ResNet18 and then train,test a neural network on the CIPHAR-10 data set. <br>
CIPHAR-10 is a standard set of images that belong to 1 of 10 different classes.<br>
It contains a total of 6000 images per class and is partitioned into 50000,10000 for training and testing respectively.

Let us start by importing all the modules

In [None]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms
from torch.nn import functional as F

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

Since the model is quite time consuming due to the large data set we are going to run this on our gpu.

In [None]:
# Configure the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Lets us get our data and load it using the data loaders.

In [None]:
# data transformation using transforms.compose. The normalisation is done so as to give the entire data set a mean 0 and std 1.
# This ensures faster convergence generally.
# The values are taken from the internet
transform = transforms.Compose([
            transforms.Resize((224,224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                                 std=[0.2023, 0.1994, 0.2010])
                              ])

data_dir='./data'
batch_size=64

# Downloading the datasets into train and test seperately
train_dataset = datasets.CIFAR10(root=data_dir, train=True, download=True, transform=transform )

test_dataset = datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform )

# Load the data using data loader.
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False )


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:03<00:00, 48208668.87it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


Let us now call our ResNet18 model that is inbuilt in torch and move it to our gpu.

In [None]:
model = models.resnet18(pretrained=True)
model.to(device)
print(model)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 146MB/s]


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

We need to change the out_features in the last Fully connected layer to 10 as we have 10 image classes in CIPHAR-10.

In [None]:
model.fc = nn.Linear(model.fc.in_features, 10)
model.to(device)
print(model)  # Check the model fc layer again.

Now since our model and the data set are both ready. Let us go to traing and testing part.

In [None]:
loss_criterion = nn.CrossEntropyLoss()
# CrossEntropyLoss is better than BCEloss in case we have more out_features. BCEloss is good for 1 output in fc that is binary.

optimizer = optim.Adam(model.parameters(), lr=0.001)

epochs = 10 # Train for 10 iterations
total_images = len(train_loader)

for epoch in range(epochs):
  cum_loss = 0 # Keep track of the cumulative loss instead of just the loss of one batch

  for i, (images, labels) in enumerate(train_loader): # We are keeping a counter too
    images = images.to(device)
    labels = labels.to(device)

    # Rest all is the same default code
    optimizer.zero_grad()
    outputs = model(images)
    loss = loss_criterion(outputs, labels)
    loss.backward()
    optimizer.step()

    cum_loss += batch_size*loss.item()

  print(f'For epoch = {epoch+1} : Loss = {cum_loss/total_images}')

Now to the Testing Part.

In [None]:
model.eval()
# Stop tracking gradients.
with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:

            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1) # Get the max along dimension 1 that is along the 10 out_features.
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            # Clear the memory.
            del images, labels, outputs

        print(f'Accuracy of the network on the 10000 images : {(correct/total)*100}%')

Let us display one image, its predicted label and the actual label.

In [None]:
index = 10
classes = [
    'airplane', 'automobile', 'bird', 'cat', 'deer',
    'dog', 'frog', 'horse', 'ship', 'truck'
]

test_image, label = test_dataset[index]
output = model(test_image.unsqueeze(0))
_, predicted = torch.max(output.data, 1)
print(f'Predicted: {classes[predicted]}')
mpimg.imshow(test_image.permute(1,2,0))
plt.title(f'Actual: {classes[label]} : Predicted: {classes[predicted]} ')
plt.show()