# MNIST Transformer

In [50]:
import keras
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import torch.optim as optim
import math
import time

In [58]:
import torch
print(torch.cuda.is_available())
print(torch.version.cuda) # This should print the CUDA version PyTorch was built with (e.g., '11.8', '12.1'), NOT 'None'
print(torch.backends.cudnn.is_available()) # Should also be True if CUDA is working

False
None
False


## Data Loading

In [25]:
# load the data from the csv file
x_train = np.loadtxt('mnist_train.csv', delimiter=',') # x_train is (60000, 785), 785 bc of label at index 0
x_test = np.loadtxt('mnist_test.csv', delimiter=',')

# convert to torch tensors
x_train = torch.from_numpy(x_train).to(torch.int16) # we only need int16 bc our values are nowhere near 2^16
x_test = torch.from_numpy(x_test).to(torch.int16)

encode = lambda x: (x + 10) # add 10 to every grayscale value to account for CLS tokens 0-9; now, gs values are 10-265.
decode = lambda x: (x - 10)
x_train[:, 1:] = encode(x_train[:, 1:]) # (60000, 785)
x_test[:, 1:] = encode(x_test[:, 1:])   # (10000, 785)

# split the training data into training and validation sets
train_subset, val_subset = torch.utils.data.random_split(x_train, [50000, 10000])
x_train_split = x_train[train_subset.indices] 
x_val = x_train[val_subset.indices]
x_train = x_train_split
x_train.shape, x_val.shape # (50000, 785) (10000, 785)

(torch.Size([50000, 785]), torch.Size([10000, 785]))

In [46]:
# batching
batch_size = 32
seq_len = x_train.shape[1] - 1 # 784

def get_batch(split):
    data = x_train if split == 'train' else x_val
    indices = torch.randint(len(data), (batch_size,)) # basically, get 32 photos
    x = data[indices, :-1] # (batch_size, 784)
    y = data[indices, 1:]  # (batch_size, 784)
    return x, y

xb, yb = get_batch('train')
xb.shape, yb.shape # (32, 784) (32, 784)

# peace of mind check
for i in range(5):
    context = xb[0, :i+1]
    target = yb[0, i]
    print(f"When context is {context.tolist()}, the target is {target.tolist()}")


When context is [3], the target is 10
When context is [3, 10], the target is 10
When context is [3, 10, 10], the target is 10
When context is [3, 10, 10, 10], the target is 10
When context is [3, 10, 10, 10, 10], the target is 10


## Attention

In [56]:
B, T, C = 4, 8, 32
x = torch.randn(B, T, C)


head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

k = key(x)   # (B, T, head_size)
q = query(x) # (B, T, head_size)
v = value(x) # (B, T, head_size)
weights = q @ k.transpose(-2, -1) # (B, T, head_size) @ (B, head_size, T) -> (B, T, T) explanation: basically, for each batch, we multiply a (T, head_size) matrix by a (head_size, T) matrix, which results in a (T, T) matrix for each batch, so we get a (B, T, T) matrix
weights = weights / math.sqrt(head_size) # scale down the weights to prevent them from blowing up

tril = torch.tril(torch.ones(T, T))
weights = weights.masked_fill(tril == 0, float("-inf")) # piece de resistance of ye olde 'decoder' transformer (AR)
weights = F.softmax(weights, dim=-1)

out = weights @ v

out.shape # (4, 8, 32)

weights

tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3076, 0.6924, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2689, 0.2543, 0.4769, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0391, 0.8925, 0.0373, 0.0311, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2741, 0.1508, 0.3187, 0.0735, 0.1830, 0.0000, 0.0000, 0.0000],
         [0.1004, 0.1210, 0.2156, 0.1775, 0.0312, 0.3543, 0.0000, 0.0000],
         [0.0443, 0.2066, 0.1084, 0.3758, 0.0536, 0.0799, 0.1313, 0.0000],
         [0.0200, 0.2559, 0.0680, 0.0458, 0.2710, 0.1009, 0.1908, 0.0477]],

        [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0632, 0.9368, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.7136, 0.2708, 0.0157, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2481, 0.0882, 0.0241, 0.6396, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0769, 0.1994, 0.4409, 0.1833, 0.0995, 0.0000, 0.0000, 0.0000],
         [0.0173, 0.075