In [None]:
# the torch version works better

In [210]:
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import random
import math
from tqdm import tqdm

In [211]:
# dataset idea from https://github.com/karpathy/minGPT/blob/master/play_math.ipynb

def make_dataset():
  ret = []
  for i in range(100):
    for j in range(100):
      s = i+j
      ret.append([i//10, i%10, j//10, j%10, s//100, (s//10)%10, s%10])
  return ret
ds = make_dataset()
random.shuffle(ds)
ds = np.array(ds)
ds_X = ds[:, 0:6]
ds_Y = np.copy(ds[:, 1:])
ds_X_train, ds_X_test = ds_X[0:8000], ds_X[8000:]
ds_Y_train, ds_Y_test = ds_Y[0:8000], ds_Y[8000:]

train_loader = torch.utils.data.DataLoader(list(zip(ds_X_train, ds_Y_train)), batch_size=32)

In [252]:
def attention(queries, keys, values):
  d = queries.shape[-1]
  scores = torch.matmul(queries, keys.transpose(-2,-1))/math.sqrt(d)
  attention_weights = F.softmax(scores, dim=-1)
  return torch.matmul(attention_weights, values)

class MultiHeadAttention(nn.Module):
  def __init__(self, embed_dim, num_heads):
    super(MultiHeadAttention, self).__init__()
    self.embed_dim, self.num_heads = embed_dim, num_heads
    assert embed_dim % num_heads == 0
    self.projection_dim = embed_dim // num_heads
    
    self.W_q = nn.Linear(embed_dim, embed_dim)
    self.W_k = nn.Linear(embed_dim, embed_dim)
    self.W_v = nn.Linear(embed_dim, embed_dim)
    self.W_o = nn.Linear(embed_dim, embed_dim)

  def transpose(self, x):
    x = x.reshape(x.shape[0], x.shape[1], self.num_heads, self.projection_dim)
    return x.permute(0, 2, 1, 3)
  
  def transpose_output(self, x):
    x = x.permute(0, 2, 1, 3)
    return x.reshape(x.shape[0], x.shape[1], self.embed_dim)
    
  def forward(self, q, k, v):
    q = self.transpose(self.W_q(q))
    k = self.transpose(self.W_k(k))
    v = self.transpose(self.W_v(v))
    output = attention(q, k, v)
    return self.W_o(self.transpose_output(output))
  
class TransformerBlock(nn.Module):
  def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
    super(TransformerBlock, self).__init__()
    self.att = MultiHeadAttention(embed_dim, num_heads)
    self.ffn = nn.Sequential(
      nn.Linear(embed_dim, ff_dim), nn.ReLU(), nn.Linear(ff_dim, embed_dim)
    )
    self.layernorm1 = nn.LayerNorm(embed_dim)
    self.layernorm2 = nn.LayerNorm(embed_dim)
    self.dropout = nn.Dropout(rate)
    
  def forward(self, x):
    x = self.layernorm1(x + self.dropout(self.att(x, x, x)))
    x = self.layernorm2(x + self.dropout(self.ffn(x)))
    return x
  
class TokenAndPositionEmbedding(nn.Module):
  def __init__(self, maxlen, vocab_size, embed_dim):
    super(TokenAndPositionEmbedding, self).__init__()
    self.token_emb = nn.Embedding(vocab_size, embed_dim)
    self.pos_emb = nn.Embedding(maxlen, embed_dim)
  def forward(self, x):
    pos = torch.arange(0, x.size(1), dtype=torch.int32)
    return self.token_emb(x) + self.pos_emb(pos).view(1, x.size(1), -1)
  
m = nn.Sequential(
  TokenAndPositionEmbedding(6, 10, 128),
  TransformerBlock(128, 4, 32),
  TransformerBlock(128, 4, 32),
  nn.Linear(128, 10),
  nn.LogSoftmax(dim=-1))
opt = torch.optim.Adam(m.parameters(), lr=3e-4)

In [253]:
def num_correct():
  m.eval()
  pred = m(torch.from_numpy(ds_X_test)).argmax(dim=2)
  gt = torch.from_numpy(ds_Y_test)
  return (pred[:, -1] == gt[:, -1]).sum(), (pred[:, :-1] == gt[:, :-1]).sum()

for epoch in range(5):
  m.train()
  total_loss = None
  for dat in (t:=tqdm(train_loader)):
    #print(dat[0].shape)
    output = m(dat[0])
    #print(output, dat[1])

    loss = F.nll_loss(output.view(-1, 10), dat[1].view(-1))
    opt.zero_grad()
    loss.backward()
    opt.step()
    if total_loss == None:
      total_loss = loss.detach()
    else:
      total_loss += loss.detach()
    #t.set_description("%f" % loss)
  print(num_correct(), total_loss)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:01<00:00, 135.71it/s]


(tensor(446), tensor(10000)) tensor(176.6125)


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:01<00:00, 139.38it/s]


(tensor(1801), tensor(10000)) tensor(62.5285)


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:01<00:00, 139.10it/s]


(tensor(2000), tensor(10000)) tensor(8.7012)


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:01<00:00, 138.94it/s]


(tensor(2000), tensor(10000)) tensor(1.9596)


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:01<00:00, 138.86it/s]


(tensor(2000), tensor(10000)) tensor(0.7682)
