In [87]:
import os
import random
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader

In [88]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()

        self.conv1 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=(3, 3),
            stride=stride,
            padding=1,
            bias=False,
        )

        self.bn1 = nn.BatchNorm2d(out_channels)

        self.conv2 = nn.Conv2d(
            in_channels=out_channels,
            out_channels=out_channels,
            kernel_size=(3, 3),
            stride=1,
            padding=1,
            bias=False,
        )

        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    kernel_size=(1, 1),
                    stride=stride,
                    bias=False,
                ), nn.BatchNorm2d(out_channels))

    def forward(self, x):
        out = F.relu((self.bn1(self.conv1(x))))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)

        return out

In [89]:
class ResNet(nn.Module):
    def __init__(self, num_classes=10):
        super(ResNet, self).__init__()

        # Initial input conv
        self.conv1 = nn.Conv2d(in_channels=3,
                               out_channels=64,
                               kernel_size=(3, 3),
                               stride=1,
                               padding=1,
                               bias=False)

        self.bn1 = nn.BatchNorm2d(64)

        # Create blocks
        self.block1 = self._create_block(64, 64, stride=1)
        self.block2 = self._create_block(64, 128, stride=2)
        self.block3 = self._create_block(128, 256, stride=2)
        self.block4 = self._create_block(256, 512, stride=2)
        self.linear = nn.Linear(512, num_classes)

    # A block is just two residual blocks for ResNet18
    def _create_block(self, in_channels, out_channels, stride):
        return nn.Sequential(ResidualBlock(in_channels, out_channels, stride),
                             ResidualBlock(out_channels, out_channels, 1))

    def forward(self, x):
        # Output of one layer becomes input to the next
        out = nn.ReLU()(self.bn1(self.conv1(x)))
        out = self.block1(out)
        out = self.block2(out)
        out = self.block3(out)
        out = self.block4(out)
        out = nn.AvgPool2d(4)(out)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

In [91]:
with open("../data/cifar/labels.txt") as label_file:
    labels = label_file.read().split()
    label_mapping = dict(zip(labels, list(range(len(labels)))))

In [92]:
def preprocess(image):

    # Convert PIL image to numpy array
    image = np.array(image).astype(np.float32)

    # cifar_mean = np.array([0.4914, 0.4822, 0.4465]).reshape(1, 1, -1)
    # cifar_std = np.array([0.2023, 0.1994, 0.2010]).reshape(1, 1, -1)
    # image = (image - cifar_mean) / cifar_std

    # Reshape from [W,H,C] to [C,H,W]
    image = image.transpose(2, 0, 1)
    print(image.shape)
    return image

In [93]:
class Cifar10Dataset(Dataset):
    def __init__(self, data_dir, transform=None):
        files = os.listdir(data_dir)
        files = [os.path.join(data_dir, x) for x in files]

        self.data_size = len(files)
        self.files = random.sample(files, self.data_size)
        self.transform = transform

    def __len__(self):
        return self.data_size

    def __getitem__(self, idx):
        image_address = self.files[idx]
        image = Image.open(image_address)
        image = preprocess(image)

        label_name = image_address[:-4].split("_")[-1]
        label = label_mapping[label_name]

        image = image.astype(np.float32)

        if self.transform:
            image = self.transform(image)

        return image, label

In [95]:
EPOCHS = 1
BATCH_SIZE = 512
LR = 0.01
WEIGHT_DECAY = 5e-4

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

training_dataset = Cifar10Dataset(data_dir="../data/cifar/train/",
                                  transform=transform)

trainloader = DataLoader(training_dataset,
                         batch_size=BATCH_SIZE,
                         shuffle=True,
                         num_workers=0)

test_dataset = Cifar10Dataset(data_dir="../data/cifar/test/", transform=transform)

testloader = DataLoader(test_dataset,
                        batch_size=BATCH_SIZE,
                        shuffle=True,
                        num_workers=0)

device = "cuda" if torch.cuda.is_available() else "cpu"
net = ResNet().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(),
                      lr=LR,
                      momentum=0.9,
                      weight_decay=WEIGHT_DECAY)

50000
