# This notebook is designed for teaching/testing purposes to help you visualize the tensor shapes that go through each module

In [1]:
# my virtual environments are rarely properly connected to jupyter so this fixes that. 
# you prolly won't need this cell but running it won't hurt anything either
import sys
import os
current_dir = os.getcwd()  # Get the current working directory
venv_dir = os.path.join(current_dir, 'venv') 
python_version = str(sys.version_info.major) + '.' + str(sys.version_info.minor)
site_packages_path = os.path.join(venv_dir, 'lib', 'python' + python_version, 'site-packages')
sys.path.append(site_packages_path) 

In [2]:
# config file
from config import ModelConfig
cfg = ModelConfig()
print(cfg)

# import the tokenizer specified by cfg
from tools import import_from_nested_path
imported_objects = import_from_nested_path(['tokenizers', cfg.tokenizer], 'tokenizer', ['get_tokenizer'])
get_tokenizer = imported_objects.get('get_tokenizer')
tokenizer = get_tokenizer(size = 512) # assuming 'bpe', size options are 95, 128, 256, 512, 1024 and 2048

import random
import torch

ModelConfig(dim=32, device='cpu', dropout_rate=0.1, weight_tying=True, tokenizer='bpe_v2', vocab_len=1024, num_layers=6, second_resid_norm=False, mlp_hidden_mult=4, mlp_bias=False, mlp_nonlinearity='SiLU', mlp_gated=True, num_q_heads=2, num_kv_heads=1, head_dim=16, theta=10000, max_seq_len=128, ca_num_q_heads=2, ca_num_kv_heads=1, ca_head_dim=16, scale_first_resid=True, norm_type='RMSNorm', norm_affine=True, norm_bias=True, eps=1e-06, pool_type='sum', pre_pool_norm=True, pool_output_linear=False, pool_bias=False, compress_freq='constant', compress_freq_n=1, fs_mult=4, fs_periods=3, fs_loss_lambda=1.0, max_batch_size=1)


# Norms

In [3]:
from modules.norm import Norm

In [4]:
%%time

### RMSNorm

# Create an instance of RMSNorm
module = Norm(cfg.dim, 'RMSNorm').to(cfg.device)

# let's take a look
print(sum(p.numel() for p in module.parameters())/1e3, 'K parameters')
print(module)

# Initially, logging is disabled by default
module.enable_logging()

### Optionally disabling printing for sub-functions
#module.disable_function_logging('CosineNorm')
#module.disable_function_logging('LayerNorm')
#module.disable_function_logging('RMSNorm')

x = torch.randn(32,cfg.max_seq_len,cfg.dim).to(cfg.device)

# Call the forward method - logging will occur
output = module(x)

# Disable logging. 
# This isn't actually necessary since we won't be using this object again but that's how you'd do it
module.disable_logging()

# clearing up ram jic we're training later
del module, x, output

0.064 K parameters
Norm()

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 32])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 32])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 32])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 32])
CPU times: user 3.42 ms, sys: 1.5 ms, total: 4.91 ms
Wall time: 3.57 ms


In [5]:
%%time

# LayerNorm
module = Norm(cfg.dim, 'LayerNorm').to(cfg.device)
module.enable_logging()

### Optionally disabling printing for sub-functions
#module.disable_function_logging('CosineNorm')
#module.disable_function_logging('LayerNorm')
#module.disable_function_logging('RMSNorm')

x = torch.randn(32,cfg.max_seq_len,cfg.dim).to(cfg.device)
output = module(x)
module.disable_logging()
del module, x, output


Inputs:
Tensor 'x' shape: torch.Size([32, 128, 32])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 32])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 32])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 32])
CPU times: user 7.52 ms, sys: 1.48 ms, total: 9 ms
Wall time: 4.56 ms


# Attention

In [6]:
from modules.mqa import MQA, futureSightMQA, precompute_freqs_cis

In [7]:
%%time

# first up let's look at self-attention training

# Create an instance of multi-head self-attention
module = MQA(
        cfg.dim,
        cfg.head_dim,
        cfg.num_q_heads,
        cfg.num_kv_heads,
        cfg.max_seq_len,
        #cfg.max_batch_size, # if you don't pass in a max_batch_size then the module will be incapable of kv caching
).to(cfg.device)
print(sum(p.numel() for p in module.parameters())/1e3, 'K parameters')
print(module)

# Initially, logging is disabled by default
module.enable_logging()

### Optionally disabling printing for sub-functions
#module.disable_function_logging('apply_rotary_emb')
#module.disable_function_logging('reshape_for_broadcast')
#module.disable_function_logging('match_headcount')
#module.disable_function_logging('attend')
#module.disable_function_logging('calc_output')

# precompute RoPE frequencies, causal mask, and dummy input data
freqs_cis = precompute_freqs_cis(
    cfg.head_dim,
    cfg.max_seq_len,
    cfg.theta
).to(cfg.device)
mask = torch.full(
    (cfg.max_seq_len, cfg.max_seq_len),
    float("-inf"),
    device=cfg.device
)
mask = torch.triu(mask, diagonal=1)
x = torch.randn(32,cfg.max_seq_len,cfg.dim).to(cfg.device)

# Call the forward method - logging will occur
output = module(x, x, x, freqs_cis, mask, training=True)

# Disable logging. 
# This isn't actually necessary since we won't be using this object again but that's how you'd do it
module.disable_logging()

# clearing up ram jic we're training later
del module, freqs_cis, mask, x, output

3.072 K parameters
MQA(
  (Wq): Linear(in_features=32, out_features=32, bias=False)
  (Wk): Linear(in_features=32, out_features=16, bias=False)
  (Wv): Linear(in_features=32, out_features=16, bias=False)
  (Wo): Linear(in_features=32, out_features=32, bias=False)
)

Inputs:
Tensor 'q' shape: torch.Size([32, 128, 32])
Tensor 'k' shape: torch.Size([32, 128, 32])
Tensor 'v' shape: torch.Size([32, 128, 32])
Tensor 'freqs_cis' shape: torch.Size([128, 8])
Tensor 'mask' shape: torch.Size([128, 128])
Other-type 'cache_len': Type=NoneType, Value=None
Integer 'training': Value=True

Inputs:
Tensor 'q' shape: torch.Size([32, 128, 2, 16])
Tensor 'k' shape: torch.Size([32, 128, 1, 16])
Tensor 'freqs_cis' shape: torch.Size([128, 8])

Inputs:
Tensor 'freqs_cis' shape: torch.Size([128, 8])
Tensor 'x' shape: torch.Size([32, 128, 2, 8])

Outputs:
Tensor 'output' shape: torch.Size([1, 128, 1, 8])

Outputs:
Tensor 'output[0]' shape: torch.Size([32, 128, 2, 16])
Tensor 'output[1]' shape: torch.Size([32, 12

In [8]:
%%time

# now let's do it for self-attention inference

module = MQA(
        cfg.dim,
        cfg.head_dim,
        cfg.num_q_heads,
        cfg.num_kv_heads,
        cfg.max_seq_len,
        cfg.max_batch_size
).to(cfg.device)
module.enable_logging()
#module.disable_function_logging('apply_rotary_emb')
#module.disable_function_logging('reshape_for_broadcast')
#module.disable_function_logging('match_headcount')
#module.disable_function_logging('attend')
#module.disable_function_logging('calc_output')

# precompute RoPE frequencies, causal mask, and dummy input data
freqs_cis = precompute_freqs_cis(
    cfg.head_dim,
    cfg.max_seq_len,
    cfg.theta
).to(cfg.device)
mask = torch.full(
    (cfg.max_seq_len, cfg.max_seq_len),
    float("-inf"),
    device=cfg.device
)
mask = torch.triu(mask, diagonal=1)

# setting up for kv caching
context_chunk_len = cfg.max_seq_len // 4
cache_len = random.randint(1, 3 * context_chunk_len)
seq_len = cache_len + context_chunk_len
# need to extend the mask with zeros for the cached values
mask = mask[:context_chunk_len, :context_chunk_len]
mask = torch.hstack(
            [torch.zeros((context_chunk_len, cache_len)), mask]
        )

# these don't use seq_len because those entries should already be in the kv cache
freqs_cis = freqs_cis[:context_chunk_len]
x = torch.randn(cfg.max_batch_size,context_chunk_len,cfg.dim).to(cfg.device)

# Call the forward method - logging will occur
output = module(x, x, x, freqs_cis, mask, cache_len)

# Disable logging. 
# This isn't actually necessary since we won't be using this object again but that's how you'd do it
module.disable_logging()

# clearing up ram jic we're training later
del module, freqs_cis, mask, cache_len, context_chunk_len, seq_len, x, output


Inputs:
Tensor 'q' shape: torch.Size([1, 32, 32])
Tensor 'k' shape: torch.Size([1, 32, 32])
Tensor 'v' shape: torch.Size([1, 32, 32])
Tensor 'freqs_cis' shape: torch.Size([32, 8])
Tensor 'mask' shape: torch.Size([32, 33])
Integer 'cache_len': Value=1
Integer 'training': Value=False

Inputs:
Tensor 'q' shape: torch.Size([1, 32, 2, 16])
Tensor 'k' shape: torch.Size([1, 32, 1, 16])
Tensor 'freqs_cis' shape: torch.Size([32, 8])

Inputs:
Tensor 'freqs_cis' shape: torch.Size([32, 8])
Tensor 'x' shape: torch.Size([1, 32, 2, 8])

Outputs:
Tensor 'output' shape: torch.Size([1, 32, 1, 8])

Outputs:
Tensor 'output[0]' shape: torch.Size([1, 32, 2, 16])
Tensor 'output[1]' shape: torch.Size([1, 32, 1, 16])

Inputs:
Tensor 'k' shape: torch.Size([1, 33, 1, 16])
Tensor 'v' shape: torch.Size([1, 33, 1, 16])

Outputs:
Tensor 'output[0]' shape: torch.Size([1, 33, 2, 16])
Tensor 'output[1]' shape: torch.Size([1, 33, 2, 16])

Inputs:
Tensor 'q' shape: torch.Size([1, 2, 32, 16])
Tensor 'k' shape: torch.Size

In [9]:
%%time

# now cross-attention, which should be the same whether doing training or inference

# Create an instance of future sight cross-attention
module = futureSightMQA(
        cfg.dim,
        cfg.head_dim,
        cfg.num_q_heads,
        cfg.num_kv_heads,
        cfg.max_seq_len
).to(cfg.device)
print(sum(p.numel() for p in module.parameters())/1e3, 'K parameters')
print(module)

# Initially, logging is disabled by default
module.enable_logging()

### Optionally disabling printing for sub-functions
#module.disable_function_logging('match_headcount')
#module.disable_function_logging('attend')
#module.disable_function_logging('calc_output')

x = torch.randn(32,cfg.max_seq_len,cfg.dim).to(cfg.device)
z = torch.randn(32,cfg.max_seq_len,2,cfg.dim).to(cfg.device)

# Call the forward method - logging will occur
output = module(x, z, training=True)

# clearing up ram jic we're training later
del module, x, z, output

3.072 K parameters
futureSightMQA(
  (Wq): Linear(in_features=32, out_features=32, bias=False)
  (Wk): Linear(in_features=32, out_features=16, bias=False)
  (Wv): Linear(in_features=32, out_features=16, bias=False)
  (Wo): Linear(in_features=32, out_features=32, bias=False)
)

Inputs:
Tensor 'q' shape: torch.Size([32, 128, 32])
Tensor 'kv' shape: torch.Size([32, 128, 2, 32])
Integer 'training': Value=True

Inputs:
Tensor 'k' shape: torch.Size([4096, 2, 1, 16])
Tensor 'v' shape: torch.Size([4096, 2, 1, 16])

Outputs:
Tensor 'output[0]' shape: torch.Size([4096, 2, 2, 16])
Tensor 'output[1]' shape: torch.Size([4096, 2, 2, 16])

Inputs:
Tensor 'q' shape: torch.Size([4096, 2, 1, 16])
Tensor 'k' shape: torch.Size([4096, 2, 2, 16])

Outputs:
Tensor 'output' shape: torch.Size([4096, 2, 1, 2])

Inputs:
Tensor 'logits' shape: torch.Size([4096, 2, 1, 2])
Tensor 'v' shape: torch.Size([4096, 2, 2, 16])
Integer 'training': Value=True

Outputs:
Tensor 'output' shape: torch.Size([4096, 1, 32])

Output

# MLP

In [10]:
from modules.mlp import MLP

In [11]:
%%time

# GeGLU
module = MLP(
    cfg.dim, 
    int(cfg.dim * cfg.mlp_hidden_mult * 2/3), 
    cfg.dim, 
    'GeLU', 
    gated=True, 
    bias=False, 
    dropout_rate = 0.1
).to(cfg.device)
print(sum(p.numel() for p in module.parameters())/1e3, 'K parameters')
print(module)
module.enable_logging()

x = torch.randn(32,cfg.max_seq_len,cfg.dim).to(cfg.device)
output = module(x, training=True)
module.disable_logging()
del module, x, output

8.16 K parameters
MLP(
  (Wup): Linear(in_features=32, out_features=85, bias=False)
  (Wgate): Linear(in_features=32, out_features=85, bias=False)
  (Wdown): Linear(in_features=85, out_features=32, bias=False)
  (nonlinearity): GELU(approximate='none')
)

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 32])
Integer 'training': Value=True

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 32])
CPU times: user 7.71 ms, sys: 3.11 ms, total: 10.8 ms
Wall time: 6.75 ms


In [12]:
%%time

# not gated, testing every other nonlinearity
module = MLP(
    cfg.dim, 
    cfg.dim * cfg.mlp_hidden_mult, 
    cfg.dim, 
    'ReLU', 
    gated=False, 
    bias=False, 
    dropout_rate = 0.1
).to(cfg.device)
print(sum(p.numel() for p in module.parameters())/1e3, 'K parameters')
print(module)
module.enable_logging()

x = torch.randn(32,cfg.max_seq_len,cfg.dim).to(cfg.device)
output = module(x, training=True)
module.disable_logging()
del module, x, output

8.192 K parameters
MLP(
  (Wup): Linear(in_features=32, out_features=128, bias=False)
  (Wdown): Linear(in_features=128, out_features=32, bias=False)
  (nonlinearity): ReLU()
)

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 32])
Integer 'training': Value=True

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 32])
CPU times: user 6.4 ms, sys: 2.12 ms, total: 8.52 ms
Wall time: 6.51 ms


# ResidualLayer

In [13]:
from modules.layer import Layer

In [14]:
%%time

# TRAINING w/ only self-attention
module = Layer(cfg).to(cfg.device)
print(sum(p.numel() for p in module.parameters())/1e3, 'K parameters')
print(module)

module.enable_logging()
#module.disable_function_logging('self_attn_connect')
#module.disable_function_logging('mlp_connect')
### enabling printing for sub-modules
#module.pre_attn_norm.enable_logging()
#module.attn.enable_logging()
#module.post_attn_norm.enable_logging()
#module.pre_mlp_norm.enable_logging()
#module.mlp.enable_logging()
#module.post_mlp_norm.enable_logging()

# precompute RoPE frequencies, causal mask, and dummy input data
freqs_cis = precompute_freqs_cis(
    cfg.head_dim,
    cfg.max_seq_len,
    cfg.theta
).to(cfg.device)
mask = torch.full(
    (cfg.max_seq_len, cfg.max_seq_len),
    float("-inf"),
    device=cfg.device
)
mask = torch.triu(mask, diagonal=1)
x = torch.randn(32,cfg.max_seq_len,cfg.dim).to(cfg.device)

output = module(x, freqs_cis, mask, training=True)
module.disable_logging()
del module, freqs_cis, mask, x, output

11.36 K parameters
Layer(
  (pre_self_attn_norm): Norm()
  (self_attn): MQA(
    (Wq): Linear(in_features=32, out_features=32, bias=False)
    (Wk): Linear(in_features=32, out_features=16, bias=False)
    (Wv): Linear(in_features=32, out_features=16, bias=False)
    (Wo): Linear(in_features=32, out_features=32, bias=False)
  )
  (pre_mlp_norm): Norm()
  (mlp): MLP(
    (Wup): Linear(in_features=32, out_features=85, bias=False)
    (Wgate): Linear(in_features=32, out_features=85, bias=False)
    (Wdown): Linear(in_features=85, out_features=32, bias=False)
    (nonlinearity): SiLU()
  )
)

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 32])
Tensor 'freqs_cis' shape: torch.Size([128, 8])
Tensor 'mask' shape: torch.Size([128, 128])
Other-type 'cache_len': Type=NoneType, Value=None
Other-type 'y': Type=NoneType, Value=None
Integer 'training': Value=True

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 32])
Tensor 'freqs_cis' shape: torch.Size([128, 8])
Tensor 'mask' shape: torch.Size([128, 1

In [15]:
%%time

# INFERENCE
module = Layer(cfg).to(cfg.device)
print(sum(p.numel() for p in module.parameters())/1e3, 'K parameters')
print(module)

module.enable_logging()
#module.disable_function_logging('self_attn_connect')
#module.disable_function_logging('mlp_connect')
#module.pre_attn_norm.enable_logging()
#module.attn.enable_logging()
#module.post_attn_norm.enable_logging()
#module.pre_mlp_norm.enable_logging()
#module.mlp.enable_logging()
#module.post_mlp_norm.enable_logging()

# precompute RoPE frequencies, causal mask, and dummy input data
freqs_cis = precompute_freqs_cis(
    cfg.head_dim,
    cfg.max_seq_len,
    cfg.theta
).to(cfg.device)
mask = torch.full(
    (cfg.max_seq_len, cfg.max_seq_len),
    float("-inf"),
    device=cfg.device
)
mask = torch.triu(mask, diagonal=1)
# setting up for kv caching
cache_len = cfg.max_seq_len // 3
context_chunk_len = cfg.max_seq_len // 4
seq_len = cache_len + context_chunk_len
# need to extend the mask with zeros for the cached values
mask = mask[:context_chunk_len, :context_chunk_len]
mask = torch.hstack(
            [torch.zeros((context_chunk_len, cache_len)), mask]
        )
# these don't use seq_len because those entries should already be in the kv cache
freqs_cis = freqs_cis[:context_chunk_len]
x = torch.randn(1,context_chunk_len,cfg.dim).to(cfg.device)

output = module(x, freqs_cis, mask, cache_len)
module.disable_logging()
del module, freqs_cis, mask, cache_len, context_chunk_len, seq_len, x, output

11.36 K parameters
Layer(
  (pre_self_attn_norm): Norm()
  (self_attn): MQA(
    (Wq): Linear(in_features=32, out_features=32, bias=False)
    (Wk): Linear(in_features=32, out_features=16, bias=False)
    (Wv): Linear(in_features=32, out_features=16, bias=False)
    (Wo): Linear(in_features=32, out_features=32, bias=False)
  )
  (pre_mlp_norm): Norm()
  (mlp): MLP(
    (Wup): Linear(in_features=32, out_features=85, bias=False)
    (Wgate): Linear(in_features=32, out_features=85, bias=False)
    (Wdown): Linear(in_features=85, out_features=32, bias=False)
    (nonlinearity): SiLU()
  )
)

Inputs:
Tensor 'x' shape: torch.Size([1, 32, 32])
Tensor 'freqs_cis' shape: torch.Size([32, 8])
Tensor 'mask' shape: torch.Size([32, 74])
Integer 'cache_len': Value=42
Other-type 'y': Type=NoneType, Value=None
Integer 'training': Value=False

Inputs:
Tensor 'x' shape: torch.Size([1, 32, 32])
Tensor 'freqs_cis' shape: torch.Size([32, 8])
Tensor 'mask' shape: torch.Size([32, 74])
Integer 'cache_len': Val

In [16]:
%%time

# Layer w/ future sight cross-attention and w/ kv caching in the self-attention enabled while TRAINING
module = Layer(cfg, cross_attn=True).to(cfg.device)
print(sum(p.numel() for p in module.parameters())/1e3, 'K parameters')
print(module)

module.enable_logging()
#module.disable_function_logging('self_attn_connect')
#module.disable_function_logging('future_sight_connect')
#module.disable_function_logging('mlp_connect')
### enabling printing for sub-modules
#module.pre_self_attn_norm.enable_logging()
#module.self_attn.enable_logging()
#module.post_self_attn_norm.enable_logging()
#module.pre_future_sight_norm_x.enable_logging()
#module.pre_future_sight_norm_z.enable_logging()
#module.future_sight.enable_logging()
#module.post_future_sight_norm.enable_logging()
#module.pre_mlp_norm.enable_logging()
#module.mlp.enable_logging()
#module.post_mlp_norm.enable_logging()

# precompute RoPE frequencies, causal mask, and dummy input data
freqs_cis = precompute_freqs_cis(
    cfg.head_dim,
    cfg.max_seq_len,
    cfg.theta
).to(cfg.device)
mask = torch.full(
    (cfg.max_seq_len, cfg.max_seq_len),
    float("-inf"),
    device=cfg.device
)
mask = torch.triu(mask, diagonal=1)
x = torch.randn(32,cfg.max_seq_len,cfg.dim).to(cfg.device)
y = torch.randn(32,cfg.max_seq_len, 3, cfg.dim).to(cfg.device)

output = module(x, freqs_cis, mask, y=y, training=True)
module.disable_logging()
del module, freqs_cis, mask, x, y, output

14.56 K parameters
Layer(
  (pre_self_attn_norm): Norm()
  (self_attn): MQA(
    (Wq): Linear(in_features=32, out_features=32, bias=False)
    (Wk): Linear(in_features=32, out_features=16, bias=False)
    (Wv): Linear(in_features=32, out_features=16, bias=False)
    (Wo): Linear(in_features=32, out_features=32, bias=False)
  )
  (pre_future_sight_norm_x): Norm()
  (pre_future_sight_norm_y): Norm()
  (future_sight): futureSightMQA(
    (Wq): Linear(in_features=32, out_features=32, bias=False)
    (Wk): Linear(in_features=32, out_features=16, bias=False)
    (Wv): Linear(in_features=32, out_features=16, bias=False)
    (Wo): Linear(in_features=32, out_features=32, bias=False)
  )
  (pre_mlp_norm): Norm()
  (mlp): MLP(
    (Wup): Linear(in_features=32, out_features=85, bias=False)
    (Wgate): Linear(in_features=32, out_features=85, bias=False)
    (Wdown): Linear(in_features=85, out_features=32, bias=False)
    (nonlinearity): SiLU()
  )
)

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 

In [17]:
%%time

# Layer w/ future sight cross-attention and w/o kv caching in the self-attention enabled during INFERENCE
module = Layer(cfg, cross_attn = True, kv_cache = False).to(cfg.device)
print(sum(p.numel() for p in module.parameters())/1e3, 'K parameters')
print(module)

module.enable_logging()
#module.disable_function_logging('self_attn_connect')
#module.disable_function_logging('future_sight_connect')
#module.disable_function_logging('mlp_connect')
#module.pre_self_attn_norm.enable_logging()
#module.self_attn.enable_logging()
#module.post_self_attn_norm.enable_logging()
#module.pre_future_sight_norm_x.enable_logging()
#module.pre_future_sight_norm_z.enable_logging()
#module.future_sight.enable_logging()
#module.post_future_sight_norm.enable_logging()
#module.pre_mlp_norm.enable_logging()
#module.mlp.enable_logging()
#module.post_mlp_norm.enable_logging()

# precompute RoPE frequencies, causal mask, and dummy input data
freqs_cis = precompute_freqs_cis(cfg.head_dim,
                                 cfg.max_seq_len,
                                 cfg.theta).to(cfg.device)
mask = torch.full((cfg.max_seq_len, cfg.max_seq_len),
                  float("-inf"),
                  device=cfg.device)
mask = torch.triu(mask, diagonal=1)
x = torch.randn(32,cfg.max_seq_len,cfg.dim).to(cfg.device)
y = torch.randn(32,cfg.max_seq_len, 3, cfg.dim).to(cfg.device)

output = module(x, freqs_cis, mask, y=y)
module.disable_logging()
del module, freqs_cis, mask, x, y, output

14.56 K parameters
Layer(
  (pre_self_attn_norm): Norm()
  (self_attn): MQA(
    (Wq): Linear(in_features=32, out_features=32, bias=False)
    (Wk): Linear(in_features=32, out_features=16, bias=False)
    (Wv): Linear(in_features=32, out_features=16, bias=False)
    (Wo): Linear(in_features=32, out_features=32, bias=False)
  )
  (pre_future_sight_norm_x): Norm()
  (pre_future_sight_norm_y): Norm()
  (future_sight): futureSightMQA(
    (Wq): Linear(in_features=32, out_features=32, bias=False)
    (Wk): Linear(in_features=32, out_features=16, bias=False)
    (Wv): Linear(in_features=32, out_features=16, bias=False)
    (Wo): Linear(in_features=32, out_features=32, bias=False)
  )
  (pre_mlp_norm): Norm()
  (mlp): MLP(
    (Wup): Linear(in_features=32, out_features=85, bias=False)
    (Wgate): Linear(in_features=32, out_features=85, bias=False)
    (Wdown): Linear(in_features=85, out_features=32, bias=False)
    (nonlinearity): SiLU()
  )
)

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 

# Full Model

In [18]:
from modules.model import Model

In [19]:
%%time

# TRAINING
module = Model(cfg).to(cfg.device)
print(sum(p.numel() for p in module.parameters())/1e3, 'K parameters')
print(module)

module.enable_logging()
### enabling printing for sub-modules
module.body_layers[0].enable_logging()
module.body_layers[0].disable_function_logging('self_attn_connect') # disabling some functions of some sub-modules
module.body_layers[0].disable_function_logging('mlp_connect')
module.first_fs_layer.enable_logging()
module.first_fs_layer.disable_function_logging('mlp_connect')
module.fs_layers[0].enable_logging()
module.fs_layers[0].disable_function_logging('mlp_connect')
module.fs_layers[-1].enable_logging()
module.fs_layers[-1].disable_function_logging('mlp_connect')

input_token_ids = torch.randint(tokenizer.vocab_len, (32, cfg.max_seq_len)).to(cfg.device)
target_token_ids = torch.randint(tokenizer.vocab_len, (32, cfg.max_seq_len)).to(cfg.device)

output, loss = module(input_token_ids, target_token_ids=target_token_ids)
print(output.shape, loss)
del module, input_token_ids, target_token_ids, output

110.752 K parameters
Model(
  (token_embedder): Embedding(1027, 32)
  (body_layers): ModuleList(
    (0-2): 3 x Layer(
      (pre_self_attn_norm): Norm()
      (self_attn): MQA(
        (Wq): Linear(in_features=32, out_features=32, bias=False)
        (Wk): Linear(in_features=32, out_features=16, bias=False)
        (Wv): Linear(in_features=32, out_features=16, bias=False)
        (Wo): Linear(in_features=32, out_features=32, bias=False)
      )
      (pre_mlp_norm): Norm()
      (mlp): MLP(
        (Wup): Linear(in_features=32, out_features=85, bias=False)
        (Wgate): Linear(in_features=32, out_features=85, bias=False)
        (Wdown): Linear(in_features=85, out_features=32, bias=False)
        (nonlinearity): SiLU()
      )
    )
  )
  (first_fs_layer): Layer(
    (pre_self_attn_norm): Norm()
    (self_attn): MQA(
      (Wq): Linear(in_features=32, out_features=32, bias=False)
      (Wk): Linear(in_features=32, out_features=16, bias=False)
      (Wv): Linear(in_features=32, out_

In [20]:
%%time

# Inference w/out kv caching
module = Model(cfg).to(cfg.device)
print(sum(p.numel() for p in module.parameters())/1e3, 'K parameters')
print(module)

module.enable_logging()
### enabling printing for sub-modules
module.body_layers[0].enable_logging()
module.body_layers[0].disable_function_logging('self_attn_connect') # disabling some functions of some sub-modules
module.body_layers[0].disable_function_logging('mlp_connect')
module.first_fs_layer.enable_logging()
module.first_fs_layer.disable_function_logging('mlp_connect')
module.fs_layers[0].enable_logging()
module.fs_layers[0].disable_function_logging('mlp_connect')
module.fs_layers[-1].enable_logging()
module.fs_layers[-1].disable_function_logging('mlp_connect')

input_token_ids = torch.randint(tokenizer.vocab_len, (1, cfg.max_seq_len // 4)).to(cfg.device)

output, _ = module(input_token_ids)
print(output.shape)
del module, input_token_ids, output

110.752 K parameters
Model(
  (token_embedder): Embedding(1027, 32)
  (body_layers): ModuleList(
    (0-2): 3 x Layer(
      (pre_self_attn_norm): Norm()
      (self_attn): MQA(
        (Wq): Linear(in_features=32, out_features=32, bias=False)
        (Wk): Linear(in_features=32, out_features=16, bias=False)
        (Wv): Linear(in_features=32, out_features=16, bias=False)
        (Wo): Linear(in_features=32, out_features=32, bias=False)
      )
      (pre_mlp_norm): Norm()
      (mlp): MLP(
        (Wup): Linear(in_features=32, out_features=85, bias=False)
        (Wgate): Linear(in_features=32, out_features=85, bias=False)
        (Wdown): Linear(in_features=85, out_features=32, bias=False)
        (nonlinearity): SiLU()
      )
    )
  )
  (first_fs_layer): Layer(
    (pre_self_attn_norm): Norm()
    (self_attn): MQA(
      (Wq): Linear(in_features=32, out_features=32, bias=False)
      (Wk): Linear(in_features=32, out_features=16, bias=False)
      (Wv): Linear(in_features=32, out_

In [21]:
%%time

# Inference w/ kv caching
### low key i'm thinking i just abandon kv-caching on this model for now and only bring it back if/when i decide to keep it.
#   pretty sure it wouldn't be a difficult fix or anything but i straight up don't wanna do it
module = Model(cfg).to(cfg.device)
print(sum(p.numel() for p in module.parameters())/1e3, 'K parameters')
print(module)

module.enable_logging()
### enabling printing for sub-modules
module.body_layers[0].enable_logging()
module.body_layers[0].disable_function_logging('self_attn_connect') # disabling some functions of some sub-modules
module.body_layers[0].disable_function_logging('mlp_connect')
module.first_fs_layer.enable_logging()
module.first_fs_layer.disable_function_logging('mlp_connect')
module.fs_layers[0].enable_logging()
module.fs_layers[0].disable_function_logging('mlp_connect')
module.fs_layers[-1].enable_logging()
module.fs_layers[-1].disable_function_logging('mlp_connect')

input_token_ids = torch.randint(tokenizer.vocab_len, (1, cfg.max_seq_len // 4)).to(cfg.device)

output, _ = module(input_token_ids, cache_len = cfg.max_seq_len // 3)
print(output.shape)
del module, input_token_ids, output

110.752 K parameters
Model(
  (token_embedder): Embedding(1027, 32)
  (body_layers): ModuleList(
    (0-2): 3 x Layer(
      (pre_self_attn_norm): Norm()
      (self_attn): MQA(
        (Wq): Linear(in_features=32, out_features=32, bias=False)
        (Wk): Linear(in_features=32, out_features=16, bias=False)
        (Wv): Linear(in_features=32, out_features=16, bias=False)
        (Wo): Linear(in_features=32, out_features=32, bias=False)
      )
      (pre_mlp_norm): Norm()
      (mlp): MLP(
        (Wup): Linear(in_features=32, out_features=85, bias=False)
        (Wgate): Linear(in_features=32, out_features=85, bias=False)
        (Wdown): Linear(in_features=85, out_features=32, bias=False)
        (nonlinearity): SiLU()
      )
    )
  )
  (first_fs_layer): Layer(
    (pre_self_attn_norm): Norm()
    (self_attn): MQA(
      (Wq): Linear(in_features=32, out_features=32, bias=False)
      (Wk): Linear(in_features=32, out_features=16, bias=False)
      (Wv): Linear(in_features=32, out_

RuntimeError: The size of tensor a (32) must match the size of tensor b (74) at non-singleton dimension 3