**import libraries**

In [27]:
import torch
import torch.nn as nn
import torch.optim as optim
import math
from torchvision import datasets, transforms
from torch.autograd import Function
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
import sklearn.cluster

In [28]:
use_gpu = True
device = torch.device("cuda:0" if torch.cuda.is_available() and use_gpu else "cpu")
device

device(type='cuda', index=0)

**custom linear layer**

In [29]:
class LinearFunction(Function):
    @staticmethod
    def forward(ctx, input, weight, bias=None):
        ctx.save_for_backward(input, weight, bias)
        output = input.mm(weight.t()) + bias
        return output

    @staticmethod
    def backward(ctx, grad_output):
        #grad_output -> dLoss/dy_hat
        input, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None

        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)

        return grad_input, grad_weight, grad_bias

In [30]:
class MyLinearLayer(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.input_size = input_size
        self.output_size = output_size

        self.weight = nn.Parameter(torch.randn(output_size, input_size))
        self.bias = nn.Parameter(torch.randn(output_size))

    def forward(self, input):
        return LinearFunction.apply(input, self.weight, self.bias)


**MSE Loss**

In [31]:
class MSELossFunction(Function):
    @staticmethod
    def forward(ctx, y_pred, y):
        y = y.view(y.shape[0], -1)
        ctx.save_for_backward(y_pred, y)
        loss = ( (y - y_pred)**2 ).mean()

        return  loss

    @staticmethod
    def backward(ctx, grad_output):
        y_pred, y = ctx.saved_tensors
        grad_input = 2 * (y_pred - y) / y_pred.shape[0]
        return grad_input, None


In [32]:
class MSELoss(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, input, target):
        return MSELossFunction.apply(input, target)


**custom cross entropy loss**

In [33]:
class CrossEntropyLossFunction(Function):
    @staticmethod
    def forward(ctx, output, target):
        output_softmax = F.log_softmax(output, dim=1)

        one_hot_labels = torch.zeros_like(output_softmax)
        one_hot_labels.scatter_(1, target.view(-1, 1), 1)

        ctx.save_for_backward(output_softmax, one_hot_labels)

        loss = torch.sum(-one_hot_labels * output_softmax, dim=1).mean()

        return loss

    @staticmethod
    def backward(ctx, grad_output):
        output, target = ctx.saved_tensors

        grad_input = (F.softmax(output, dim=1) - target)/output.shape[0]

        return grad_input, None


In [34]:
class CrossEntropyLoss(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, input, target):
        return CrossEntropyLossFunction.apply(input, target)


**model with one layer**

In [35]:
class BasicModel(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.layer = nn.Linear(input_size, output_size)

    def forward(self, x):
        x = x.view(-1, 784)
        return self.layer(x)


**dataLoader**

In [36]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

**model and loss function**

In [37]:
input_size = 784
output_size = 10

model = BasicModel(input_size, output_size)
criterion = CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

model = model.to(device)

**train function**

In [38]:
def train(model, dataloader, loss_fn, optimizer, num_epochs=50):
    for epoch in range(num_epochs):
        model.train()
        num_samples = len(dataloader.dataset)
        num_batches = len(dataloader)
        running_corrects = 0
        running_loss = 0.0
        for index, (inputs, labels) in enumerate(dataloader):
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            _, preds = torch.max(outputs, dim=1)
            running_corrects += torch.sum(preds == labels)
            running_loss += loss.item()

        epoch_loss = (running_loss / num_batches)
        epoch_acc = (running_corrects / num_samples) * 100
        print(f"epoch {epoch+1} -> Loss: {epoch_loss}, accuracy: {epoch_acc}")

In [39]:
train(model, train_loader, criterion, optimizer)

epoch 1 -> Loss: 1.3391249592997816, accuracy: 85.67166900634766
epoch 2 -> Loss: 1.244700111146929, accuracy: 87.70000457763672
epoch 3 -> Loss: 1.2359495929190154, accuracy: 87.99833679199219
epoch 4 -> Loss: 1.2293428018657384, accuracy: 88.1500015258789
epoch 5 -> Loss: 1.2511449682353528, accuracy: 88.31666564941406
epoch 6 -> Loss: 1.2114849072466018, accuracy: 88.60000610351562
epoch 7 -> Loss: 1.2187365383378415, accuracy: 88.59833526611328
epoch 8 -> Loss: 1.1917533148040396, accuracy: 88.72833251953125
epoch 9 -> Loss: 1.1643823875563104, accuracy: 88.84500122070312
epoch 10 -> Loss: 1.1817265916774586, accuracy: 88.70166778564453
epoch 11 -> Loss: 1.1429401946061455, accuracy: 89.05833435058594
epoch 12 -> Loss: 1.1819109857908443, accuracy: 88.91666412353516
epoch 13 -> Loss: 1.157093072929648, accuracy: 89.01167297363281
epoch 14 -> Loss: 1.153664670836951, accuracy: 88.99166870117188
epoch 15 -> Loss: 1.2006137848361087, accuracy: 88.92166900634766
epoch 16 -> Loss: 1.214

**test function**

In [40]:
def test(model, dataloader):
    model.eval()
    correct = 0
    total_correct = 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs.data, 1)
            total_correct += labels.size(0)
            correct += (preds == labels).sum().item()

    print(f"Accuracy: {(correct / total_correct) * 100}%")


In [41]:
test(model, test_loader)

Accuracy: 88.92999999999999%
