In [1]:
import numpy as np
import matplotlib.pyplot as plt
from IPython import embed

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

torch.manual_seed(666)

<torch._C.Generator at 0x7f8259b56050>

In [41]:
def load_data():
    # Read text as one giant string
    with open("./input.txt", 'r') as f:
        text = f.read()
        f.close()
    text = text[:int(0.1*len(text))]
    text_len = len(text)
    print(text_len)

    # Get all characters in data, (extra, not required->) rearrange so letters come first
    vocab = sorted(list(set(text)))
    vocab = vocab[:1] + vocab[13:] + vocab[1:13]
    vocab_size = len(vocab)

    # Convert between chars and ints
    int2char = {i: vocab[i] for i in range(vocab_size)} # ints -> chars
    char2int = {vocab[i]:i for i in range(vocab_size)}   # chars -> ints
    text_ints = [char2int[text[i]] for i in range(len(text))]
    train = text_ints[:int(0.8*text_len)]
    val = text_ints[int(0.8*text_len):int(0.9*text_len)]
    test = text_ints[int(0.9*text_len):]
    assert len(train)+len(val)+len(test)==text_len
    
    return train,val,test,vocab_size,int2char

    # Check char-int conversion:
    # for i in range(10):
    #     assert int2char[text_ints[i]]==text[i]

    # magic_nums = {}
    # magic_nums["text_len"] = text_len
    # magic_nums["vocab_size"] = vocab_size

In [42]:
class Chunks(Dataset):
    def __init__(self, data, chunk_size):
        super().__init__()
        self.data = data
        self.chunk_size = chunk_size
    def __getitem__(self, indx):
        x = self.data[indx:indx+self.chunk_size]
        y = self.data[indx+1:indx+self.chunk_size+1]
        return torch.tensor(x), torch.tensor(y)
    def __len__(self):
        return len(self.data) - self.chunk_size


# Check for off-by-one index issues:  
# train_set = Chunks(train[:8], chunk_size=4)
# train_loader = DataLoader(train_set, batch_size=1, shuffle=False, drop_last=False)
# print(train_set.data)
# for x,y in train_loader:
#     print(x)
#     print("    " + str(y))
#     print()

In [57]:
from torch import einsum
from einops import rearrange, reduce, repeat

class Projections(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.d = d
        self.Q = nn.Linear(d, d, bias=False)
        self.K = nn.Linear(d, d, bias=False)
        self.V = nn.Linear(d, d, bias=False)
        
    def forward(self, x):
        q = self.Q(x)
        k = self.K(x)
        v = self.V(x)
        return q,k,v
    

class MultiHead(nn.Module):
    def __init__(self, d, n_heads):
        super().__init__()
        self.d = torch.tensor(d)
        self.n_heads = n_heads
        
    def forward(self, q, k, v):
        # Reshape to create heads dimension.
        q = rearrange(q, "b c (n_heads d_head) -> b c n_heads d_head", n_heads=self.n_heads)
        k = rearrange(k, "b c (n_heads d_head) -> b c n_heads d_head", n_heads=self.n_heads)
        v = rearrange(v, "b c (n_heads d_head) -> b n_heads c d_head", n_heads=self.n_heads)
        # Scaled dot product
        qk = einsum("b Q n d, b K n d -> b n Q K", q, k) # capital Q,K refer to sequence length axis
        qk = qk / torch.sqrt(self.d)
        # Future masking every (cq, ck)-shaped square matrix
        mask = torch.ones(qk.shape, device='cuda')
        mask = torch.tril(mask, diagonal=0)
        qk.masked_fill_(mask==0, value=float('-inf'))
        # Softmax to get attention weights
        attention = F.softmax(qk, dim=-1)
        # Apply attention*values
        output = einsum("b n Q K, b n K d -> b n Q d", attention, v)
        # Reshape to remove heads dimension
        output = rearrange(output, "b n cq d -> b cq (n d)")
        return output

class DecoderBlock(nn.Module):
    def __init__(self, d, n_heads, mlp_width=6):
        super().__init__()
        self.d = d
        self.n_heads = n_heads
        self.mlp_width = mlp_width
        
        self.LN1 = nn.LayerNorm(d)
        self.PROJ = Projections(d)
        self.MHA = MultiHead(d, n_heads)
        self.LN2 = nn.LayerNorm(d)
        self.MLP = nn.Sequential(
            nn.Linear(d, d*mlp_width),
            nn.ReLU(),
            nn.Linear(d*mlp_width, d)
        )
        
    def forward(self, x0):
        # Attention
        q,k,v = self.PROJ(self.LN1(x0))
        x1 = self.MHA(q,k,v)
        x1 += x0
        # MLP
        x2 = self.MLP(self.LN2(x1))
        x2 += x1
        return x2
    
class LanguageModel(nn.Module):
    def __init__(self, d, vocab_size, chunk_size, n_heads=6, n_blocks=6):
        super().__init__()
        # Hyperparams
        self.vocab_size = vocab_size
        self.d = d
        self.chunk_size = chunk_size
        self.n_heads = n_heads
        self.n_blocks = n_blocks
        # Initial embeddings
        self.E_pos = nn.Embedding(chunk_size, d)
        self.E_token = nn.Embedding(vocab_size, d)
        # Decoder blocks
        self.DBlocks = nn.ModuleList([DecoderBlock(d, n_heads) for i in range(n_blocks)])
        # Linear map to logits of distribution over vocab
        self.LN_final = nn.LayerNorm(d)
        self.logits = nn.Linear(d, vocab_size)
    def forward(self, x_ints):
        # Embeddings
        x_pos = self.E_pos(torch.arange(x_ints.shape[1], device='cuda'))
        x_pos = rearrange(x_pos, "c d -> 1 c d")
        x_token = self.E_token(x_ints)
        x1 = x_token + x_pos
        # Decoder blocks
        for db in self.DBlocks:
            x1 = db(x1)
        # Linear to logits
        x1 = self.LN_final(x1)
        x2 = self.logits(x1) 
        return x2



### Create model

In [58]:
chunk_size = 256
batch_size = 64
d = 384

train, val, test, vocab_size, int2char = load_data()
train_set = Chunks(train, chunk_size=chunk_size)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=False)
model = LanguageModel(d, vocab_size, chunk_size)
print(model)

lr = 3e-4
epochs = 1
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

111539
LanguageModel(
  (E_pos): Embedding(256, 384)
  (E_token): Embedding(61, 384)
  (DBlocks): ModuleList(
    (0): DecoderBlock(
      (LN1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      (PROJ): Projections(
        (Q): Linear(in_features=384, out_features=384, bias=False)
        (K): Linear(in_features=384, out_features=384, bias=False)
        (V): Linear(in_features=384, out_features=384, bias=False)
      )
      (MHA): MultiHead()
      (LN2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      (MLP): Sequential(
        (0): Linear(in_features=384, out_features=2304, bias=True)
        (1): ReLU()
        (2): Linear(in_features=2304, out_features=384, bias=True)
      )
    )
    (1): DecoderBlock(
      (LN1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      (PROJ): Projections(
        (Q): Linear(in_features=384, out_features=384, bias=False)
        (K): Linear(in_features=384, out_features=384, bias=False)
        (V): Linear(in_featu

### Training loop

In [59]:
def val_eval(model, val, batch_size):
    val_set = Chunks(val, chunk_size=model.chunk_size)
    val_loader = DataLoader(val_set, batch_size=batch_size, drop_last=True)
    val_loss = 0.0
    model.eval()
    with torch.no_grad():
        for x,y in tqdm(val_loader):
            x = x.to('cuda')
            y = y.to('cuda')
            preds = model(x)
            preds = rearrange(preds, "b c vocab_size -> (b c) vocab_size")
            y = rearrange(y, "b c -> (b c)")
            loss = F.cross_entropy(preds, y)
            val_loss += loss.item()
    model.train()
    return val_loss / len(val_loader)
            

In [60]:
from tqdm import tqdm
# Move to GPU
model = model.to('cuda')
model.train()
epoch_losses = []
for epoch in range(epochs):
    train_loss = 0.0
    for x,y in tqdm(train_loader):
        optimizer.zero_grad()
        x = x.to('cuda')
        y = y.to('cuda')
        preds = model(x)
        preds = rearrange(preds, "b c vocab_size -> (b c) vocab_size")
        y = rearrange(y, "b c -> (b c)")
        loss = F.cross_entropy(preds, y)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        
    train_loss /= len(train_loader)
    val_loss = val_eval(model, val, batch_size)
    epoch_losses.append((train_loss, val_loss))
    print(f"Epoch #{epoch}:")
    print(f"        Train loss: {train_loss}")
    print(f"        Val   loss: {val_loss}")
    print()

100%|██████████| 1391/1391 [02:22<00:00,  9.76it/s]
100%|██████████| 170/170 [00:06<00:00, 25.84it/s]

Epoch #0:
        Train loss: 1.4098656814596278
        Val   loss: 3.188582059916328






### Generate text

In [63]:
model.eval()
# Generate text
temp = 1.0
context = [0]
context_char = []
with torch.no_grad():
    for i in range(1000):
        if len(context) > model.chunk_size:
            x = torch.tensor(context[-model.chunk_size:])
        else:
            x = torch.tensor(context)
        x = torch.unsqueeze(x, dim=0)
        x = x.to('cuda')
        logits = model(x)
        logits = torch.squeeze(logits[:,-1:,:]) # (B=1, c, d). Take last index along c, remove batch dimension.
        probs = F.softmax(logits/temp)
        # Sampling: show where probability mass is concentrated,
        # take argmax or sample
#         print(torch.topk(probs,k=5))
#         next_int = int(torch.argmax(probs))
        next_int = int(torch.multinomial(probs, num_samples=1))
        context.append(next_int)
    for i in range(len(context)):
        next_char = int2char[context[i]]
        context_char.append(next_char)
print(''.join(context_char))

  probs = F.softmax(logits/temp)




MENENIUS:
Has shall straight his lip and abhugh?

First Citizen:
To lose it forming the to bear abbarbaced!' But,
What corrrouse the may disgrace as whiles
That flatters in his naturers, you change he fight
Wither have but lived more than you necesar
He lead the pand what if hor not ine, and who her
Their vowsest of their complaninings; and their stong
Before Coriolanus he draily is nature, and his blood
To the people, who it show'd upon state,
He weds mine honour and of angatin
He honour in's enreches like about friend
Than Now the foolsh of them.

COMINIUS:
Though there's a lefter
The beggarn of the senate, who could
But field and city 
And braile stinctly: if my were statnd to be with the fault,
But my daughtines bad with a liess lift them
Breal and cortable the people, who ne'er show'd
The dust of aughter, even still was like fatte
To least the city is proud. Here come begin
With here he senaters: and shall his farst,
We care a shim friers; fand gener bed,
I could be like the fla