# Download data

In [None]:
!curl https://data.deepai.org/mnist.zip -o mnist.zip
!mkdir data
!unzip mnist.zip -d data/mnist/
!rm mnist.zip
!gunzip data/mnist -r

# Data utilities

In [5]:
import numpy as np
import torch


def load_mnist_data(test=False):
    if(test):
        f_images = open('data/mnist/t10k-images-idx3-ubyte', 'rb')
        f_labels = open('data/mnist/t10k-labels-idx1-ubyte', 'rb')
    else:
        f_images = open('data/mnist/train-images-idx3-ubyte', 'rb')
        f_labels = open('data/mnist/train-labels-idx1-ubyte', 'rb')

    # skip bullshit start
    f_images.seek(16)
    f_labels.seek(8)

    # read whole file
    buf_images = f_images.read()
    buf_labels = f_labels.read()

    images = np.copy(
        np.frombuffer(buf_images, dtype=np.uint8).astype(np.float32)
    )
    images = images.reshape(-1, 1, 28, 28) / 256

    labels = np.copy(
        np.frombuffer(buf_labels, dtype=np.uint8)
    )
    labels_one_hot = np.zeros((labels.shape[0], 10))
    labels_one_hot[np.arange(labels.size), labels] = 1

    return images, labels


def sample_batch(X, Y, batch_size=32):
    length = len(Y)
    idx = np.random.choice(np.arange(0, length),
                           size=(batch_size), replace=False)

    return X[idx], Y[idx]


# Training script

In [9]:
from tqdm import trange
from perceiver import PerceiverLogits, load_mnist_model
import torch
from torch import nn, optim
import random
import numpy as np
import os


def set_random_seed(value):
    torch.manual_seed(value)
    np.random.seed(value)
    random.seed(value)


set_random_seed(10)


torch.set_printoptions(sci_mode=False)
"""model = PerceiverLogits(
    input_channels=1,
    input_shape=(28, 28),
    fourier_bands=4,
    output_features=10,
    latents=8,
    d_model=16,
    heads=8,
    latent_blocks=6,
    dropout=0.1,
    layers=6
)"""
#model = torch.load('./checkpoints/epoch9')


def test(model, DEVICE='cpu'):
    model.eval()
    model = model.to(DEVICE)
    with torch.no_grad():
        X_test, Y_test = load_mnist_data(test=True)
        X_LENGTH = len(X_test)
        BATCH_SIZE = 500

        correct = 0
        total = 0

        for i in range(X_LENGTH // BATCH_SIZE):
            x = torch.from_numpy(
                X_test[i * BATCH_SIZE:(i+1) * BATCH_SIZE]
            ).float().to(DEVICE)
            y = torch.from_numpy(
                Y_test[i * BATCH_SIZE:(i+1) * BATCH_SIZE]
            ).long().to(DEVICE)

            y_ = model(x).argmax(dim=-1)

            total += len(y_)
            correct += (y_ == y).sum().item()

        return correct / total


def train(model, SKIP_EPOCHS=-1, EPOCHS=24, BATCH_SIZE=32, DEVICE='cpu'):
    model.train()
    model = model.to(DEVICE)
    gamma = 0.1 ** 0.5  # 0.3ish
    optimizer = optim.Adam(model.parameters(), lr=gamma * 0.01)
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer, step_size=3, gamma=gamma, last_epoch=-1, verbose=False)

    X_train, Y_train = load_mnist_data(test=False)
    X_LENGTH = len(X_train)

    for epoch in range(EPOCHS):
        if(epoch <= SKIP_EPOCHS):
            scheduler.step()
            continue
        else:
            print('EPOCH', epoch, '[LEARNING RATE: ' + str(optimizer.param_groups[0]
                                                           ['lr']) + '; ACCURACY: ' + str(test(model, DEVICE=DEVICE)) + ']')

        t = trange(X_LENGTH // BATCH_SIZE)
        for _ in t:
            optimizer.zero_grad()

            x, y = sample_batch(X_train, Y_train, BATCH_SIZE)
            x = torch.from_numpy(x).float().to(DEVICE)
            y = torch.from_numpy(y).long().to(DEVICE)

            y_ = model(x)
            loss = nn.NLLLoss()(y_, y)

            loss.backward()
            optimizer.step()

            t.set_description(str(loss.item())[0:5])
        scheduler.step()

        if(not os.path.exists('checkpoints')):
            os.mkdir('checkpoints')
        torch.save(model, 'checkpoints/epoch' + str(epoch))


#train(model, SKIP_EPOCHS=-1)
test(model)

0.9602