In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from model import *

### Input and Output Shapes

In [3]:
# Input dimensions (strange numbers taken on purpose)
batch_size = 27
c = 11

# Model parameters (strange numbers taken on purpose)
d_model = 64
d_head = 17
n_heads = 5
d_inner = 123
vocab_size = 23
n_blocks = 6

In [4]:
# Attention Head

attn_head = AttentionHead(d_model=d_model, d_head=d_head)

x = torch.randn(batch_size, c, d_model)
o = attn_head(x)

print(f"{x.shape} -> {o.shape}")

torch.Size([27, 11, 64]) -> torch.Size([27, 11, 17])


In [5]:
# Multi-Head Attention

mha = MultiHeadAttention(d_model=d_model, n_heads=n_heads, d_head=d_head)

x = torch.randn(batch_size, c, d_model)
x_new = mha(x)

print(f"{x.shape} -> {x_new.shape}")

torch.Size([27, 11, 64]) -> torch.Size([27, 11, 64])


In [6]:
# Feed Forward

ff = FeedForward(d_model=d_model, d_inner=d_inner)

x = torch.randn(batch_size, c, d_model)
x_new = ff(x)

print(f"{x.shape} -> {x_new.shape}")

torch.Size([27, 11, 64]) -> torch.Size([27, 11, 64])


In [7]:
# Basic Building Block

b = Block(d_model=d_model, n_heads=n_heads, d_head=d_head, d_inner=d_inner)

x = torch.randn(batch_size, c, d_model)
x_new = b(x)

print(f"{x.shape} -> {x_new.shape}")

torch.Size([27, 11, 64]) -> torch.Size([27, 11, 64])


In [8]:
# Language Model

lm = LanguageModel(vocab_size=vocab_size, d_model=d_model, n_heads=n_heads, n_blocks=n_blocks)

idx = torch.randint(0, vocab_size, (batch_size, c))
x_new, _ = lm(idx)

print(f"{idx.shape} -> {x.shape} -> {x_new.shape}")

torch.Size([27, 11]) -> torch.Size([27, 11, 64]) -> torch.Size([27, 11, 23])


### Side note on nn.Embedding

In [9]:
# model parameter
vocab_size = 5
block_size = 7
d_model = 4

# input parameters
sequence_length = 3
batch_size = 2

# embedding tables
token_embedding_table = nn.Embedding(vocab_size, d_model)
position_embedding_table = nn.Embedding(block_size, d_model)


print("--- WEIGHTS OF THE TOKEN EMBEDDING TABLE ---\n")
print(token_embedding_table.weight) # 5 different tokens, each one is 4-dimensional
print(token_embedding_table.weight.shape)


print("\n\n--- WEIGHTS OF THE POSITION EMBEDDING TABLE ---\n")
print(position_embedding_table.weight) # 7 different positions, each one is 4-dimensional
print(position_embedding_table.weight.shape)

print("\n\n--- INPUT INDICES ---\n")
idx = torch.randint(0, vocab_size, (batch_size, sequence_length)) # 3 random indices between 0 and 4
print(idx)
print(idx.shape)

print("\n\n--- TOKEN EMBEDDINGS ---\n")
token_embeddings = token_embedding_table(idx) # 2 sequences of 3 tokens, each one is 4-dimensional
print(token_embeddings)
print(token_embeddings.shape)

print("\n\n--- POSITION EMBEDDINGS ---\n")
t = torch.arange(sequence_length) # 3 positions
print(t)
position_embeddings = position_embedding_table(t) # 3 positions, each one is 4-dimensional
print(position_embeddings)
print(position_embeddings.shape)

print("\n\n--- ADDING TOKEN EMBEDDINGS AND POSITION EMBEDDINGS ---\n")
print(token_embeddings + position_embeddings)
print((token_embeddings + position_embeddings).shape)

--- WEIGHTS OF THE TOKEN EMBEDDING TABLE ---

Parameter containing:
tensor([[-0.1978, -0.9415,  0.1220, -2.1716],
        [ 0.5547, -0.8293, -1.8911,  0.1926],
        [ 1.5990, -1.7100, -0.9501,  1.2014],
        [-0.3939, -0.9876,  0.8104,  0.3245],
        [-0.0311,  1.4183,  1.6611, -0.6586]], requires_grad=True)
torch.Size([5, 4])


--- WEIGHTS OF THE POSITION EMBEDDING TABLE ---

Parameter containing:
tensor([[-0.1639, -2.5262,  0.0626, -0.4400],
        [-0.5536, -0.4538, -0.8721,  1.1158],
        [ 0.3666, -0.1443,  1.7843, -0.2331],
        [ 0.3175,  1.2170,  0.2160, -0.4171],
        [-0.5369, -0.8636,  0.8445,  0.6123],
        [-1.0339, -0.1021,  0.1612, -0.5991],
        [ 0.6750,  1.4297, -1.3731,  0.8217]], requires_grad=True)
torch.Size([7, 4])


--- INPUT INDICES ---

tensor([[3, 0, 0],
        [1, 0, 1]])
torch.Size([2, 3])


--- TOKEN EMBEDDINGS ---

tensor([[[-0.3939, -0.9876,  0.8104,  0.3245],
         [-0.1978, -0.9415,  0.1220, -2.1716],
         [-0.1978, -0.