In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from model import *

In [3]:
# Input dimensions (strange numbers taken on purpose)
batch_size = 27
c = 11

# Model parameters (strange numbers taken on purpose
d_model = 64
d_head = 17
n_heads = 5
d_inner = 123
vocab_size = 23
n_blocks = 6

In [4]:
# Attention Head

attn_head = AttentionHead(d_model=d_model, d_head=d_head)

x = torch.randn(batch_size, c, d_model)
o = attn_head(x)

print(f"{x.shape} -> {o.shape}")

torch.Size([27, 11, 64]) -> torch.Size([27, 11, 17])


In [5]:
# Multi-Head Attention

mha = MultiHeadAttention(d_model=d_model, n_heads=n_heads, d_head=d_head)

x = torch.randn(batch_size, c, d_model)
x_new = mha(x)

print(f"{x.shape} -> {x_new.shape}")

torch.Size([27, 11, 64]) -> torch.Size([27, 11, 64])


In [6]:
# Feed Forward

ff = FeedForward(d_model=d_model, d_inner=d_inner)

x = torch.randn(batch_size, c, d_model)
x_new = ff(x)

print(f"{x.shape} -> {x_new.shape}")

torch.Size([27, 11, 64]) -> torch.Size([27, 11, 64])


In [7]:
# Basic Building Block

b = Block(d_model=d_model, n_heads=n_heads, d_head=d_head, d_inner=d_inner)

x = torch.randn(batch_size, c, d_model)
x_new = b(x)

print(f"{x.shape} -> {x_new.shape}")

torch.Size([27, 11, 64]) -> torch.Size([27, 11, 64])


In [8]:
# Language Model

lm = LanguageModel(vocab_size=vocab_size, d_model=d_model, n_heads=n_heads, n_blocks=n_blocks)

idx = torch.randint(0, vocab_size, (batch_size, c))
x_new, _ = lm(idx)

print(f"{idx.shape} -> {x.shape} -> {x_new.shape}")

torch.Size([27, 11]) -> torch.Size([27, 11, 64]) -> torch.Size([27, 11, 23])


### Side note on nn.Embedding

In [9]:
# model parameter
vocab_size = 5
block_size = 7
d_model = 4

# input parameters
sequence_length = 3
batch_size = 2

# embedding tables
token_embedding_table = nn.Embedding(vocab_size, d_model)
position_embedding_table = nn.Embedding(block_size, d_model)


print("--- WEIGHTS OF THE TOKEN EMBEDDING TABLE ---\n")
print(token_embedding_table.weight) # 5 different tokens, each one is 4-dimensional
print(token_embedding_table.weight.shape)


print("\n\n--- WEIGHTS OF THE POSITION EMBEDDING TABLE ---\n")
print(position_embedding_table.weight) # 7 different positions, each one is 4-dimensional
print(position_embedding_table.weight.shape)

print("\n\n--- INPUT INDICES ---\n")
idx = torch.randint(0, vocab_size, (batch_size, sequence_length)) # 3 random indices between 0 and 4
print(idx)
print(idx.shape)

print("\n\n--- TOKEN EMBEDDINGS ---\n")
token_embeddings = token_embedding_table(idx) # 2 sequences of 3 tokens, each one is 4-dimensional
print(token_embeddings)
print(token_embeddings.shape)

print("\n\n--- POSITION EMBEDDINGS ---\n")
t = torch.arange(sequence_length) # 3 positions
print(t)
position_embeddings = position_embedding_table(t) # 3 positions, each one is 4-dimensional
print(position_embeddings)
print(position_embeddings.shape)

print("\n\n--- ADDING TOKEN EMBEDDINGS AND POSITION EMBEDDINGS ---\n")
print(token_embeddings + position_embeddings)
print((token_embeddings + position_embeddings).shape)

--- WEIGHTS OF THE TOKEN EMBEDDING TABLE ---

Parameter containing:
tensor([[ 1.0325,  0.7264,  1.0161,  2.0160],
        [ 0.3157,  0.0622, -1.7561, -0.1970],
        [-0.9901, -0.7923,  1.2296, -0.0728],
        [ 0.5098, -0.2969,  0.9309, -1.7869],
        [ 1.6245,  0.9987,  0.5837,  0.7846]], requires_grad=True)
torch.Size([5, 4])


--- WEIGHTS OF THE POSITION EMBEDDING TABLE ---

Parameter containing:
tensor([[-0.8765, -0.1246,  0.4092, -1.0750],
        [ 1.7209,  0.4917,  1.2667, -0.1124],
        [ 1.6285, -1.5780,  1.2353,  0.1783],
        [ 0.5356,  0.7281,  0.6711, -0.0864],
        [-1.4852, -0.0774,  1.0756,  0.0436],
        [-0.2980,  1.7133, -0.7907, -1.8390],
        [ 0.3887,  0.2321, -0.9094,  0.0071]], requires_grad=True)
torch.Size([7, 4])


--- INPUT INDICES ---

tensor([[2, 2, 3],
        [3, 1, 2]])
torch.Size([2, 3])


--- TOKEN EMBEDDINGS ---

tensor([[[-0.9901, -0.7923,  1.2296, -0.0728],
         [-0.9901, -0.7923,  1.2296, -0.0728],
         [ 0.5098, -0.