In [17]:
import copy
import math
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

import time
import wandb

m1 = torch.device("mps")
cpu = torch.device("cpu")

# Transformer Model
https://arxiv.org/pdf/1706.03762.pdf

Super helpful walkthrough: http://nlp.seas.harvard.edu/annotated-transformer/

In [32]:
# chars in dictionary
vocab = 30

# d_model is the same as embedding same for simplicity.
# embs = 12
d_model = 12

# number of chars to see in one window
window = 16

# This will increase one day
batch_size = 5

heads = 3

blocks = 6

In [7]:
prompts = torch.randint(vocab, (batch_size, window))
prompts.shape

torch.Size([5, 16])

In [8]:
class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        super(Embeddings, self).__init__()
        self.emb = nn.Embedding(vocab, d_model)
        self.sqrt_d_model = math.sqrt(d_model)

    def forward(self, x):
        return self.emb(x) * self.sqrt_d_model

emb_prompts = Embeddings(d_model, vocab)(prompts)
emb_prompts.shape


torch.Size([5, 16, 12])

In [41]:
class PositionalEncoding(nn.Module):
    ## Thank you! https://jalammar.github.io/illustrated-transformer/
    # for evens:  sin(pos/10000**(2i/embs))
    # for odds:   cos(pos/10000**(2i/embs))
    def __init__(self, d_model, window):
        super(PositionalEncoding, self).__init__()
        
        pos_idxs = torch.arange(0, window).view(-1, 1)
        emb_idxs = torch.arange(0, d_model).view(1, -1)
        
        
        angles = pos_idxs / (10000**(2*emb_idxs / d_model))

        self.pos_enc_tmp = torch.zeros(window, d_model)
        self.pos_enc_tmp[:, 0::2] += torch.sin(angles[:, 0::2])
        self.pos_enc_tmp[:, 1::2] += torch.cos(angles[:, 1::2])

        self.register_buffer("pos_enc", self.pos_enc_tmp)

    def forward(self, x):
        N = x.shape[0]
        # For some reason I thought you concatenated these, but you add instead. Interesting.
        # return torch.cat((x, self.pos_enc.unsqueeze(0).repeat(N, 1, 1)), dim=2)
        return x + self.pos_enc

# pos_enc = PositionalEncoding(embs, window)

# # Plot the positional embeddings
# plt.pcolormesh(pos_enc.pos_enc, cmap='viridis')
# plt.xlabel('Embedding Dimensions')
# plt.xlim((0, embs))
# plt.ylim((window,0))
# plt.ylabel('Char Position')
# plt.colorbar()
# plt.show()


enc_prompt = PositionalEncoding(d_model, window)(emb_prompts)
enc_prompt.shape

torch.Size([5, 16, 12])

In [11]:
def clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

def attention(q, k, v):
    d_k = q.size(-1)
    scores = q @ k.transpose(-2, -1) / math.sqrt(d_k)
    p_attn = F.softmax(scores, dim=-1)
    return torch.matmul(p_attn, v), p_attn

In [61]:
class GptAttention(nn.Module):
    """
    For this attention module k = v = q are all the same.
    It's for encoder only transfomers.
    """
    def __init__(self, heads, d_model):
        super(GptAttention, self).__init__()
        assert d_model % heads == 0

        self.heads = heads

        ## I've seen combining these into a single linear layer
        # This seems weird to me because k,v,q will cross and change eachother
        #
        self.W = clones(nn.Linear(d_model, 3*d_model), heads)
        # self.Wq = clones(nn.Linear(d_model, d_model), heads)
        # self.Wv = clones(nn.Linear(d_model, d_model), heads)
        # self.Wk = clones(nn.Linear(d_model, d_model), heads)

        self.linear = nn.Linear(d_model * heads, d_model)
    
    def forward(self, x):
        B, window, embs = x.shape
        att_out = []
        for i in range(self.heads):
            q, v, k = self.W[i](x).split(d_model, dim=2)
            # q = self.Wq[i](x)
            # v = self.Wv[i](x)
            # k = self.Wk[i](x)
            att, _ = attention(q, k, v)
            att_out.append(att)
        out = torch.cat(att_out, dim=2)
        
        return self.linear(out)

gpt_attn = GptAttention(6, d_model)
out = gpt_attn(enc_prompt)
print(out.shape)

torch.Size([5, 16, 12])


In [50]:
class FeedForward(nn.Module):
    def __init__(self, d_model):
        super(FeedForward, self).__init__()
        self.l1 = nn.Linear(d_model, 2*d_model)
        self.l2 = nn.Linear(2*d_model, d_model)

    def forward(self, x):
        x = F.relu(self.l1(x))
        return self.l2(x)

In [51]:
class Norm(nn.Module):
    def __init__(self, features, eps=1e-6):
        super(Norm, self).__init__()
        self.a = nn.Parameter(torch.ones(features))
        self.b = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a * (x - mean) / (std + self.eps) + self.b

In [52]:
class Block(nn.Module):
    def __init__(self, d_model, heads):
        super(Block, self).__init__()
        self.attn = GptAttention(heads, d_model)
        self.norm1 = Norm(d_model)
        self.ff = FeedForward(d_model)
        self.norm2 = Norm(d_model)

    def forward(self, x):
        attn = self.attn(x)
        x = self.norm1(x + attn)
        ff = self.ff(x)
        x = self.norm2(x + ff)
        return x

b = Block(d_model, 3)
out = b(emb_prompts)
print(out.shape)

torch.Size([5, 16, 12])


In [67]:
class GPT(nn.Module):
    def __init__(self, d_model, heads, blocks):
        super(GPT, self).__init__()
        self.vocab_emb = Embeddings(d_model, vocab)
        self.pos_emb = PositionalEncoding(d_model, window)
        self.blocks = clones(Block(d_model, heads), blocks)
        self.l_out = nn.Linear(d_model, vocab)

    def forward(self, x):
        x = self.vocab_emb(x)
        x = self.pos_emb(x)
        for b in self.blocks:
            x = b(x)
        x = self.l_out(x)

        return x

    def sample_char(self, x):
        logits = self(x)
        probs = F.softmax(logits, dim=1)
        return torch.multinomial(probs, num_samples=1).item()
        

gpt = GPT(d_model, heads, blocks)

X, Y = torch.squeeze(Xtr[:5]), torch.squeeze(Ytr[:5])

logits = gpt(X)
print(logits.shape)
print(Y.shape)
dev_loss = F.cross_entropy(logits, Y)


torch.Size([5, 16, 30])
torch.Size([5])


RuntimeError: Expected target size [5, 30], got [5]

# Now let's make it run!

In [18]:
names = open('compiled_names.txt', 'r').read().splitlines()

In [19]:
## functions to convert chars to int and inverse

chars = sorted(list(set(''.join(names))))
stoi = {s:i+1 for i,s in enumerate(chars)}

# . is both "before start" in X, and "im done" for Y
stoi['.'] = 0
itos = {s:i for i,s in stoi.items()}

num_char = len(stoi)

In [20]:
def build_dataset(words, device):
    x, y = [], []

    for word in words:
        for i, c in enumerate(word + '.'):
            mini_x = []
            for w in reversed(range(1, window+1)):
                if i - w >= 0:
                    mini_x.append(stoi[word[i-w]])
                else:
                    mini_x.append(stoi['.'])

            x.append(mini_x)
            y.append(stoi[c])
            
    return torch.tensor(x, device=device), torch.tensor(y, device=device)

In [21]:
import random
random.seed(42)
random.shuffle(names)
n1 = int(0.8*len(names))
n2 = int(0.9*len(names))

Xtr, Ytr = build_dataset(names[:n1], device=cpu)
Xdev, Ydev = build_dataset(names[n1:n2], device=cpu)
Xte, Yte = build_dataset(names[n2:], device=cpu)

In [None]:
for i in range(50): 
    print("{} --> {}".format([itos[c.item()] for c in Xtr[i]], itos[Ytr[i].item()]))
       

In [63]:
network = GPT(d_model, heads, blocks)
network.to(m1)
network.train(mode=True)

opt = torch.optim.Adam(network.parameters(), lr=0.001)

steps = []
losses = []
dev_steps = []
dev_losses = []
batch_size = 64
max_steps = 200000
rec_freq = 2000
start_time = time.perf_counter()

for i in range(max_steps+1):
    # sample from training set
    sample_idx = torch.randint(len(Ytr), size=(batch_size,1))
    X, Y = torch.squeeze(Xtr[sample_idx].to(m1)), torch.squeeze(Ytr[sample_idx].to(m1))
    
    # forward
    logits = network(X)
    loss = F.cross_entropy(logits, Y)

    with torch.no_grad():
        opt.zero_grad()
        loss.backward()
        opt.step()

    ## Record data
    steps.append(i)
    losses.append(loss.item())
    
    if i % rec_freq == 0: # print every once in a while
        dev_loss = 0
        with torch.no_grad():
            dev_idx = torch.randint(len(Ydev), size=(batch_size,1))
            X_check, Y_check = torch.squeeze(Xdev[dev_idx].to(m1)), torch.squeeze(Ydev[dev_idx].to(m1))

            dev_loss = F.cross_entropy(network(X_check), Y_check)
            
        current_time = time.perf_counter()
        dt = current_time - start_time
        print(f'{i:7d}/{max_steps:7d}: dt: {dt:.2f} dev_loss: {dev_loss.item():.4f} loss: {loss.item():.4f}')
        
        dev_losses.append(dev_loss.item())
        dev_steps.append(i)


current_time = time.perf_counter()
dt = current_time - start_time
print("total training time: {}".format(dt))


RuntimeError: Expected target size [64, 30], got [64]