In [None]:
!curl https://data.deepai.org/mnist.zip -o mnist.zip
!unzip mnist.zip -d mnist/
!rm mnist.zip
!gunzip mnist -r

In [55]:
import numpy as np
import torch

torch.set_printoptions(sci_mode=False)
device = 'cpu'

def load_mnist_data(test=False):
    if(test):
        f_images = open('mnist/t10k-images-idx3-ubyte','rb')
        f_labels = open('mnist/t10k-labels-idx1-ubyte','rb')
    else:
        f_images = open('mnist/train-images-idx3-ubyte','rb')
        f_labels = open('mnist/train-labels-idx1-ubyte','rb')
        
    # skip bullshit start
    f_images.seek(16)
    f_labels.seek(8)
    
    # read whole file
    buf_images = f_images.read()
    buf_labels = f_labels.read()
    
    images = np.frombuffer(buf_images, dtype=np.uint8).astype(np.float32)
    images = images.reshape(-1, 1, 28, 28) / 256
    
    labels = np.frombuffer(buf_labels, dtype=np.uint8)
    labels_one_hot = np.zeros((labels.shape[0], 10))
    labels_one_hot[np.arange(labels.size), labels] = 1
    
    return images, labels

def sample_batch(X, Y, batch_size=32):
    length = len(Y)
    idx = np.random.choice(np.arange(0, length), size=(batch_size), replace=False)
    
    return X[idx], Y[idx]

In [None]:
def checker_board(d_model):
    half = (d_model) // 2
    texture = torch.cat([
        torch.ones((half, 1)),
        torch.zeros((half, 1))
    ], dim=1).view((-1,))
    
    return texture

def pos_embedding(x):
        # x: (pos, n, i)
        
        length = x.shape[0]
        batch_size = x.shape[1]
        d_model = x.shape[2]

        i = torch.arange(0, d_model).view((1, 1, -1)).expand(length, -1, d_model).to(device).float()
        pos = torch.arange(0, length).view((-1, 1, 1)).expand(length, -1, d_model).to(device).float()
        
        z = pos / 10000 ** (i / d_model)
        
        sin = torch.sin(z)
        cos = torch.cos(z)
        
        sin_mask = checker_board(d_model).to(device)
        cos_mask = -sin_mask + 1
                
        pe = (sin_mask * sin) + (cos_mask * cos)
        pe = pe.expand(length, batch_size, d_model)
        
        return x + pe

In [103]:
import torch
from torch import nn, optim
import torch.nn.functional as F
from matplotlib.pyplot import imshow

torch.manual_seed(0)

class AttentionBlock(nn.Module):
    def __init__(self, d_model, heads, dropout):
        super(AttentionBlock, self).__init__()
        
        self.layer_norm_x = nn.LayerNorm([d_model])
        self.layer_norm_1 = nn.LayerNorm([d_model])
        self.attention = nn.MultiheadAttention(
            d_model,
            heads,
            dropout=0.0,
            bias=True,
            add_bias_kv=True,
        )
        self.dropout = nn.Dropout(p=dropout)
        self.linear1 = nn.Linear(d_model, d_model)
        self.layer_norm_2 = nn.LayerNorm([d_model])
        self.linear2 = nn.Linear(d_model, d_model)
        self.linear3 = nn.Linear(d_model, d_model)
        
    def forward(self, x, z_input):
        x = self.layer_norm_x(x)
        z = self.layer_norm_1(z_input)
        z, _ = self.attention(z, x, x)
        
        z = self.dropout(z)
        z = self.linear1(z)
        
        z = self.layer_norm_2(z)
        z = self.linear2(z)
        z = F.gelu(z)
        z = self.dropout(z)
        z = self.linear3(z)
        
        return z + z_input

class PerceiverBlock(nn.Module):
    def __init__(self, d_model, latent_blocks, dropout, heads):
        super(PerceiverBlock, self).__init__()
        
        self.cross_attention = AttentionBlock(d_model, heads=1, dropout=dropout)
        self.latent_attentions = nn.ModuleList([
            AttentionBlock(d_model, heads=heads, dropout=dropout) for _ in range(latent_blocks)
        ])
        
    def forward(self, x, z):
        z = self.cross_attention(x, z)
        for latent_attention in self.latent_attentions:
            z = latent_attention(z, z)
        return z

class Repeater(nn.Module):
    def __init__(self, module, repeats=1):
        super(Repeater, self).__init__()
        
        self.repeats = repeats
        self.module = module
    
    def forward(self, x, z):
        for _ in range(self.repeats):
            z = self.module(x, z)
        return z
    
class Perceiver(nn.Module):
    def __init__(self, output_size, latents=8, d_model=16, input_channels=1, heads=4, dropout=0.1):
        super(Perceiver, self).__init__()

        self.init_latent = nn.Parameter(torch.rand((latents, d_model)))
        self.embedding = nn.Conv1d(input_channels, d_model, 1)
        
        self.block1 = Repeater(PerceiverBlock(d_model, latent_blocks=3, heads=heads, dropout=dropout), repeats=1)
        self.block2 = Repeater(PerceiverBlock(d_model, latent_blocks=3, heads=heads, dropout=dropout), repeats=3)
        
        self.linear1 = nn.Linear(d_model, d_model)
        self.linear2 = nn.Linear(d_model, output_size)   
        
    def forward(self, x):
        # Transform our X (input)     
        # x.shape = (batch_size, channels, width, height)
        x = x.view((x.shape[0], x.shape[1], -1))
        # x.shape = (batch_size, channels, pixels)
        
        x = self.embedding(x)
        # x.shape = (batch_size, d_model, pixels)
        x = x.permute(2, 0, 1)
        # x.shape (pixels, batch_size, d_model)
        
        x = pos_embedding(x)
               
        # Transform our Z (latent)
        # z.shape = (latents, d_model)
        z = self.init_latent.unsqueeze(1)
        # z.shape = (latents, 1, d_model)
        z = z.expand(-1, x.shape[1], -1)
        # z.shape = (latents, batch_size, d_model)
        
        z = self.block1(x, z)
        z = self.block2(x, z)
        
        z = self.linear1(z)
        z = z.mean(dim=0)
        z = self.linear2(z)
        
        return z
    
model = Perceiver(output_size=10)

In [106]:
from tqdm import trange

def test(model):
    model.eval()
    with torch.no_grad():
        X_test, Y_test = load_mnist_data(test=True)
        X_LENGTH = len(X_test)
        BATCH_SIZE = 500
        DEVICE = 'cpu'

        correct = 0
        total = 0

        t = range(X_LENGTH // BATCH_SIZE)
        for i in t:
            x = torch.from_numpy(X_test[i * BATCH_SIZE:(i+1) * BATCH_SIZE]).float().to(DEVICE)
            y = torch.from_numpy(Y_test[i * BATCH_SIZE:(i+1) * BATCH_SIZE]).long().to(DEVICE)

            y_ = model(x).argmax(dim=-1)

            total += len(y_)
            correct += (y_ == y).sum().item()

        return correct / total
    
def train(model, SKIP_EPOCHS=-1, EPOCHS=10, BATCH_SIZE=32, DEVICE='cpu'):
    model.train()
    gamma = 0.1 ** 0.5 # 0.3ish
    optimizer = optim.Adam(model.parameters(), lr=gamma)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=gamma, last_epoch=-1, verbose=False)

    X_train, Y_train = load_mnist_data(test=False)
    X_LENGTH = len(X_train)

    for epoch in range(EPOCHS):
        print('EPOCH', epoch, '[LEARNING RATE: ' + str(optimizer.param_groups[0]['lr']) + '; ACCURACY: ' + str(test(model)) + ']')
        if(epoch <= SKIP_EPOCHS):
            scheduler.step()
            continue

        t = trange(X_LENGTH // BATCH_SIZE)
        for _ in t:
            optimizer.zero_grad()

            x, y = sample_batch(X_train, Y_train, BATCH_SIZE)
            x = torch.from_numpy(x).float().to(DEVICE)
            y = torch.from_numpy(y).long().to(DEVICE)

            y_ = model(x)
            loss = nn.CrossEntropyLoss()(y_, y)

            loss.backward()
            optimizer.step()

            t.set_description(str(loss.item())[0:5])
        scheduler.step()
        
train(model)


2.119:   0%|          | 1/1875 [00:00<03:56,  7.92it/s]

EPOCH 0 [LEARNING RATE: 0.003; ACCURACY: 0.0958]


1.011: 100%|██████████| 1875/1875 [03:25<00:00,  9.13it/s]
1.322:   0%|          | 1/1875 [00:00<04:02,  7.73it/s]

EPOCH 1 [LEARNING RATE: 0.003; ACCURACY: 0.5819]


0.656: 100%|██████████| 1875/1875 [03:33<00:00,  8.80it/s]
0.981:   0%|          | 1/1875 [00:00<03:55,  7.97it/s]

EPOCH 2 [LEARNING RATE: 0.003; ACCURACY: 0.7224]


0.385: 100%|██████████| 1875/1875 [03:26<00:00,  9.06it/s]
0.522:   0%|          | 1/1875 [00:00<03:59,  7.81it/s]

EPOCH 3 [LEARNING RATE: 0.003; ACCURACY: 0.7533]


0.658: 100%|██████████| 1875/1875 [03:27<00:00,  9.02it/s]
0.940:   0%|          | 1/1875 [00:00<03:57,  7.89it/s]

EPOCH 4 [LEARNING RATE: 0.0009; ACCURACY: 0.7748]


0.399: 100%|██████████| 1875/1875 [03:28<00:00,  9.00it/s]
0.611:   0%|          | 1/1875 [00:00<03:44,  8.35it/s]

EPOCH 5 [LEARNING RATE: 0.0009; ACCURACY: 0.8252]


0.641:   3%|▎         | 47/1875 [00:05<03:20,  9.10it/s]


KeyboardInterrupt: 

In [108]:
torch.save(model, 'model_8_16_3_4')

In [None]:
from matplotlib.pyplot import imshow

imshow(data[1])