My code along with https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html

# Neural Networks

## Data

In [1]:
import sys
sys.path.append("..")

In [2]:
import torch
import torchvision
import torchvision.transforms as transforms

from helpers.stats import model_mem_size, mem_size, mem_summary

In [3]:
# Have a look at some images

import matplotlib.pyplot as plt
import numpy as np

def imshow(img):
    img = img / 2 + 0.5
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1,2,0)))
    plt.show()
    


# Data Loaders

In [4]:
# define trainload and testload?

_batch_size = 2000

tf = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        (0.5, 0.5, 0.5),
        (0.5, 0.5, 0.5),
    )
])


train_dataset = torchvision.datasets.CIFAR10(
    "./data",
    train=True,
    download=True,
    transform=tf,
)

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=_batch_size,
    shuffle=True,
    num_workers=4,
)

test_dataset = torchvision.datasets.CIFAR10(
    "./data",
    train=False,
    transform=tf,
    download=True,
)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    _batch_size,
    shuffle=True,
    num_workers=4,
)

# Not sure where these classes and their order actually came from
classes = (
    'plane',
    'car',
    'bird',
    'cat',
    'deer',
    'dog',
    'frog',
    'horse',
    'ship',
    'truck',
)

Files already downloaded and verified
Files already downloaded and verified


# Model Definitions

In [5]:
from typing import List
import torch.nn as nn
import torch.nn.functional as F

_N_IMAGE_CHANNELS = 3


class ConvStack(nn.Module):
    def __init__(
        self, 
        out_channels: List[int] = None,
        kernel_sizes: List[int] = None,
    ):
        super().__init__()
        self._out_channels = out_channels or [50, 50]
        self._kernel_sizes = kernel_sizes or [5, 5]
        
        # # not including max pool kernel as we're dealing with tiny images
        
        # self._pool_kernel = 1
        # self._pool_stride = 1
        
        # self.pool = nn.MaxPool2d(
        #     kernel_size=self._pool_kernel,
        #     stride=self._pool_stride,
        # )
        
        modules = [
            nn.Conv2d(
                in_channels=_N_IMAGE_CHANNELS,
                out_channels=self._out_channels[0],
                kernel_size=self._kernel_sizes[0],
            ),
            # self.pool,
        ]
        
        for i in range(1, len(self._out_channels)):
            modules.extend([
                nn.Conv2d(
                    in_channels=self._out_channels[i - 1],
                    out_channels=self._out_channels[i],
                    kernel_size=self._kernel_sizes[i],
                ),
                # self.pool,
            ])
        self.sequential = nn.Sequential(*modules)
        
        
    def get_output_size(self, image_size: int):
        # note that it's x for x in .. as opposed to x-1 due to the max pooling
        # converted back to x-1 as I'm no longer using max pooling
        output_size = image_size - sum(x - 1 for x in self._kernel_sizes)
        return output_size * output_size * self._out_channels[-1]
        
    def forward(self, x: torch.Tensor):
        return self.sequential(x)


class Net(nn.Module):
    def __init__(self, image_size: int):
        super().__init__()
        self._image_size = image_size
        self.conv_stack = ConvStack(
            out_channels=[1000, 600, 300, 100],
            kernel_sizes=[8, 7, 6, 5],
        )

        self.fc1 = nn.Linear(
            self.conv_stack.get_output_size(self._image_size),
            10_000,
        )
        self.fc2 = nn.Linear(10_000, 10_000)
        self.fc3 = nn.Linear(10_000, len(classes))

    def forward(self, x):
        x = self.conv_stack(x)
        # import ipdb
        # ipdb.set_trace()
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

## Helper

In [6]:
from contextlib import contextmanager
import time

@contextmanager
def _log(message: str, min_time=None):
    t_start = time.time()
    if min_time is None:
        print(f"{message}")
    yield
    duration = time.time() - t_start
    if min_time is None or min_time < duration:
        mem_summary()
        print(f"{message} complete ({duration:0.2f}s)")

# Train loop

In [None]:
import torch.optim as optim
from tqdm import tqdm
import time


min_log_time = 20
lr = 0.001

device = "cuda" if torch.cuda.is_available() else "cpu"

# define net
net = Net(image_size=32)
net.to(device)

model_mem_size(net)
mem_summary()

# define optimiser
optimiser = optim.SGD(net.parameters(), lr=lr)

# define loss function
loss_function = nn.CrossEntropyLoss()


epoch = 0
n_epochs = 20

for epoch in range(n_epochs):
    t_start = time.time()
    cumulative_loss = 0
    for i, (images, labels) in tqdm(enumerate(train_loader), total=len(train_loader)):
        
        with _log(f"loading data to {device}", min_time=min_log_time):
            images = images.to(device)
            labels = labels.to(device)

        with _log("making predictions", min_time=min_log_time):
            # make predictions
            y_pred = net(images)

        with _log("calculating loss", min_time=min_log_time):
            # calculate loss
            loss = loss_function(y_pred, labels)

        with _log("zeroing grad", min_time=min_log_time):
            # zero gradients
            optimiser.zero_grad()

        with _log("running backward pass", min_time=min_log_time):
            # calculate gradients from loss
            loss.backward()

        with _log("stepping optimiser", min_time=min_log_time):
            # step the optim
            optimiser.step()

        # print(f"epoch {epoch} iter {i}: {loss}")
        cumulative_loss += loss.item()
        torch.cuda.empty_cache()

    duration = time.time() - t_start
    print(f"epoch: {epoch}, duration: {duration}, loss: {cumulative_loss / (i + 1): 0.2f}")

903.9MiB
cuda:0: 926.0MiB (5.66%)
cuda:1: 0.0B (0.00%)


100%|██████████| 25/25 [05:21<00:00, 12.85s/it]

epoch: 0, duration: 321.3566553592682, loss:  2.30



100%|██████████| 25/25 [05:21<00:00, 12.84s/it]

epoch: 1, duration: 321.06624579429626, loss:  2.30



100%|██████████| 25/25 [05:21<00:00, 12.85s/it]

epoch: 2, duration: 321.2367694377899, loss:  2.30



 28%|██▊       | 7/25 [01:30<03:51, 12.84s/it]

In [None]:
PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)

In [None]:
images = None
net = None
optimiser = None
cumulative_loss = None
loss = None

import gc
gc.collect()

torch.cuda.ipc_collect()
torch.cuda.empty_cache()

In [211]:
from dataclasses import dataclass


@dataclass
class Accuracy:
    total: int
    correct: int
    # TODO: per class

    def __repr__(self):
        return f"{self.correct/self.total*100:0.3f}%"


def _calc_accuracy(
    model: Net, 
    dataset: torch.utils.data.DataLoader,
) -> Accuracy:
    # test accuracy
    
    with torch.no_grad():
        total = 0
        correct = 0
        for images, labels in dataset:
            images = images.to(device)
            output = model(images)
            predictions = torch.max(output, 1)[1].to("cpu")
            total += labels.size(0)
            correct += (predictions == labels).sum().item()

    return Accuracy(
        total=total,
        correct=correct,
    )


In [212]:
print(_calc_accuracy(net, train_loader))
print(_calc_accuracy(net, test_loader))

98.468%
54.510%


TODO: figure out how to effectively scale a conv nn

In [154]:
net.conv1.weight.size()

torch.Size([50, 3, 5, 5])

In [156]:
conv_out = net.conv1(images)

In [157]:
images.size()

torch.Size([20, 3, 32, 32])

In [158]:
conv_out.size()

torch.Size([20, 50, 28, 28])

In [160]:
pool_out = net.pool(conv_out)
pool_out.size()

torch.Size([20, 50, 14, 14])

In [164]:
conv_out_2 = F.relu(net.conv2(pool_out))

In [165]:
pool_out_2 = net.pool(conv_out_2)

In [167]:
pool_out_2.size()

torch.Size([20, 16, 5, 5])

In [None]:

correct = 0
total = 0
for images, labels in trainloader:
    output = net(images)
    _, predicted = torch.max(output, 1)
    mask = predicted == labels
    correct += mask.sum().item()
    total += mask.size(0)

# prepare to count predictions for each class
correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}

# again no gradients needed
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predictions = torch.max(outputs, 1)
        # collect the correct predictions for each class
        for label, prediction in zip(labels, predictions):
            if label == prediction:
                correct_pred[classes[label]] += 1
            total_pred[classes[label]] += 1


# print accuracy for each class
for classname, correct_count in correct_pred.items():
    accuracy = 100 * float(correct_count) / total_pred[classname]
    print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')