# This notebook is designed for teaching purposes to help you visualize the tensor shapes that go through each module. Read along with 'model.py'

In [1]:
# my virtual environments are rarely properly connected to jupyter so this fixes that
import sys
import os
current_dir = os.getcwd()  # Get the current working directory
venv_dir = os.path.join(current_dir, '../venv') 
python_version = str(sys.version_info.major) + '.' + str(sys.version_info.minor)
site_packages_path = os.path.join(venv_dir, 'lib', 'python' + python_version, 'site-packages')
sys.path.append(site_packages_path) 

In [2]:
# tokenizer
sys.path.append("..")  # Adds the parent directory to the path so we can see the tokenizer
from tokenizer_TinyStories import *
size = 512 # size options are 128, 256, 512 and 1024
path = f'../tokenizers/tiny_stories_tokenizer_{size}.model'
tokenizer = get_tokenizer(path) 

In [3]:
# config file
from config import *
cfg = Config()
print(cfg, cfg.context_chunk)

# model modules
from model import *

Config(dim=128, num_layers=8, vocab_size=None, device='cpu', mlp_hidden_mult=4, mlp_bias=False, gated=True, nonlinearity='GeLU', num_q_heads=4, num_kv_heads=1, theta=10000, max_seq_len=512, norm_type='RMSNorm', norm_affine=True, norm_bias=False, eps=1e-05, max_batch_size=32, memory_saver_div=8) <bound method Config.context_chunk of Config(dim=128, num_layers=8, vocab_size=None, device='cpu', mlp_hidden_mult=4, mlp_bias=False, gated=True, nonlinearity='GeLU', num_q_heads=4, num_kv_heads=1, theta=10000, max_seq_len=512, norm_type='RMSNorm', norm_affine=True, norm_bias=False, eps=1e-05, max_batch_size=32, memory_saver_div=8)>


# Norms

In [5]:
### RMSNorm

# Create an instance of RMSNorm
hold = cfg.norm_type
cfg.norm_type = 'rmsnorm' # purposely mis-typing it
module = Norm(cfg.dim, cfg)

# Initially, logging is disabled by default
module.enable_logging()

### Optionally disabling printing for sub-functions
#module.disable_function_logging('CosineNorm')
#module.disable_function_logging('LayerNorm')
#module.disable_function_logging('RMSNorm')

x = torch.randn(32,cfg.max_seq_len,cfg.dim)

# Call the forward method - logging will occur
output = module(x)

# Disable logging. 
# This isn't actually necessary since we won't be using this object again but that's how you'd do it
module.disable_logging()

# clearing up ram jic we're training later
cfg.norm_type = hold
del hold, module, x, output

norm type rmsnorm not found. defaulting to RMSNorm

Inputs:
Tensor 'x' shape: torch.Size([32, 512, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 512, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 512, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 512, 128])


In [6]:
# LayerNorm
hold = cfg.norm_type
cfg.norm_type = 'LayerNorm'
module = Norm(cfg.dim, cfg)
module.enable_logging()

### Optionally disabling printing for sub-functions
#module.disable_function_logging('CosineNorm')
#module.disable_function_logging('LayerNorm')
#module.disable_function_logging('RMSNorm')

x = torch.randn(32,cfg.max_seq_len,cfg.dim)
output = module(x)
module.disable_logging()
cfg.norm_type = hold
del hold, module, x, output


Inputs:
Tensor 'x' shape: torch.Size([32, 512, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 512, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 512, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 512, 128])


In [7]:
# CosineNorm
hold = cfg.norm_type
cfg.norm_type = 'CosineNorm'
module = Norm(cfg.dim, cfg)
module.enable_logging()

### Optionally disabling printing for sub-functions
#module.disable_function_logging('CosineNorm')
#module.disable_function_logging('LayerNorm')
#module.disable_function_logging('RMSNorm')

x = torch.randn(32,cfg.max_seq_len,cfg.dim)
output = module(x)
module.disable_logging()
cfg.norm_type = hold
del hold, module, x, output


Inputs:
Tensor 'x' shape: torch.Size([32, 512, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 512, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 512, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 512, 128])


# Attention

In [8]:
# Create an instance of multi-head self-attention
module = MQSA(cfg)

# Initially, logging is disabled by default
module.enable_logging()

### Optionally disabling printing for sub-functions
#module.disable_function_logging('apply_rotary_emb')
#module.disable_function_logging('reshape_for_broadcast')
#module.disable_function_logging('match_headcount')
#module.disable_function_logging('attend')
#module.disable_function_logging('calc_output')

# precompute RoPE frequencies, causal mask, and dummy input data
freqs_cis = precompute_freqs_cis(
    cfg.dim // cfg.num_q_heads,
    cfg.max_seq_len,
    cfg.theta
)
mask = torch.full(
    (cfg.max_seq_len, cfg.max_seq_len),
    float("-inf"),
    device=cfg.device
)
mask = torch.triu(mask, diagonal=1)
x = torch.randn(32,cfg.max_seq_len,cfg.dim)

# Call the forward method - logging will occur
output = module(x, freqs_cis, mask, training=True)

# Disable logging. 
# This isn't actually necessary since we won't be using this object again but that's how you'd do it
module.disable_logging()

# clearing up ram jic we're training later
del module, freqs_cis, mask, x, output


Inputs:
Tensor 'x' shape: torch.Size([32, 512, 128])
Tensor 'freqs_cis' shape: torch.Size([512, 16])
Tensor 'mask' shape: torch.Size([512, 512])
Integer 'cache_len': Value=True

Inputs:
Tensor 'xq' shape: torch.Size([32, 512, 4, 32])
Tensor 'xk' shape: torch.Size([32, 512, 1, 32])
Tensor 'freqs_cis' shape: torch.Size([512, 16])

Inputs:
Tensor 'freqs_cis' shape: torch.Size([512, 16])
Tensor 'x' shape: torch.Size([32, 512, 4, 16])

Outputs:
Tensor 'output' shape: torch.Size([1, 512, 1, 16])

Outputs:
Tensor 'output[0]' shape: torch.Size([32, 512, 4, 32])
Tensor 'output[1]' shape: torch.Size([32, 512, 1, 32])

Inputs:
Tensor 'keys' shape: torch.Size([32, 512, 1, 32])
Tensor 'values' shape: torch.Size([32, 512, 1, 32])

Outputs:
Tensor 'output[0]' shape: torch.Size([32, 512, 4, 32])
Tensor 'output[1]' shape: torch.Size([32, 512, 4, 32])

Inputs:
Tensor 'queries' shape: torch.Size([32, 4, 512, 32])
Tensor 'keys' shape: torch.Size([32, 4, 512, 32])
Integer 'training': Value=True

Outputs:


In [9]:
# now let's do it for inference

module = MQSA(cfg)
module.enable_logging()
#module.disable_function_logging('apply_rotary_emb')
#module.disable_function_logging('reshape_for_broadcast')
#module.disable_function_logging('match_headcount')
#module.disable_function_logging('attend')
#module.disable_function_logging('calc_output')

# precompute RoPE frequencies, causal mask, and dummy input data
freqs_cis = precompute_freqs_cis(
    cfg.dim // cfg.num_q_heads,
    cfg.max_seq_len,
    cfg.theta
)
mask = torch.full(
    (cfg.max_seq_len, cfg.max_seq_len),
    float("-inf"),
    device=cfg.device
)
mask = torch.triu(mask, diagonal=1)

# setting up for kv caching
cache_len = 420
seqlen = cache_len + cfg.context_chunk()
# need to extend the mask with zeros for the cached values
mask = mask[:cfg.context_chunk(), :cfg.context_chunk()]
mask = torch.hstack(
            [torch.zeros((cfg.context_chunk(), cache_len)), mask]
        )

# these don't use seqlen because those entries should already be in the kv cache
freqs_cis = freqs_cis[:cfg.context_chunk()]
x = torch.randn(32,cfg.context_chunk(),cfg.dim)

# Call the forward method - logging will occur
output = module(x, freqs_cis, mask, cache_len)

# Disable logging. 
# This isn't actually necessary since we won't be using this object again but that's how you'd do it
module.disable_logging()

# clearing up ram jic we're training later
del module, freqs_cis, mask, cache_len, seqlen, x, output


Inputs:
Tensor 'x' shape: torch.Size([32, 64, 128])
Tensor 'freqs_cis' shape: torch.Size([64, 16])
Tensor 'mask' shape: torch.Size([64, 484])
Integer 'cache_len': Value=420

Inputs:
Tensor 'xq' shape: torch.Size([32, 64, 4, 32])
Tensor 'xk' shape: torch.Size([32, 64, 1, 32])
Tensor 'freqs_cis' shape: torch.Size([64, 16])

Inputs:
Tensor 'freqs_cis' shape: torch.Size([64, 16])
Tensor 'x' shape: torch.Size([32, 64, 4, 16])

Outputs:
Tensor 'output' shape: torch.Size([1, 64, 1, 16])

Outputs:
Tensor 'output[0]' shape: torch.Size([32, 64, 4, 32])
Tensor 'output[1]' shape: torch.Size([32, 64, 1, 32])

Inputs:
Tensor 'keys' shape: torch.Size([32, 484, 1, 32])
Tensor 'values' shape: torch.Size([32, 484, 1, 32])

Outputs:
Tensor 'output[0]' shape: torch.Size([32, 484, 4, 32])
Tensor 'output[1]' shape: torch.Size([32, 484, 4, 32])

Inputs:
Tensor 'queries' shape: torch.Size([32, 4, 64, 32])
Tensor 'keys' shape: torch.Size([32, 4, 484, 32])
Integer 'training': Value=False

Outputs:
Tensor 'outp

# MLP