# Deeper Networks for Image Classification


## ResNet50 Model for Image Classification

- Code by: Kaviraj Gosaye
- Student ID: 220575371

### 0. Imports

In [None]:
# import libraries
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
from torchinfo import summary
from sklearn.metrics import confusion_matrix
from sklearn.metrics import precision_score, recall_score
import seaborn as sns

### 1. Data Loading and Preprocessing

In [None]:
# transform PIL image to tensor and normalize
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3), 
    transforms.Resize((224, 224)), 
    transforms.ToTensor(), 
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# load mnist dataset
train_set = torchvision.datasets.MNIST(root='./datasets', train=True, download=True, transform=transform)
test_set = torchvision.datasets.MNIST(root='./datasets', train=False, download=True, transform=transform)
                                    
train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True, num_workers=8)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=False, num_workers=8)

In [None]:
# extract one sample from the training set
dataiter = iter(train_loader)
images, labels = next(dataiter)

# plot the image
def imshow(img):
    # reverse normalization
    img = img / 2 + 0.5
    # convert tensor to numpy array
    npimg = img.numpy()
    # rearrange the dimensions to match matplotlib format
    # matplotlib:   H x W x C
    # torch:        C x H x W
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

# show images
imshow(torchvision.utils.make_grid(images))

### 2. Model Building

In [None]:
class Bottleneck(nn.Module):

    # self, input channels, number of channels for 3x3 conv, expansion factor, stride
    def __init__(self, in_channels, intermediate_channels, expansion, stride):

        super(Bottleneck,self).__init__()

        self.expansion = expansion
        self.in_channels = in_channels
        self.intermediate_channels = intermediate_channels
        
        # if identity mapping is possible
        if self.in_channels == self.intermediate_channels*self.expansion:
            self.identity = True
        # else projection mapping is required
        else:
            self.identity = False
            projection_layer = []
            projection_layer.append(nn.Conv2d(in_channels=self.in_channels, out_channels=self.intermediate_channels*self.expansion, kernel_size=1, stride=stride, padding=0, bias=False ))
            projection_layer.append(nn.BatchNorm2d(self.intermediate_channels*self.expansion))
            self.projection = nn.Sequential(*projection_layer)

        self.relu = nn.ReLU()


        self.conv1_1x1 = nn.Conv2d(in_channels=self.in_channels, out_channels=self.intermediate_channels, kernel_size=1, stride=1, padding=0, bias=False )
        self.batchnorm1 = nn.BatchNorm2d(self.intermediate_channels)
        
        # 3x3
        self.conv2_3x3 = nn.Conv2d(in_channels=self.intermediate_channels, out_channels=self.intermediate_channels, kernel_size=3, stride=stride, padding=1, bias=False )
        self.batchnorm2 = nn.BatchNorm2d(self.intermediate_channels)
        
        # 1x1
        self.conv3_1x1 = nn.Conv2d(in_channels=self.intermediate_channels, out_channels=self.intermediate_channels*self.expansion, kernel_size=1, stride=1, padding=0, bias=False )
        self.batchnorm3 = nn.BatchNorm2d( self.intermediate_channels*self.expansion )
        
    def forward(self, x):
        # store input for pre-final layer
        in_x = x

        x = self.relu(self.batchnorm1(self.conv1_1x1(x)))

        x = self.relu(self.batchnorm2(self.conv2_3x3(x)))
        
        x = self.batchnorm3(self.conv3_1x1(x))

        # identity or projected mapping
        if self.identity:
            x += in_x
        else:
            x += self.projection(in_x)

        # final relu
        x = self.relu(x)
        
        return x

In [None]:
class ResNet50(nn.Module):
    # self, number of channels, image channels (3), output
    def __init__(self, resnet_channels, in_channels, num_classes):

        super(ResNet50,self).__init__()
        self.channels_list = resnet_channels[0]
        self.repeatition_list = resnet_channels[1]
        self.expansion = resnet_channels[2]

        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=7, stride=2, padding=3, bias=False )
        self.batchnorm1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()

        self.maxpool = nn.MaxPool2d(kernel_size=3,stride=2,padding=1)

        self.block1 = self.create_blocks( 64 , self.channels_list[0], self.repeatition_list[0], self.expansion, stride=1 )
        self.block2 = self.create_blocks( self.channels_list[0]*self.expansion , self.channels_list[1], self.repeatition_list[1], self.expansion, stride=2 )
        self.block3 = self.create_blocks( self.channels_list[1]*self.expansion , self.channels_list[2], self.repeatition_list[2], self.expansion, stride=2 )
        self.block4 = self.create_blocks( self.channels_list[2]*self.expansion , self.channels_list[3], self.repeatition_list[3], self.expansion, stride=2 )

        self.average_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Linear( self.channels_list[3]*self.expansion , num_classes)


    def forward(self, x):
        x = self.relu(self.batchnorm1(self.conv1(x)))
        x = self.maxpool(x)
        
        x = self.block1(x)
        
        x = self.block2(x)
        
        x = self.block3(x)
        
        x = self.block4(x)
        
        x = self.average_pool(x)

        x = torch.flatten(x, start_dim=1)
        
        x = self.fc1(x)
        
        return x

    # self, input channels, intermediate channels, number of repeats, expansion factor, stride
    def create_blocks(self, in_channels, intermediate_channels, num_repeat, expansion, stride):
        layers = [] 
        layers.append(Bottleneck(in_channels,intermediate_channels,expansion,stride=stride))
        for num in range(1,num_repeat):
            layers.append(Bottleneck(intermediate_channels*expansion,intermediate_channels,expansion,stride=1))

        return nn.Sequential(*layers)

In [None]:
# setting the device to cuda if available
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# creating instance of model and setting it to the device
architecture = ([64,128,256,512], [3,4,6,3], 4)
resnet_50 = ResNet50(architecture , in_channels=3, num_classes=1000).to(device)

In [None]:
# visualize the model
info = summary(resnet_50, (3,3, 224, 224), col_names = ('input_size', 'output_size', 'num_params', 'kernel_size'))
print(info)

### 3. Model Training

In [None]:
# loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(resnet_50.parameters(), lr=0.001)

# training the model
start = time.time()

num_epochs = 1
losses = []
train_accs = []

for epoch in range(num_epochs):
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)

        optimizer.zero_grad()

        outputs = resnet_50(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # cumulative loss
        running_loss += loss.item()
        # printing the average loss every 100 mini-batches
        if i % 100 == 99:
            print(f"[{epoch + 1}, {i + 1}] loss: {running_loss / 100}")
            running_loss = 0.0

    correct = 0.0
    pred = outputs.argmax(dim=1, keepdim=True)
    # reshaping the labels to match the shape of the predictions
    # comparing the predictions to the labels using element-wise comparison
    # summing the correct predictions
    correct += pred.eq(labels.view_as(pred)).sum().item()
    train_acc = 100. * correct / len(outputs)
    train_accs.append(train_acc)
    losses.append(loss.item())
    
print(f"Finished Training after {time.time() - start} seconds")

In [None]:
# save model after training
torch.save(resnet_50, "./Models/resnet_50_mnist.pth")

### 4. Model Evaluation

In [None]:
resnet_50.eval()
# Disable gradient calculation
with torch.no_grad():
    correct = 0
    total = 0
    predicted_labels = []
    true_labels = []
    test_loss = 0

    # Using test set
    for data in test_loader:
        images, labels = data[0].to(device), data[1].to(device)
        
        # Forward pass
        outputs = resnet_50(images)
        
        # Calculate the test loss
        loss = criterion(outputs, labels)
        test_loss += loss.item()

        # Get the predicted labels
        _, predicted = torch.max(outputs.data, 1)
        
        # Update the total and correct predictions
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        # Append the predicted and true labels
        predicted_labels.extend(predicted.cpu().numpy())
        true_labels.extend(labels.cpu().numpy())

    # Calculate the accuracy
    accuracy = 100 * correct / total
    test_loss /= len(test_loader)

# Print the accuracy and test loss
print(f"Accuracy on the test data: {accuracy}%")
print(f"Test Loss: {test_loss}")

# Create the confusion matrix
cm = confusion_matrix(true_labels, predicted_labels)

# Plot the confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")
plt.xlabel("Predicted Labels")
plt.ylabel("True Labels")
plt.title("Confusion Matrix")
plt.show()

In [None]:
# Plot the loss
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.show()

In [None]:
# Plot the accuracy
plt.plot(train_accs)
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training Accuracy')
plt.show()

In [None]:
# Convert lists to tensors
pred_labels_tensor = torch.tensor(predicted_labels)
true_labels_tensor = torch.tensor(true_labels)

# Calculate precision and recall
precision = precision_score(true_labels_tensor, pred_labels_tensor, average='weighted')
recall = recall_score(true_labels_tensor, pred_labels_tensor, average='weighted')

print("Precision:", precision)
print("Recall:", recall)