In [2]:
import numpy as np
import json
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import math
import re
from transformers import BertTokenizer

In [4]:
context_length = 128
batch_size = 8
d_embed = 64

In [191]:
def load_file(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        text = f.read()
    return text

In [192]:
text = load_file('maze_runner.txt')

In [194]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

In [None]:
print(vocab_size)

In [7]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
tokenized_jokes = tokenizer(jokes, return_tensors='pt', padding=True, truncation=True, max_length=context_length)

In [8]:
vocab_size = tokenizer.vocab_size
print(vocab_size)

30522


In [189]:
START_TOKEN = tokenizer.cls_token_id
END_TOKEN = tokenizer.sep_token_id
PAD_TOKEN = tokenizer.pad_token_id
t = tokenized_jokes['input_ids'][92]
tokenizer.decode(t)

101 102 0


'[CLS] im an avid supporter of the flat earth society i always have heated debate about it with my friend residing in the other hemisphere [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'

In [11]:
data = tokenized_jokes['input_ids']

In [12]:
split_index = int(len(jokes) * 0.85)
train_jokes = data[:split_index]
test_jokes = data[split_index:]

In [179]:
def get_batch(split_type=None):
    jokes = train_jokes if split_type == 'train'or split_type == None else val_jokes
    # get a offset between 0 and len(train_jokes) - batch_size - 1
    random_idx = np.random.randint(0, len(jokes) - batch_size - 1)
    x = torch.stack([jokes[random_idx + i] for i in range(0, batch_size)])
    y = torch.stack([torch.cat((jokes[random_idx + i][1:], torch.tensor([PAD_TOKEN]))) for i in range(0, batch_size)])
    return x, y
    

In [180]:
def softmax(x):
    return np.exp(x) / np.sum(np.exp(x), axis=0)

In [181]:
class InputEmbeddings(nn.Module):
    def __init__(self, d_embed : int, vocab_size : int):
        super().__init__()
        self.d_embed = d_embed
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, d_embed)
    
    def forward(self, x):
        return self.embedding(x) * math.sqrt(self.d_embed)

class PositionalEncoding(nn.Module):
    def __init__(self, d_embed, context_length):
        super().__init__()
        self.d_embed = d_embed
        self.context_length = context_length
        positional_encoding = torch.zeros(context_length, d_embed)
        position_index = torch.arange(0, context_length, dtype=torch.float).unsqueeze(1) # (seq_len, 1)
        denominator = torch.exp(torch.arange(0, d_embed, 2).float() * (-math.log(10000.0) / d_embed))
        positional_encoding[ : , 0::2] = torch.sin(position_index * denominator)
        positional_encoding[ : , 1::2] = torch.cos(position_index * denominator)
        positional_encoding = positional_encoding.unsqueeze(0)
        self.register_buffer('positional_encoding', positional_encoding)
    
    def forward(self, x):
        return (x + self.positional_encoding[: , :x.shape[1], :])

In [182]:
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(d_embed, head_size, bias=False)
        self.query = nn.Linear(d_embed, head_size, bias=False)
        self.value = nn.Linear(d_embed, head_size, bias=False)
        self.mask = torch.tril(torch.ones(context_length, context_length))
    
    def forward(self, x):
        batch_size, sequence_length, feature_dimension = x.shape
        K = self.key(x)
        Q = self.query(x)
        q_kt = Q @ K.transpose(-2, -1) / np.sqrt(feature_dimension) 
        q_kt = q_kt.masked_fill(self.mask == 0, float('-inf'))
        scaled_qkt = torch.nn.functional.softmax(q_kt, dim=-1)
        V = self.value(x)
        attention = scaled_qkt @ V
        return attention

In [183]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for i in range(num_heads)])
        self.linear_layer = nn.Linear(head_size * num_heads, d_embed) # head_size * num_heads = d_embed (usually)

    def forward(self, x):
        head_outputs = torch.cat([head(x) for head in self.heads], dim=-1)
        return self.linear_layer(head_outputs)

In [184]:
class FeedForward(nn.Module):
    def __init__(self, d_embed):
        super().__init__()
        self.linear_layer = nn.Sequential(
            nn.Linear(d_embed, 4 * d_embed),
            nn.ReLU(),
            nn.Linear(4 * d_embed, d_embed)
        )
    
    def forward(self, x):
        return self.linear_layer(x)

In [185]:
class Block(nn.Module):
    def __init__(self, d_embed, num_heads):
        super().__init__()
        head_size = d_embed // num_heads # head_size is "how" much of the embedding is "seen" by each head
        self.multi_head_attention = MultiHeadAttention(head_size, num_heads)
        self.layer_norm1 = nn.LayerNorm(d_embed)
        self.feed_forward_layer = FeedForward(d_embed)
        self.layer_norm2 = nn.LayerNorm(d_embed)
    
    def forward(self, x):
        attention = self.multi_head_attention(x)
        x = self.layer_norm1(x + attention)
        feed_forward = self.feed_forward_layer(x)
        return self.layer_norm2(x + feed_forward)

In [186]:
class Transformer(nn.Module):
    def __init__(self, d_embed, num_heads, num_blocks):
        super().__init__()
        self.input_embeddings = InputEmbeddings(d_embed, vocab_size)
        self.positional_encoding = PositionalEncoding(d_embed, context_length)
        self.blocks = nn.Sequential(*[Block(d_embed, num_heads) for i in range(num_blocks)])
        self.final_layer_norm = nn.LayerNorm(d_embed)
        self.output_layer = nn.Linear(d_embed, vocab_size)
    
    def forward(self, x, y = None):
        batch_size, sequence_length = x.shape
        x_input_embeddings = self.input_embeddings(x)
        x_positional = self.positional_encoding(x_input_embeddings)
        block_out = self.blocks(x_positional)
        layer_norm_out = self.final_layer_norm(block_out)
        logits = self.output_layer(layer_norm_out)
        loss = None
        if y is None:
            return logits, loss
        else:
            loss = nn.CrossEntropyLoss()(logits.view(-1, vocab_size), y.view(-1))
            return logits, loss

In [187]:
model = Transformer(d_embed, 4, 4)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)

In [188]:
for epoch in range(5000):
    x, y = get_batch('train')
    logits, loss = model(x, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if epoch % 100 == 0:
        print(loss.item())

9.962092399597168
1.1225227117538452
1.003089427947998
0.9844313263893127


KeyboardInterrupt: 