In [5]:
import copy
import math
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

import time
import wandb

import lovely_tensors as lt

lt.monkey_patch()

m1 = torch.device("mps")
cpu = torch.device("cpu")

# Transformer Model
https://arxiv.org/pdf/1706.03762.pdf

Super helpful walkthrough: http://nlp.seas.harvard.edu/annotated-transformer/

In [6]:
# chars in dictionary
vocab = 30

# d_model is the same as embedding same for simplicity.
# embs = 12
d_model = 12

# number of chars to see in one window
window = 16

# This will increase one day
batch_size = 64

heads = 3

blocks = 6

In [7]:
prompts = torch.randint(vocab, (batch_size, window))
prompts.shape

torch.Size([5, 16])

In [8]:
class GptAttention(nn.Module):
    """
    For this attention module k = v = q are all the same.
    It's for encoder only transfomers.
    """
    def __init__(self, heads, d_model):
        super(GptAttention, self).__init__()
        assert d_model % heads == 0
        self.heads = heads

        self.w_attn = nn.Linear(d_model, 3*d_model)
        self.head = nn.Linear(d_model, d_model)

        # causal mask to ensure that attention is only applied to the left in the input sequence
        self.register_buffer("bias", torch.tril(torch.ones(window, window))
                                     .view(1, 1, window, window))
    
    def forward(self, x):
        B, window, embs = x.shape

        q, v, k = self.w_attn(x).split(d_model, dim=2)

        # (B, heads, window, embs)
        q = q.view(B, window, self.heads, embs // self.heads).transpose(1, 2)
        k = k.view(B, window, self.heads, embs // self.heads).transpose(1, 2)
        v = v.view(B, window, self.heads, embs // self.heads).transpose(1, 2)
        
        # Self-attend: (B, heads, window, embs) x (B, heads, embs, window) -> (B, heads, window, window)
        scores = q @ k.transpose(-2, -1) / math.sqrt(k.size(-1))
        mask = scores.masked_fill(self.bias[:,:,:window,:window] == 0, float('-inf'))
        probs = F.softmax(mask, dim=-1)
        attn = probs @ v
        attn = attn.transpose(1, 2).contiguous().view(B, window, embs)

        return self.head(attn)

# gpt_attn = GptAttention(heads, d_model)
# out = gpt_attn(enc_prompt)
# print(out.shape)

In [9]:
class FeedForward(nn.Module):
    def __init__(self, d_model):
        super(FeedForward, self).__init__()
        self.l1 = nn.Linear(d_model, 2*d_model)
        self.l2 = nn.Linear(2*d_model, d_model)

    def forward(self, x):
        x = F.relu(self.l1(x))
        return self.l2(x)

In [10]:
class Block(nn.Module):
    def __init__(self, d_model, heads):
        super(Block, self).__init__()
        self.attn = GptAttention(heads, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.ff = FeedForward(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        attn = self.attn(x)
        x = self.norm1(x + attn)
        ff = self.ff(x)
        x = self.norm2(x + ff)
        return x

# b = Block(d_model, 3)
# out = b(emb_prompts)
# print(out.shape)

In [21]:
class GPT(nn.Module):
    def __init__(self, d_model, heads, blocks):
        super(GPT, self).__init__()
        self.vocab_emb = nn.Embedding(vocab, d_model)
        self.pos_emb = nn.Embedding(window, d_model)
        # self.vocab_emb = Embeddings(d_model, vocab)
        # self.pos_emb = PositionalEncoding(d_model, window)

        self.blocks = nn.ModuleList([Block(d_model, heads) for _ in range(blocks)])
        self.head = nn.Linear(d_model, vocab)

    def forward(self, x):
        vocab_emb = self.vocab_emb(x)
        pos_emb = self.pos_emb(torch.arange(0, x.shape[1]).to(m1))

        x = vocab_emb + pos_emb

        for b in self.blocks:
            x = b(x)
        x = self.head(x)

        return x

    def sample_char(self, x):
        logits = self(x)
        probs = F.softmax(logits[:,-1,:], dim=1)
        return torch.multinomial(probs, num_samples=1).item()
        

# gpt = GPT(d_model, heads, blocks)

# X, Y = torch.squeeze(Xtr[:5]), torch.squeeze(Ytr[:5])

# logits = gpt(X)
# print(logits[:,-1,:])
# # out = logits[:,-1,:].view(-1, logits.size(-1))
# print(out)
# print(Y)
# # pluck 
# dev_loss = F.cross_entropy(logits[:,-1,:], Y)


# Now let's make it run!

In [13]:
## functions to convert chars to int and inverse

chars = sorted(list(set(''.join(names))))
stoi = {s:i+1 for i,s in enumerate(chars)}

# . is both "before start" in X, and "im done" for Y
stoi['.'] = 0
itos = {s:i for i,s in stoi.items()}

num_char = len(stoi)

In [14]:
def build_dataset(words, device):
    x, y = [], []

    for word in words:
        for i, c in enumerate(word + '.'):
            mini_x = []
            for w in reversed(range(1, window+1)):
                if i - w >= 0:
                    mini_x.append(stoi[word[i-w]])
                else:
                    mini_x.append(stoi['.'])

            x.append(mini_x)
            y.append(stoi[c])
            
    return torch.tensor(x, device=device), torch.tensor(y, device=device)

In [15]:
import random
random.seed(42)
random.shuffle(names)
n1 = int(0.8*len(names))
n2 = int(0.9*len(names))

Xtr, Ytr = build_dataset(names[:n1], device=cpu)
Xdev, Ydev = build_dataset(names[n1:n2], device=cpu)
Xte, Yte = build_dataset(names[n2:], device=cpu)

In [16]:
for i in range(2): 
    print("{} --> {}".format([itos[c.item()] for c in Xtr[i]], itos[Ytr[i].item()]))
       

['.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.'] --> c
['.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', 'c'] --> o


In [22]:
network = GPT(d_model, heads, blocks)
network.to(m1)
network.train(mode=True)

opt = torch.optim.Adam(network.parameters(), lr=0.001)

steps = []
losses = []
dev_steps = []
dev_losses = []
batch_size = 64
max_steps = 200000
rec_freq = 2000
start_time = time.perf_counter()

for i in range(max_steps+1):
    # sample from training set
    sample_idx = torch.randint(len(Ytr), size=(batch_size,1))
    X, Y = torch.squeeze(Xtr[sample_idx].to(m1)), torch.squeeze(Ytr[sample_idx].to(m1))
    
    # forward
    logits = network(X)
    loss = F.cross_entropy(logits[:,-1,:], Y)

    with torch.no_grad():
        opt.zero_grad()
        loss.backward()
        opt.step()

    ## Record data
    steps.append(i)
    losses.append(loss.item())
    
    if i % rec_freq == 0: # print every once in a while
        dev_loss = 0
        with torch.no_grad():
            dev_idx = torch.randint(len(Ydev), size=(batch_size,1))
            X_check, Y_check = torch.squeeze(Xdev[dev_idx].to(m1)), torch.squeeze(Ydev[dev_idx].to(m1))

            dev_loss = F.cross_entropy(network(X_check)[:,-1,:], Y_check)
            
        current_time = time.perf_counter()
        dt = current_time - start_time
        print(f'{i:7d}/{max_steps:7d}: dt: {dt:.2f} dev_loss: {dev_loss.item():.4f} loss: {loss.item():.4f}')
        
        dev_losses.append(dev_loss.item())
        dev_steps.append(i)


current_time = time.perf_counter()
dt = current_time - start_time
print("total training time: {}".format(dt))


      0/ 200000: dt: 3.90 dev_loss: 3.5128 loss: 3.6062
   2000/ 200000: dt: 298.98 dev_loss: 2.5893 loss: 2.6829
   4000/ 200000: dt: 586.79 dev_loss: 2.6213 loss: 2.3875
   6000/ 200000: dt: 869.60 dev_loss: 2.9032 loss: 2.6299
   8000/ 200000: dt: 1150.57 dev_loss: 2.4862 loss: 2.3863
  10000/ 200000: dt: 1431.09 dev_loss: 2.6333 loss: 2.3692
  12000/ 200000: dt: 1711.92 dev_loss: 2.3036 loss: 2.3886
  14000/ 200000: dt: 1970.06 dev_loss: 2.1615 loss: 2.2835
  16000/ 200000: dt: 2225.28 dev_loss: 2.4631 loss: 2.3602
  18000/ 200000: dt: 2477.81 dev_loss: 2.2285 loss: 2.2893
  20000/ 200000: dt: 2731.91 dev_loss: 2.3686 loss: 2.6120
  22000/ 200000: dt: 2985.90 dev_loss: 2.6159 loss: 2.6038
  24000/ 200000: dt: 3239.59 dev_loss: 2.3325 loss: 2.4089
  26000/ 200000: dt: 3493.51 dev_loss: 2.6571 loss: 2.4474
  28000/ 200000: dt: 3747.53 dev_loss: 2.1306 loss: 2.3210
  30000/ 200000: dt: 4001.67 dev_loss: 2.1593 loss: 2.2871
  32000/ 200000: dt: 4255.47 dev_loss: 2.4880 loss: 2.3838
  3

In [23]:
from pytorch_lightning import LightningModule

In [29]:
import random
from torch.utils.data import DataLoader

class NameDataLoader():
    def __init__(self, words):
        self.X, self.Y = self._build_dataset(words)

    def __getitem__(self, index: int):
        return self.X[index], self.Y[index]

    def __len__(self) -> int:
        return len(self.Y)
        
    def _build_dataset(self, words):
        x, y = [], []

        for word in words:
            for i, c in enumerate(word + '.'):
                mini_x = []
                for w in reversed(range(1, window+1)):
                    if i - w >= 0:
                        mini_x.append(stoi[word[i-w]])
                    else:
                        mini_x.append(stoi['.'])

                x.append(mini_x)
                y.append(stoi[c])
                
        return torch.tensor(x), torch.tensor(y)


class NameData():
    def __init__(self, name_txt_path):

        self.names = open(name_txt_path, 'r').read().splitlines()

        random.seed(42)
        random.shuffle(self.names)
        n1 = int(0.8*len(self.names))
        n2 = int(0.9*len(self.names))

        self.train = NameDataLoader(self.names[:n1])
        self.dev = NameDataLoader(self.names[n1:n2])
        self.test = NameDataLoader(self.names[n2:])

    def train_data_loader(self):
        return self.train

    def test_data_loader(self):
        return self.test

    def val_data_loader(self):
        return self.dev


data = NameData('compiled_names.txt')


In [None]:
class LitSurnames(LightningModule):
    def __init__(self, data, d_model, heads, blocks, batch_size):
        super().__init__()

        # save hyper params
        self.d_model = d_model
        self.heads = heads
        self.blocks = blocks
        self.batch_size = batch_size

        if data is None:
            data = NameData('compiled_names.txt')
        else:
            self.data = data

        self.model = GPT(
            self.d_model, 
            self.heads, 
            self.blocks
        )

    def forward(self, x):
        return self.model(x)


    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits[:,-1,:], y)

        with torch.no_grad():
            opt.zero_grad()
            loss.backward()
            opt.step()
        
        self.log('tr_loss', loss)

        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits[:,-1,:], y)

        self.log('test_loss', loss)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits[:,-1,:], y)

        self.log('val_loss', loss)

        return loss

    ####################
    # DATA RELATED HOOKS
    ####################

    def prepare_data(self):
        pass

    def setup(self, stage=None):
        pass

    def train_dataloader(self):
        return DataLoader(self.data.train_data_loader(), batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.data.test_data_loader(), batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.data.val_data_loader(), batch_size=self.batch_size)

run = wandb.init(project="surnamerator", reinit=True)
wandb_logger = WandbLogger()
data, = None
# chars in dictionary
vocab = 30

# d_model is the same as embedding same for simplicity.
# embs = 12
d_model = 12

# number of chars to see in one window
window = 16

# This will increase one day
batch_size = 64

heads = 3

blocks = 6
lit_surname = LitSurnames()
