# Transfomer model

In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import random
import matplotlib.pyplot as plt

In [6]:
# read the dataset
with open('./data/input.txt', 'r') as f:
    text = f.read()

# print out the first 100 characters
print(text[:200])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you


In [8]:
# create a list of all characters in the text
chars = sorted(list(set(text)))
vocab_size = len(chars)
# print out the number of unique characters
print('Number of unique characters: {}'.format(vocab_size))
# print out the characters
print(''.join(chars))

Number of unique characters: 65

 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


In [10]:
# create a dictionary that maps integers to characters and vice versa
int2char = dict(enumerate(chars))
char2int = {ch: ii for ii, ch in int2char.items()}

print(int2char)
print(char2int)

{0: '\n', 1: ' ', 2: '!', 3: '$', 4: '&', 5: "'", 6: ',', 7: '-', 8: '.', 9: '3', 10: ':', 11: ';', 12: '?', 13: 'A', 14: 'B', 15: 'C', 16: 'D', 17: 'E', 18: 'F', 19: 'G', 20: 'H', 21: 'I', 22: 'J', 23: 'K', 24: 'L', 25: 'M', 26: 'N', 27: 'O', 28: 'P', 29: 'Q', 30: 'R', 31: 'S', 32: 'T', 33: 'U', 34: 'V', 35: 'W', 36: 'X', 37: 'Y', 38: 'Z', 39: 'a', 40: 'b', 41: 'c', 42: 'd', 43: 'e', 44: 'f', 45: 'g', 46: 'h', 47: 'i', 48: 'j', 49: 'k', 50: 'l', 51: 'm', 52: 'n', 53: 'o', 54: 'p', 55: 'q', 56: 'r', 57: 's', 58: 't', 59: 'u', 60: 'v', 61: 'w', 62: 'x', 63: 'y', 64: 'z'}
{'\n': 0, ' ': 1, '!': 2, '$': 3, '&': 4, "'": 5, ',': 6, '-': 7, '.': 8, '3': 9, ':': 10, ';': 11, '?': 12, 'A': 13, 'B': 14, 'C': 15, 'D': 16, 'E': 17, 'F': 18, 'G': 19, 'H': 20, 'I': 21, 'J': 22, 'K': 23, 'L': 24, 'M': 25, 'N': 26, 'O': 27, 'P': 28, 'Q': 29, 'R': 30, 'S': 31, 'T': 32, 'U': 33, 'V': 34, 'W': 35, 'X': 36, 'Y': 37, 'Z': 38, 'a': 39, 'b': 40, 'c': 41, 'd': 42, 'e': 43, 'f': 44, 'g': 45, 'h': 46, 'i': 47,

In [16]:
# create encode and decode functions
encode = lambda text: [char2int[ch] for ch in text]
decode = lambda int_arr: ''.join([int2char[ii] for ii in int_arr])

# encode the text
print(encode('hello'))
print(decode(encode('hello')))

[46, 43, 50, 50, 53]
hello


In [18]:
# encode the whole text
text_data = torch.tensor(encode(text), dtype=torch.long)
print(text_data, text_data.shape)
print(text_data[:30])

tensor([18, 47, 56,  ..., 45,  8,  0]) torch.Size([1115394])
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43])


In [19]:
# split the dataset into train and test sets
train_n = int(text_data.shape[0] * 0.9)
train_data = text_data[:train_n]
test_data = text_data[train_n:]
print(train_data.shape, test_data.shape)

torch.Size([1003854]) torch.Size([111540])


In [21]:
block_size = 8

# show context and target
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f'context: {context}, target: {target}')

context: tensor([18]), target: 47
context: tensor([18, 47]), target: 56
context: tensor([18, 47, 56]), target: 57
context: tensor([18, 47, 56, 57]), target: 58
context: tensor([18, 47, 56, 57, 58]), target: 1
context: tensor([18, 47, 56, 57, 58,  1]), target: 15
context: tensor([18, 47, 56, 57, 58,  1, 15]), target: 47
context: tensor([18, 47, 56, 57, 58,  1, 15, 47]), target: 58


In [31]:
# create a batch generator
torch.manual_seed(1337)
batch_size = 4
block_size = 8

def get_batch(split):
    # generate random starting indices for the batch data
    data = train_data if split == 'train' else test_data
    # get the starting indices for the batch data
    starts = torch.randint(high=data.shape[0] - block_size, size=(batch_size,))
    # get the batch data
    batch_x = [data[start:start+block_size] for start in starts]
    batch_y = [data[start+1:start+block_size+1] for start in starts]
    # convert the list to tensors
    batch_x, batch_y = torch.stack(batch_x), torch.stack(batch_y)

    return batch_x, batch_y

In [32]:
# take a look at the batch data
xb, yb = get_batch('train')
print("inputs: \n", xb.shape)
print(xb)
print("targets: \n", yb.shape)
print(yb)

inputs: 
 torch.Size([4, 8])
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
targets: 
 torch.Size([4, 8])
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])


In [35]:
for b in range(batch_size):
    print(f'batch {b}:')
    for t in range(block_size):
        context = xb[b, :t+1]
        target = yb[b, t]
        print(f'context: {context}, target: {target}')

batch 0:
context: tensor([24]), target: 43
context: tensor([24, 43]), target: 58
context: tensor([24, 43, 58]), target: 5
context: tensor([24, 43, 58,  5]), target: 57
context: tensor([24, 43, 58,  5, 57]), target: 1
context: tensor([24, 43, 58,  5, 57,  1]), target: 46
context: tensor([24, 43, 58,  5, 57,  1, 46]), target: 43
context: tensor([24, 43, 58,  5, 57,  1, 46, 43]), target: 39
batch 1:
context: tensor([44]), target: 53
context: tensor([44, 53]), target: 56
context: tensor([44, 53, 56]), target: 1
context: tensor([44, 53, 56,  1]), target: 58
context: tensor([44, 53, 56,  1, 58]), target: 46
context: tensor([44, 53, 56,  1, 58, 46]), target: 39
context: tensor([44, 53, 56,  1, 58, 46, 39]), target: 58
context: tensor([44, 53, 56,  1, 58, 46, 39, 58]), target: 1
batch 2:
context: tensor([52]), target: 58
context: tensor([52, 58]), target: 1
context: tensor([52, 58,  1]), target: 58
context: tensor([52, 58,  1, 58]), target: 46
context: tensor([52, 58,  1, 58, 46]), target: 39


In [36]:
print(xb)

tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])


In [66]:
# bigram model

torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()

        # embedding layer
        # it is a matrix of size vocab_size x vocab_size
        # which serves as a lookup table for the token embeddings
        # what is lookup table?
        # it is a table that maps integers to embeddings
        # right now, it is initialized randomly
        # but it will be trained later and learned from the data
        # here the vector size is the same as the vocab size
        self.token_embeddings = nn.Embedding(vocab_size, vocab_size)

    
    def forward(self, idx, target=None):

        # get the token embeddings  (batch_size, block_size, channel_size)
        # block_size is the number of tokens in the context
        # which is also called the time steps, that's why it is T
        logits = self.token_embeddings(idx) #(B, T, C)

        if target is None:
            loss = None
        else:
            # reshape the embeddings to (batch_size * block_size, channel_size)
            B, T, C = logits.shape
            logits = logits.view(B * T, C)
            target = target.view(B * T)
            # compute the loss
            loss = F.cross_entropy(logits, target)
    
        return logits, loss 
    
    def generate(self, idx, max_new_tokens):
        # idx is (batch_size, block_size)  which is (B, T)
        # which means the context for each batch
        for _ in range(max_new_tokens):
            # get the predictions
            # self(idx, target) is the same as self.forward(idx, target)
            logits, loss = self(idx)
            # get the logit for the last token in the context
            # which is the token we want to predict
            # here logits is (batch_size, block_size, channel_size)
            # beacause loss=None, we did not reshape the logits
            logits = logits[:, -1, :]
            # the shape now is (batch_size, channel_size)
            # dim=-1 means the last dimension, which is the channel_size
            # we want to normalize the logits to get the probabilities
            # in the dimension of channel_size
            probs = F.softmax(logits, dim=-1)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)
            # append the new token to the context
            idx = torch.cat([idx, idx_next], dim=1)

        return idx

    

# create the model
m1 = BigramLanguageModel(vocab_size)
logits, loss = m1(xb, yb)
print(logits.shape)
print(loss)

# we are trying to generate 10 new tokens
foo = m1.generate(idx=torch.zeros((1, 1), dtype=torch.long), max_new_tokens=10)
print(foo.shape)
print(foo[0])

torch.Size([32, 65])
tensor(4.8786, grad_fn=<NllLossBackward0>)
torch.Size([1, 11])
tensor([ 0, 31, 56, 12, 55, 28,  7, 29, 35, 49, 58])


In [54]:
# let's calculate the log likelihood of the model
# when we randomly initialize the model, the log likelihood is
- np.log(1/len(chars))

4.174387269895637

In [67]:
# create an optimizer
optimizer = torch.optim.Adam(m1.parameters(), lr=1e-3)

In [70]:
# batch size = 32
batch_size = 32

for i in range(5000):

    xb, yb = get_batch('train')

    # forward pass
    logits, loss = m1(xb, yb)
    # clear the gradients
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    # update the parameters
    optimizer.step()

    if i % 300 == 0:
        print(f'iteration {i}, loss {loss:.4f}')

iteration 0, loss 2.6934
iteration 300, loss 2.8176
iteration 600, loss 2.6193
iteration 900, loss 2.5633
iteration 1200, loss 2.4729
iteration 1500, loss 2.5656
iteration 1800, loss 2.5707
iteration 2100, loss 2.5902
iteration 2400, loss 2.5247
iteration 2700, loss 2.5106
iteration 3000, loss 2.4144
iteration 3300, loss 2.4463
iteration 3600, loss 2.5506
iteration 3900, loss 2.4637
iteration 4200, loss 2.4787
iteration 4500, loss 2.4445
iteration 4800, loss 2.4464


In [76]:
foo = m1.generate(idx=torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)
print(decode(foo[0].tolist()))


Bl s my tln an My fe res angor:


WAmell ssty t g sh niene himp pl hte sept pe o: lea, f her ger ft 


## The matematical tricks of self-attention

In [79]:
# simulate the self-attention mechanism
torch.manual_seed
B, T, C = 4, 8, 2  # batch_size, block_size(time), channel_size
foox = torch.randn((B, T, C))
print(foox.shape)

torch.Size([4, 8, 2])


In [80]:
xbow = torch.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        xprev = foox[b, :t+1]
        xbow[b, t] = xprev.mean(dim=0)


In [82]:
foox[0]

tensor([[-0.3271,  0.5617],
        [-1.7239, -0.7233],
        [ 1.8522,  1.6077],
        [-0.1108, -1.9461],
        [ 0.4432, -0.6504],
        [ 1.7156, -0.4671],
        [-1.4578,  0.9734],
        [ 1.9970, -0.6166]])

In [83]:
xbow[0]

tensor([[-0.3271,  0.5617],
        [-1.0255, -0.0808],
        [-0.0663,  0.4820],
        [-0.0774, -0.1250],
        [ 0.0267, -0.2301],
        [ 0.3082, -0.2696],
        [ 0.0559, -0.0920],
        [ 0.2986, -0.1576]])

In [89]:
torch.manual_seed(1337)
a = torch.tril(torch.ones((3, 3)))
print(a)
b = torch.randint(0, 10, (3, 2)).float()
print(b)
a @ b

a = a/ a.sum(dim=1, keepdim=True)
print(a)
a @ b

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
tensor([[5., 7.],
        [2., 0.],
        [5., 3.]])
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])


tensor([[5.0000, 7.0000],
        [3.5000, 3.5000],
        [4.0000, 3.3333]])

In [91]:
wei = torch.tril(torch.ones((T, T)))
wei = wei / wei.sum(dim=1, keepdim=True)
print(wei)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


In [93]:
xbow2 = wei @ foox
print(xbow2.shape)

torch.Size([4, 8, 2])


In [94]:
torch.allclose(xbow, xbow2)

True

In [98]:
# version 3 using softmax
tril3 = torch.tril(torch.ones((T, T)))
print(tril3)
wei3 = torch.zeros((T, T))
print(wei3)
wei3 = wei3.masked_fill(tril3 == 0, float('-inf'))
print(wei3)
# call softmax
wei3 = F.softmax(wei3, dim=1)
print(wei3)

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]])
tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0.,

In [99]:
torch.allclose(wei, wei3)

True

In [121]:
# self-attention demo
torch.manual_seed(1337)
B, T, C = 4, 8, 32 # batch_size, block_size(time), channel_size
# channel_size is the embedding size
x = torch.randn((B, T, C))


# let's add query, key, value
# set head_size = 16, which means we have 2 heads
# because 32 / 16 = 2
head_size = 16
query = nn.Linear(C, head_size, bias=False)
key = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

x_q = query(x)  # (B, T, head_size)
x_k = key(x)  # (B, T, head_size)

wei = x_q @ x_k.transpose(1, 2)  # (B, T, head_size) @ (B, head_size, T) = (B, T, T)


# average of the previous tokens
tril = torch.tril(torch.ones((T, T)))
# # wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
# print(wei.shape)
wei = F.softmax(wei, dim=-1)
# print(wei.shape)
v = value(x)  # (B, T, head_size)
out = wei @ v  # (B, T, T) @ (B, T, head_size) = (B, T, head_size)


In [122]:
wei[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5877, 0.4123, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4457, 0.2810, 0.2733, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2220, 0.7496, 0.0175, 0.0109, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0379, 0.0124, 0.0412, 0.0630, 0.8454, 0.0000, 0.0000, 0.0000],
        [0.5497, 0.2187, 0.0185, 0.0239, 0.1831, 0.0062, 0.0000, 0.0000],
        [0.2576, 0.0830, 0.0946, 0.0241, 0.1273, 0.3627, 0.0507, 0.0000],
        [0.0499, 0.1052, 0.0302, 0.0281, 0.1980, 0.2657, 0.1755, 0.1474]],
       grad_fn=<SelectBackward0>)