In [9]:
import random
import numpy as np
from PIL import Image
import torch
from torch import nn
from datasets import load_dataset
from helpers import get_device, train, evaluate

In [10]:
torch.manual_seed(0)
device = get_device()

In [11]:
def preprocess(x, eps=1e-7):
    mean = np.array([0.4914, 0.4822, 0.4465]).reshape(1, 1, 1, -1)
    std = np.array([0.2023, 0.1994, 0.2010]).reshape(1, 1, 1, -1)
    x = x.astype(np.float32)
    x = (x - mean) / (std + eps)
    return x


def random_crop(image, crop_size=30):
    h, w, _ = image.shape
    x = random.randint(0, w - crop_size)
    y = random.randint(0, h - crop_size)
    return image[x:x+crop_size, y:y+crop_size, :]


def random_flip(image, flip_prob=0.5):
    if random.random() < flip_prob:
        image = np.fliplr(image)
    return image


def transform(x):
    x = np.array([random_crop(y) for y in x]).transpose(0, 3, 1, 2)
    x = ((Image.fromarray(z).resize((224, 224)) for z in y) for y in x)
    x = np.stack([random_flip(np.stack([np.asarray(z) for z in y], axis=0)) for y in x], axis=0)
    return x


def target_transform(x):
    x = x.transpose(0, 3, 1, 2)
    x = ((Image.fromarray(z).resize((224, 224)) for z in y) for y in x)
    x = np.stack([np.stack([np.asarray(z) for z in y], axis=0) for y in x], axis=0)
    return x

In [12]:
dataset = load_dataset("cifar10")

X_train = preprocess(np.array([np.array(image) for image in dataset["train"]["img"]]))
Y_train = np.array(dataset["train"]["label"], dtype=np.int32)

X_test = preprocess(np.array([np.array(image) for image in dataset["test"]["img"]]))
Y_test = np.array(dataset["test"]["label"], dtype=np.int32)

In [13]:
class BasicBlock(nn.Module):
    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super().__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, 3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, 3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample

    def __call__(self, x):
        out = self.bn1(self.conv1(x)).relu()
        out = self.bn2(self.conv2(out))
        if self.downsample is not None:
            x = self.downsample(x)
        out = (out + x).relu()
        return out

In [14]:
class ResNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.inplanes = 64
        self.conv1 = nn.Conv2d(3, 64, 7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
        self.layer1 = self._make_layer(64, 3, stride=1)
        self.layer2 = self._make_layer(128, 4, stride=2)
        self.layer3 = self._make_layer(256, 6, stride=2)
        self.layer4 = self._make_layer(512, 3, stride=2)
        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc = nn.Linear(512, 10)

    def _make_layer(self, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes, 1, stride=stride, bias=False),
                nn.BatchNorm2d(planes),
            )
        layers = [BasicBlock(self.inplanes, planes, stride, downsample)]
        self.inplanes = planes
        for _ in range(1, blocks):
            layers.append(BasicBlock(self.inplanes, planes))
        return nn.Sequential(*layers)

    def __call__(self, x):
        x = self.bn1(self.conv1(x)).relu()
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc(x)
        return x

In [15]:
epochs = 10
batch_size = 64
learning_rate = 0.005
momentum = 0.9
steps = len(X_train) // batch_size

In [16]:
model = ResNet().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=learning_rate/epochs, momentum=momentum)
for _ in range(epochs):
    train(model, X_train, Y_train, optimizer, steps, batch_size=batch_size, transform=transform, device=device)
    evaluate(model, X_test, Y_test, target_transform=target_transform, device=device)

loss 1.21 accuracy 0.58: 100%|██████████| 781/781 [08:44<00:00,  1.49it/s]
100%|██████████| 79/79 [00:42<00:00,  1.84it/s]


test set accuracy is 0.5589


loss 0.88 accuracy 0.69: 100%|██████████| 781/781 [09:14<00:00,  1.41it/s]
100%|██████████| 79/79 [00:48<00:00,  1.63it/s]


test set accuracy is 0.5689


loss 0.77 accuracy 0.73: 100%|██████████| 781/781 [09:18<00:00,  1.40it/s]
100%|██████████| 79/79 [00:48<00:00,  1.62it/s]


test set accuracy is 0.6955


loss 0.70 accuracy 0.78: 100%|██████████| 781/781 [09:12<00:00,  1.41it/s]
100%|██████████| 79/79 [01:02<00:00,  1.25it/s]


test set accuracy is 0.7501


loss 0.41 accuracy 0.83: 100%|██████████| 781/781 [09:16<00:00,  1.40it/s]
100%|██████████| 79/79 [00:52<00:00,  1.51it/s]


test set accuracy is 0.7183


loss 0.47 accuracy 0.84: 100%|██████████| 781/781 [09:11<00:00,  1.42it/s]
100%|██████████| 79/79 [00:51<00:00,  1.52it/s]


test set accuracy is 0.756


loss 0.38 accuracy 0.83: 100%|██████████| 781/781 [09:17<00:00,  1.40it/s]
100%|██████████| 79/79 [00:50<00:00,  1.57it/s]


test set accuracy is 0.7712


loss 0.64 accuracy 0.77: 100%|██████████| 781/781 [09:09<00:00,  1.42it/s]
100%|██████████| 79/79 [00:51<00:00,  1.55it/s]


test set accuracy is 0.7738


loss 0.48 accuracy 0.84: 100%|██████████| 781/781 [09:16<00:00,  1.40it/s]
100%|██████████| 79/79 [00:48<00:00,  1.64it/s]


test set accuracy is 0.7839


loss 0.36 accuracy 0.91: 100%|██████████| 781/781 [09:16<00:00,  1.40it/s]
100%|██████████| 79/79 [01:05<00:00,  1.20it/s]

test set accuracy is 0.8261



