In [1]:
import numpy as np
from tqdm import trange
import torch
from torch import nn
import torch.nn.functional as F
from datasets import fetch_mnist
from extra.training import train, evaluate
np.random.seed(1337)
torch.manual_seed(1337)

<torch._C.Generator at 0x10bc77cf0>

In [2]:
class BasicBlock(nn.Module):
    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super().__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, 3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, 3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
    
    def __call__(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        if self.downsample is not None:
            x = self.downsample(x)
        out = (out + x).relu()
        return out

In [3]:
class ResNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.inplanes = 64
        self.conv1 = nn.Conv2d(1, 64, 7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
        self.layer1 = self._make_layer(64, 2, stride=1)
        self.layer2 = self._make_layer(128, 2, stride=2)
        self.layer3 = self._make_layer(256, 2, stride=2)
        self.layer4 = self._make_layer(512, 2, stride=2)
        self.fc = nn.Linear(512, 10)
    
    def _make_layer(self, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes, 1, stride=stride, bias=False), nn.BatchNorm2d(planes)
            )
        layers = [BasicBlock(self.inplanes, planes, stride, downsample)]
        self.inplanes = planes
        for _ in range(1, blocks):
            layers.append(BasicBlock(self.inplanes, planes))
        return nn.Sequential(*layers)
    
    def __call__(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.maxpool(x)
        for layer in (self.layer1, self.layer2, self.layer3, self.layer4):
            x = layer(x)
        x = x.reshape(x.shape[0], -1)
        return x

In [4]:
def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=F.cross_entropy):
    model.train()
    for i in (t := trange(steps)):
        sample = np.random.randint(0, high=X_train.shape[0], size=(BS))
        x = torch.tensor(X_train[sample])
        y = torch.tensor(Y_train[sample])
        out = model(x)
        loss = lossfn(out, y)
        optim.zero_grad()
        loss.backward()
        optim.step()
        cat = torch.argmax(out, dim=1)
        accuracy = (cat == y).float().mean()
        loss = loss.item()
        t.set_description(f"loss {loss:.2f} accuracy {accuracy:.2f}")


def evaluate(model, X_test, Y_test, num_classes=10, BS=128):
    model.eval()
    out = model(torch.tensor(X_test))
    preds = torch.argmax(out, dim=1).numpy()
    accuracy = (Y_test == preds).mean()
    print(f"test set accuracy is {accuracy}")

In [5]:
X_train, Y_train, X_test, Y_test = fetch_mnist()
X_train = X_train.reshape(-1, 1, 28, 28) / 255.0
X_test = X_test.reshape(-1, 1, 28, 28) / 255.0

model = ResNet()
learning_rate = 0.001

for _ in range(5):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    train(model, X_train, Y_train, optimizer, 1000, BS=128)
    evaluate(model, X_test, Y_test)
    learning_rate /= 3

loss 0.02 accuracy 1.00: 100%|██████████| 1000/1000 [14:51<00:00,  1.12it/s]


test set accuracy is 0.989


loss 0.01 accuracy 1.00: 100%|██████████| 1000/1000 [15:44<00:00,  1.06it/s]


test set accuracy is 0.9884


loss 0.01 accuracy 1.00: 100%|██████████| 1000/1000 [16:13<00:00,  1.03it/s]


test set accuracy is 0.9934


loss 0.00 accuracy 1.00: 100%|██████████| 1000/1000 [16:16<00:00,  1.02it/s]


test set accuracy is 0.9943


loss 0.02 accuracy 0.99: 100%|██████████| 1000/1000 [15:19<00:00,  1.09it/s]


test set accuracy is 0.9947
