<a href="https://colab.research.google.com/github/ell-hol/stonks-wid-codex/blob/main/simple_classification_CIFAR10.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
"""
A simple classification model based on a pretrained ResNet18 backbone.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision.models import resnet18


class ResNet18(nn.Module):
    def __init__(self, n_classes):
        super().__init__()
        self.resnet = resnet18(pretrained=True)
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, n_classes)
        # self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3,
        #                               bias=False)

    def forward(self, x):
        x = self.resnet(x)
        return x

"""
Train the defined ResNet18 model on the CIFAR10 dataset and evaluate its precision, recall and accuracy
"""
from torchvision import datasets, transforms

import torch.optim as optim
from torch.autograd import Variable

from sklearn.metrics import precision_score, recall_score, accuracy_score

import os

model = ResNet18(n_classes=10)
model.cuda()

epochs = 20
batch_size = 100
learning_rate = 1e-3
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

print("Number of train samples: ", len(trainloader)*batch_size)
print("Number of test samples: ", len(testloader)*batch_size)

criterion = nn.CrossEntropyLoss()

overall_step = 0
running_loss = 0.0
running_corrects = 0

for epoch in range(epochs):
    print("Epoch {}/{}".format(epoch, epochs))
    for i, data in enumerate(trainloader, 0):
        # get the inputs
        inputs, labels = data
        inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda())

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        _, preds = torch.max(outputs.data, 1)
        running_corrects += torch.sum(preds == labels.data).item()

        if overall_step % 50 == 49:  # print every 50 mini-batches
            print('[%d, %5d] loss: %.3f, accuracy: %.3f' %
                  (epoch + 1, i + 1, running_loss / 50, running_corrects / (50 * batch_size)))
            running_loss = 0.0
            running_corrects = 0.0

        overall_step += 1

print('Finished Training')

# Test model on test data
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images, labels = images.cuda(), labels.cuda()
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

# Compute precision, recall and accuracy for each class
labels_list = []
predicted_list = []

for data in testloader:
    images, labels = data
    images, labels = images.cuda(), labels.cuda()
    outputs = model(images)
    _, predicted = torch.max(outputs.data, 1)
    labels_list += labels.cpu().numpy().tolist()
    predicted_list += predicted.cpu().numpy().tolist()

precision = precision_score(labels_list, predicted_list, average='macro')
recall = recall_score(labels_list, predicted_list, average='macro')
accuracy = accuracy_score(labels_list, predicted_list)

print("Precision: %.3f" % precision)
print("Recall: %.3f" % recall)
print("Accuracy: %.3f" % accuracy)


# Save model
torch.save(model.state_dict(), './cifar10_resnet18.pth')

Files already downloaded and verified
Files already downloaded and verified
Number of train samples:  50000
Number of test samples:  10000
Epoch 0/20


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


[1,    50] loss: 1.649, accuracy: 0.422
[1,   100] loss: 1.224, accuracy: 0.578
[1,   150] loss: 1.093, accuracy: 0.634
[1,   200] loss: 1.072, accuracy: 0.636
[1,   250] loss: 0.970, accuracy: 0.670


KeyboardInterrupt: ignored