In [32]:
from torch import nn, optim
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import numpy as np

import matplotlib.pyplot as plt

In [33]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [34]:
class ResidualBlock(nn.Module):
    
    '''
    A residual block module for use in a neural network architecture.

    Args:
    - in_channels (int): Number of input channels.
    - out_channels (int): Number of output channels.
    - kernel_size (int, optional): Size of the convolutional kernel. Default is 3.
    - stride (int, optional): Stride of the convolutional operation. Default is 1.

    Attributes:
    - layers (nn.Sequential): Sequential module containing convolutional layers, batch normalization, and ReLU activation.
    - relu (nn.ReLU): ReLU activation function.
    - adjust_conv_1 (nn.Conv2d): Convolutional layer for adjusting input channels to match output channels.

    '''
    
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1):
        
        super().__init__()
        
        self.layers = nn.Sequential(  
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=1),
            nn.BatchNorm2d(num_features=out_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, padding=1),
            nn.BatchNorm2d(num_features=out_channels),
        )
        
        self.relu = nn.ReLU()
        
        self.adjust_conv_1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, stride=stride, kernel_size=kernel_size, padding=1)
    
    def forward(self, X):
        
        """
        Forward pass through the residual block.

        Args:
        - X (torch.Tensor): Input tensor.

        Returns:
        - torch.Tensor: Output tensor after passing through the residual block.
        """
        
        out = self.layers(X)
        
        # print("output: ", out.shape)
        # print("input_shape: ", X.shape)
        X = self.adjust_conv_1(X)
        
        out += X
        out = self.relu(out)
        return out
    
        

In [39]:
class ResNet(nn.Module):
    
    '''
    A Residual Neural Network (ResNet) implementation using residual blocks for image classification.

    Args:
    - num_classes (int, optional): Number of classes in the classification task. Default is 10.

    Attributes:
    - loss_log (list): List to store the training loss for each epoch.
    - accuracy_log (list): List to store the training accuracy for each epoch.
    - residual_layers (nn.Sequential): Sequential module containing the residual blocks and final classification layer.

    Methods:
    - forward(X): Forward pass through the ResNet.
    - fit(data, loss_func, optimizer, epochs, device): Train the ResNet on the provided data.
    - evaluate(dataloader): Evaluate the performance of the trained ResNet on the provided dataloader.
    '''
    
    def __init__(self, num_classes=10):
        
        super().__init__()
        
        self.loss_log = []
        self.accuracy_log = []
        
        self.residual_layers = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=2, padding=1),
            
            ResidualBlock(in_channels=64, out_channels=64, kernel_size=3, stride=2),
            ResidualBlock(in_channels=64, out_channels=64, kernel_size=3, stride=1),
            
            ResidualBlock(in_channels=64, out_channels=128, kernel_size=3, stride=2),
            ResidualBlock(in_channels=128, out_channels=128, kernel_size=3, stride=1),
            
            # ResidualBlock(in_channels=128, out_channels=256, kernel_size=3, stride=2),
            # ResidualBlock(in_channels=256, out_channels=256, kernel_size=3, stride=1),
            
            # ResidualBlock(in_channels=256, out_channels=512, kernel_size=3, stride=2),
            # ResidualBlock(in_channels=512, out_channels=512, kernel_size=3, stride=1),
            
            
            nn.AvgPool2d(kernel_size=7),
            nn.Flatten(),
            nn.Linear(in_features=2048, out_features=num_classes),
        )
        
        # # conv1: input (3, 224, 224) -> output: (64, 112, 112)
        # self.layer_1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=2)
        
        # # conv2_x: input(64, 112, 112) -> output: (64, 56, 56)
        # self.max_pool_2 = nn.MaxPool2d(kernel_size=3, stride=2)
        # self.res_block_2_1 = ResidualBlock(in_channels=64, out_channels=64, kernel_size=3, stride=2)
        # self.res_block_2_2 = ResidualBlock(in_channels=64, out_channels=64, kernel_size=3, stride=1)
        
        # # conv3_x: input(64, 56, 56) -> output: (128, 28, 28)
        # self.res_block_3_1 = ResidualBlock(in_channels=64, out_channels=128, kernel_size=3, stride=2)
        # self.res_block_3_2 = ResidualBlock(in_channels=128, out_channels=128, kernel_size=3, stride=1)
        
        # # conv4_x: input(128, 28, 28) -> output: (256, 14, 14)
        # self.res_block_3_1 = ResidualBlock(in_channels=128, out_channels=256, kernel_size=3, stride=2)
        # self.res_block_3_2 = ResidualBlock(in_channels=256, out_channels=256, kernel_size=3, stride=1)
        
        # # conv5_x: input(256, 14, 14) -> output: (512, 7, 7)
        # self.res_block_3_1 = ResidualBlock(in_channels=256, out_channels=512, kernel_size=3, stride=2)
        # self.res_block_3_2 = ResidualBlock(in_channels=512, out_channels=512, kernel_size=3, stride=1)
        
        # # dense layers
        # self.avg_pool = nn.AvgPool2d(kernel_size=7)
        # self.flatten = nn.Flatten()
        # self.linear_layer = nn.Linear(in_features=512, out_features=output_classes)
        
    def forward(self, X):
        
        """
        Forward pass through the ResNet.

        Args:
        - X (torch.Tensor): Input tensor.

        Returns:
        - torch.Tensor: Output tensor after passing through the ResNet.
        """
        
        out = self.residual_layers(X)
        return out
    
    def fit(self, data, loss_func=nn.CrossEntropyLoss, optimizer=optim.Adam, epochs=10, device="cpu"):
        
        """
        Train the ResNet on the provided data.

        Args:
        - data (torch.utils.data.DataLoader): Data loader containing training data.
        - loss_func (torch.nn.modules.loss._Loss, optional): Loss function for training. Default is nn.CrossEntropyLoss.
        - optimizer (torch.optim.Optimizer, optional): Optimizer for training. Default is optim.Adam.
        - epochs (int, optional): Number of epochs for training. Default is 10.
        - device (str, optional): Device to use for training, 'cpu' or 'cuda'. Default is 'cpu'.
        """
        
        self.to(device=device)
        
        loss_func = loss_func()
        optimizer = optimizer(self.parameters(), lr=0.001)
        
        # Training loop
        for epoch in range(epochs):
            self.train()
            running_loss = 0.0
            correct_predictions = 0
            total_samples = 0
            
            for images, labels in data:
                images, labels = images.to(device), labels.to(device)
                outputs = self(images)
                
                # print(labels.shape, outputs.shape)
                
                loss = loss_func(outputs, labels)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                
                running_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total_samples += labels.size(0)
                correct_predictions += (predicted == labels).sum().item()
            
            epoch_loss = running_loss / len(data)
            epoch_accuracy = correct_predictions / total_samples
            
            self.loss_log.append(epoch_loss)
            self.accuracy_log.append(epoch_accuracy)
            
            print(f'Epoch [{epoch+1}/{epochs}], Loss: {epoch_loss:.4f}, Accuracy: {100 * epoch_accuracy:.2f}%')

        print('Finished Training')
        
    def evaluate(self, dataloader):
        
        """
        Evaluate the performance of the trained ResNet on the provided dataloader.

        Args:
        - dataloader (torch.utils.data.DataLoader): Data loader containing evaluation data.

        Returns:
        - float: Accuracy of the ResNet on the evaluation data.
        """

        
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        
        correct_predictions = 0
        total_samples = 0
        
        self.to(device=device)
        self.eval()
        
        with torch.no_grad():
            for images, labels in dataloader:
                
                images, labels = images.to(device), labels.to(device)
                
                
                outputs = self(images)
                
                _, predicted = torch.max(outputs.data, dim=1)
                
                total_samples += labels.shape[0]
                correct_predictions += (predicted == labels).sum().item()
                
        
        return correct_predictions / total_samples
        


In [40]:
# Define data transformations
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(), # Convert PIL Image to tensor
    transforms.Normalize((0.5,), (0.5,)) # Normalize the pixel values to the range [-1, 1]
])

# Load FashionMNIST dataset
fashion_mnist_train = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
fashion_mnist_test = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

# Create data loaders
train_loader = torch.utils.data.DataLoader(dataset=fashion_mnist_train, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=fashion_mnist_test, batch_size=128, shuffle=False)

In [43]:
m = ResNet(num_classes=10)

In [44]:
m.fit(train_loader, device=device, epochs=10)

Epoch [1/10], Loss: 0.4161, Accuracy: 84.63%
Epoch [2/10], Loss: 0.2593, Accuracy: 90.67%
Epoch [3/10], Loss: 0.2205, Accuracy: 92.09%
Epoch [4/10], Loss: 0.1947, Accuracy: 93.00%
Epoch [5/10], Loss: 0.1777, Accuracy: 93.45%
Epoch [6/10], Loss: 0.1610, Accuracy: 94.14%
Epoch [7/10], Loss: 0.1491, Accuracy: 94.62%
Epoch [8/10], Loss: 0.1339, Accuracy: 95.22%
Epoch [9/10], Loss: 0.1215, Accuracy: 95.70%
Epoch [10/10], Loss: 0.1060, Accuracy: 96.21%
Finished Training


In [47]:
# torch.save(m, "./model_resnet.pth")

In [48]:
model = torch.load("./model_resnet.pth")

In [49]:
model.evaluate(train_loader)

0.9611166666666666

In [50]:
model.evaluate(test_loader)

0.9221