In [1]:
import random
import numpy as np
from PIL import Image
import torch
from torch import nn
from datasets import load_dataset
from helpers import get_device, train, evaluate

In [2]:
torch.manual_seed(0)
device = get_device()

In [3]:
def preprocess(x):
    x = x.reshape(-1, 3, 32, 32).astype(np.float32)
    std = x.std(axis=(0, 2, 3))
    mean = x.mean(axis=(0, 2, 3))
    x = ((x - mean.reshape(1, -1, 1, 1)) / std.reshape(1, -1, 1, 1))
    return x


def random_flip(image, flip_prob=0.5):
    if random.random() < flip_prob:
        image = np.flip(image, axis=1)
    return image


def transform(x):
    x = ((Image.fromarray(z).resize((227, 227)) for z in y) for y in x)
    x = np.stack([random_flip(np.stack([np.asarray(z) for z in y], axis=0)) for y in x], axis=0)
    return x


def target_transform(x):
    x = ((Image.fromarray(z).resize((227, 227)) for z in y) for y in x)
    x = np.stack([np.stack([np.asarray(z) for z in y], axis=0) for y in x], axis=0)
    return

In [4]:
dataset = load_dataset("cifar10")

X_train = preprocess(np.array([np.array(image) for image in dataset["train"]["img"]]))
Y_train = np.array(dataset["train"]["label"], dtype=np.int32)

X_test = preprocess(np.array([np.array(image) for image in dataset["test"]["img"]]))
Y_test = np.array(dataset["test"]["label"], dtype=np.int32)

In [5]:
class BasicBlock(nn.Module):
    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super().__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, 3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, 3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample

    def __call__(self, x):
        out = self.bn1(self.conv1(x)).relu()
        out = self.bn2(self.conv2(out))
        if self.downsample is not None:
            x = self.downsample(x)
        out = (out + x).relu()
        return out

In [6]:
class ResNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.inplanes = 64
        self.conv1 = nn.Conv2d(3, 64, 7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
        self.layer1 = self._make_layer(64, 3, stride=1)
        self.layer2 = self._make_layer(128, 4, stride=2)
        self.layer3 = self._make_layer(256, 6, stride=2)
        self.layer4 = self._make_layer(512, 3, stride=2)
        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc = nn.Linear(512, 10)

    def _make_layer(self, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes, 1, stride=stride, bias=False),
                nn.BatchNorm2d(planes),
            )
        layers = [BasicBlock(self.inplanes, planes, stride, downsample)]
        self.inplanes = planes
        for _ in range(1, blocks):
            layers.append(BasicBlock(self.inplanes, planes))
        return nn.Sequential(*layers)

    def __call__(self, x):
        x = self.bn1(self.conv1(x)).relu()
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc(x)
        return x

In [7]:
def transform(x):
    x = [[Image.fromarray(z).resize((224, 224)) for z in y] for y in x]
    x = np.stack([np.stack([np.asarray(z) for z in y], axis=0) for y in x], axis=0)
    x = x.reshape(-1, 3, 224, 224)
    return x

In [8]:
epochs = 10
batch_size = 128
steps = len(X_train) // batch_size

In [9]:
model = ResNet().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, weight_decay=0.001, momentum=0.9)
for _ in range(epochs):
    train(model, X_train, Y_train, optimizer, steps, batch_size=batch_size, transform=transform, device=device)
    evaluate(model, X_test, Y_test, target_transform=transform, device=device)

loss 1.59 accuracy 0.47: 100%|██████████| 390/390 [08:35<00:00,  1.32s/it]
100%|██████████| 79/79 [00:38<00:00,  2.03it/s]


test set accuracy is 0.3928


loss 1.47 accuracy 0.46: 100%|██████████| 390/390 [08:32<00:00,  1.31s/it]
100%|██████████| 79/79 [00:29<00:00,  2.66it/s]


test set accuracy is 0.4666


loss 1.28 accuracy 0.55: 100%|██████████| 390/390 [08:11<00:00,  1.26s/it]
100%|██████████| 79/79 [00:28<00:00,  2.72it/s]


test set accuracy is 0.4427


loss 1.09 accuracy 0.61: 100%|██████████| 390/390 [08:12<00:00,  1.26s/it]
100%|██████████| 79/79 [00:29<00:00,  2.70it/s]


test set accuracy is 0.5375


loss 0.77 accuracy 0.71: 100%|██████████| 390/390 [08:19<00:00,  1.28s/it]
100%|██████████| 79/79 [00:32<00:00,  2.46it/s]


test set accuracy is 0.5341


loss 0.92 accuracy 0.65: 100%|██████████| 390/390 [08:20<00:00,  1.28s/it]
100%|██████████| 79/79 [00:29<00:00,  2.65it/s]


test set accuracy is 0.5642


loss 0.73 accuracy 0.74: 100%|██████████| 390/390 [08:41<00:00,  1.34s/it]
100%|██████████| 79/79 [00:29<00:00,  2.70it/s]


test set accuracy is 0.578


loss 0.60 accuracy 0.83: 100%|██████████| 390/390 [08:10<00:00,  1.26s/it]
100%|██████████| 79/79 [00:29<00:00,  2.72it/s]


test set accuracy is 0.565


loss 0.52 accuracy 0.84: 100%|██████████| 390/390 [08:08<00:00,  1.25s/it]
100%|██████████| 79/79 [00:28<00:00,  2.73it/s]


test set accuracy is 0.5245


loss 0.57 accuracy 0.81: 100%|██████████| 390/390 [08:16<00:00,  1.27s/it]
100%|██████████| 79/79 [00:29<00:00,  2.70it/s]

test set accuracy is 0.5177



