In [152]:
import numpy as np
import json
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import math
import re
from transformers import BertTokenizer

In [399]:
context_length = 128
batch_size = 8
d_embed = 64

In [400]:
def load_jokes(file_path):
    with open(file_path, 'r') as file:
        jokes = file.readlines()
    jokes = [re.sub('\n', '', joke) for joke in jokes]
    return jokes

In [401]:
jokes = load_jokes('jokes.txt')

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
tokenized_jokes = tokenizer(jokes, return_tensors='pt', padding=True, truncation=True, max_length=context_length)

In [422]:
vocab_size = tokenizer.vocab_size
print(vocab_size)

30522


In [423]:
START_TOKEN = tokenizer.cls_token_id
END_TOKEN = tokenizer.sep_token_id
PAD_TOKEN = tokenizer.pad_token_id
t = tokenized_jokes['input_ids'][92]
tokenizer.decode(t)

'[CLS] im an avid supporter of the flat earth society i always have heated debate about it with my friend residing in the other hemisphere [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'

In [424]:
data = tokenized_jokes['input_ids']

In [425]:
split_index = int(len(jokes) * 0.85)
train_jokes = data[:split_index]
test_jokes = data[split_index:]

In [426]:
def get_batch(split_type=None):
    jokes = train_jokes if split_type == 'train'or split_type == None else val_jokes
    # get a offset between 0 and len(train_jokes) - batch_size - 1
    random_idx = np.random.randint(0, len(jokes) - batch_size - 1)
    x = torch.stack([jokes[random_idx + i] for i in range(0, batch_size)])
    y = torch.stack([torch.cat((jokes[random_idx + i][1:], torch.tensor([PAD_TOKEN]))) for i in range(0, batch_size)])
    return x, y
    

In [429]:
example_x, example_y = get_batch('train')
print(example_x.shape, example_y.shape)
for i in range(2):
    print(example_x[i][example_x[i] != PAD_TOKEN])
    print(example_y[i][example_y[i] != PAD_TOKEN])
    print()

torch.Size([8, 128]) torch.Size([8, 128])
tensor([  101,  2026,  3144,  6090, 17955,  2449,  2001,  3728,  3844,  2091,
         2619, 11182,  2125,  1996,  2610,  2008,  1045,  2001,  4855,  2068,
         2980,   102])
tensor([ 2026,  3144,  6090, 17955,  2449,  2001,  3728,  3844,  2091,  2619,
        11182,  2125,  1996,  2610,  2008,  1045,  2001,  4855,  2068,  2980,
          102])

tensor([  101,  2065,  3607, 18445,  2015,  4977,  2013,  1996,  4373,  2052,
         5483,  2393,   102])
tensor([ 2065,  3607, 18445,  2015,  4977,  2013,  1996,  4373,  2052,  5483,
         2393,   102])



In [275]:
def softmax(x):
    return np.exp(x) / np.sum(np.exp(x), axis=0)

In [439]:
class InputEmbeddings(nn.Module):
    def __init__(self, d_embed : int, vocab_size: int):
        super().__init__()
        self.d_embed = d_embed
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, d_embed)
    
    def forward(self, x):
        return self.embedding(x) * math.sqrt(self.d_embed)

class PositionalEncoding(nn.Module):
    def __init__(self, d_embed: int, seq_len: int, dropout: float):
        super().__init__()
        self.d_embed = d_embed
        self.seq_len = seq_len
        self.dropout = dropout

        positional_encoding = torch.zeros(d_embed, seq_len) # (d_embed, seq_len)
        position_index = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1) # (seq_len, 1)
        denominator = torch.exp(torch.arange(0, d_embed, 2).float())

In [441]:
t = torch.arange(0, d_embed, 2)
t

tensor([ 0,  2,  4,  6,  8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34,
        36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62])

In [440]:
input_embedding = InputEmbeddings(d_embed, vocab_size)
positional_encoding = PositionalEncoding(d_embed, context_length, 0.1)

x_example_embed = input_embedding(example_x)
print(x_example_embed.shape)

torch.Size([8, 128, 64])


In [16]:
x = torch.rand(16, 10, d)

In [46]:
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(d_embed, head_size, bias=False)
        self.query = nn.Linear(d_embed, head_size, bias=False)
        self.value = nn.Linear(d_embed, head_size, bias=False)
        self.mask = torch.tril(torch.ones(context_length, context_length))
    
    def forward(self, x):
        batch_size, sequence_length, feature_dimension = x.shape
        K = self.key(x)
        Q = self.query(x)
        q_kt = Q @ K.transpose(-2, -1) / np.sqrt(feature_dimension) 
        q_kt = q_kt.masked_fill(self.mask == 0, float('-inf'))
        similarity = torch.nn.functional.softmax(q_kt, dim=-1)
        V = self.value(x)
        attention = similarity @ V
        return attention

In [47]:
class MultiHeadAttention(nn.Module):

    def __init__(self, head_size, num_heads):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for i in range(num_heads)])
        self.linear_layer = nn.Linear(head_size * num_heads, d_embed)

    def forward(self, x):
        head_outputs = torch.cat([head(x) for head in self.heads], dim=-1)
        return self.linear_layer(head_outputs)

In [None]:
class Block(nn.Module):
    def __init__(self, d_embed, num_heads):
        super().__init__()
        head_size = d_embed // num_heads 
        self.multi_head_attention = MultiHeadAttention(head_size, num_heads)
        self.layer_norm1 = nn.LayerNorm(d_embed)
        self.feed_forward_layer = nn.Sequential(d_embed)
        self.layer_norm2 = nn.LayerNorm(d_embed)
    
    def forward(self, x):
        attention = self.multi_head_attention(x)
        x = self.layer_norm1(x + attention)
        feed_forward = self.feed_forward_layer(x)
        return self.layer_norm2(x + feed_forward)