# This notebook is designed for teaching/testing purposes to help you visualize the tensor shapes that go through each module

In [1]:
# my virtual environments are rarely properly connected to jupyter so this fixes that. 
# you prolly won't need this cell but running it won't hurt anything either
import sys
import os
current_dir = os.getcwd()  # Get the current working directory
venv_dir = os.path.join(current_dir, 'venv') 
python_version = str(sys.version_info.major) + '.' + str(sys.version_info.minor)
site_packages_path = os.path.join(venv_dir, 'lib', 'python' + python_version, 'site-packages')
sys.path.append(site_packages_path) 

In [2]:
# config file
from config import ModelConfig, TrainConfig
cfg = ModelConfig()
tcfg = TrainConfig()
print(cfg)
print(tcfg)

# import the tokenizer specified by cfg
from tools import import_from_nested_path
imported_objects = import_from_nested_path(['custom_tokenizers', cfg.tokenizer], 'tokenizer', ['get_tokenizer'])
get_tokenizer = imported_objects.get('get_tokenizer')
tokenizer = get_tokenizer(size = 512) # assuming 'bpe', size options are 512, 1024 and 2048

import random
import torch

ModelConfig(dim=8, device='mps', out_weight_share=True, linear_bias=False, tokenizer='bpe_tinyStories', vocab_len=512, num_layers=2, second_resid_norm=False, mlp_hidden_mult=4, mlp_nonlinearity='SiLU', mlp_gated=True, num_q_heads=2, num_kv_heads=1, head_dim=4, theta=10000, max_seq_len=10, scale_first_resid=True, norm_type='RMSNorm', norm_affine=True, norm_bias=True, eps=1e-06)
TrainConfig(model_name='2024-06-30|21-25-01', dataset_name='noanabeshima/TinyStoriesV2', data_subset=None, streaming=False, micro_batch_size=4, grad_accum_steps=2, max_iters=4, eval_interval=2, eval_samples=1, checkpoint_interval=None, beta1=0.9, beta2=0.95, epsilon=1e-08, weight_decay=0.05, grad_clip=1.0, lr_init=1e-06, lr_max=0.01, lr_min=0.0001, warmup_iters=0, final_flat_iters=0, anneal_type='cos', num_restarts=0, T_mult=2)


# Norms

In [3]:
from modules.norm import Norm

In [4]:
%%time

### RMSNorm

# Create an instance of RMSNorm
module = Norm(cfg.dim, 'RMSNorm').to(cfg.device)

# let's take a look
print("number of parameters: %.2fK" % (module.get_num_params()/1e3,))
print(module)

# Initially, logging is disabled by default
module.enable_logging()

### Optionally disabling printing for sub-functions
#module.disable_function_logging('CosineNorm')
#module.disable_function_logging('LayerNorm')
#module.disable_function_logging('RMSNorm')

x = torch.randn(tcfg.micro_batch_size,cfg.max_seq_len,cfg.dim).to(cfg.device)

# Call the forward method - logging will occur
output = module(x)

# Disable logging. 
# This isn't actually necessary since we won't be using this object again but that's how you'd do it
module.disable_logging()

# clearing up ram jic we're training later
del module, x, output

number of parameters: 0.02K
Norm()

Inputs:
Tensor 'x' shape: torch.Size([4, 10, 8]), dtype: torch.float32, device: mps:0, min/max: -2.465/2.547

Inputs:
Tensor 'x' shape: torch.Size([4, 10, 8]), dtype: torch.float32, device: mps:0, min/max: -2.465/2.547

Outputs:
Tensor 'output' shape: torch.Size([4, 10, 8]), dtype: torch.float32, device: mps:0, min/max: -2.294/2.624

Outputs:
Tensor 'output' shape: torch.Size([4, 10, 8]), dtype: torch.float32, device: mps:0, min/max: -2.294/2.624
CPU times: user 62.6 ms, sys: 40 ms, total: 103 ms
Wall time: 110 ms


In [5]:
%%time

# LayerNorm
module = Norm(cfg.dim, 'LayerNorm').to(cfg.device)
module.enable_logging()

# you can also have it optionally print out all tensors in full
module.enable_full_tensor_printing()
# i recommend only doing this with very small toy values for your hyperparameters, otherwise this gets too big

### Optionally disabling printing for sub-functions
#module.disable_function_logging('CosineNorm')
#module.disable_function_logging('LayerNorm')
#module.disable_function_logging('RMSNorm')

x = torch.randn(tcfg.micro_batch_size,cfg.max_seq_len,cfg.dim).to(cfg.device)
output = module(x)
module.disable_logging()
del module, x, output


Inputs:
Tensor 'x' shape: torch.Size([4, 10, 8]), dtype: torch.float32, device: mps:0, min/max: -2.912/2.715
Full tensor content:
tensor([[[ 1.2539e+00,  1.1740e+00, -3.1399e-01, -4.2712e-02, -1.2338e+00,
          -5.3306e-01, -8.4329e-02, -3.9619e-01],
         [-4.9208e-02, -7.3625e-01, -1.8756e+00,  3.2498e-01, -1.1523e+00,
          -1.0863e+00, -9.6733e-01,  6.0719e-01],
         [-5.8460e-01,  5.9410e-01, -9.7479e-01, -4.2559e-01,  1.8048e+00,
          -3.7047e-01,  8.9476e-02,  1.7904e+00],
         [ 1.8794e-01, -8.1233e-01, -9.5944e-01, -9.0259e-01, -1.4441e+00,
          -1.9472e+00, -3.1793e-01,  1.0997e+00],
         [ 6.5820e-01,  2.2273e+00, -1.4156e+00,  1.5682e-02, -1.3306e-01,
           6.6103e-01, -2.6299e-01, -2.3854e-01],
         [ 5.4644e-02,  3.9237e-01, -6.1842e-01, -7.6074e-01, -3.7094e-01,
          -6.9196e-01,  1.7675e+00,  3.7976e-01],
         [ 9.1040e-03,  5.2133e-01,  3.6259e-01, -2.5275e-01,  6.6164e-01,
           8.2407e-01, -5.6123e-01, -1.1834e

# Attention

In [6]:
from modules.attention import SelfAttention, PrecomputeRotaryFrequencies

In [7]:
%%time

# first up let's look at training

# Create an instance of multi-head self-attention
module = SelfAttention(cfg.dim, cfg.head_dim, cfg.num_q_heads, cfg.num_kv_heads, cfg.max_seq_len, cfg.linear_bias, device=cfg.device)
print("number of parameters: %.2fK" % (module.get_num_params()/1e3,))
print(module)

# Initially, logging is disabled by default
module.enable_logging()

# optionally enabling printing of every single input/output tensor
#module.enable_full_tensor_printing()

### Optionally disabling printing for sub-functions
#module.disable_function_logging('apply_precompute_freqs')
#module.disable_function_logging('reshape_for_broadcast')
#module.disable_function_logging('match_headcount')
#module.disable_function_logging('attend')
#module.disable_function_logging('calc_output')

# precompute RoPE frequencies, causal mask, and dummy input data
precompute_freqs = PrecomputeRotaryFrequencies(cfg.head_dim, cfg.max_seq_len, cfg.theta, cfg.device)
freqs = precompute_freqs()
mask = torch.ones(cfg.max_seq_len, cfg.max_seq_len, dtype=torch.bool, device=cfg.device).triu(diagonal=1)
x = torch.randn(tcfg.micro_batch_size,cfg.max_seq_len,cfg.dim).to(cfg.device)

# Call the forward method - logging will occur
output = module(x, freqs, mask, training=True)

# Disable logging. 
# This isn't actually necessary since we won't be using this object again but that's how you'd do it
module.disable_logging()

# clearing up ram jic we're training later
del module, freqs, mask, x, output

number of parameters: 0.19K
SelfAttention(
  (Wq): Linear(in_features=8, out_features=8, bias=False)
  (Wk): Linear(in_features=8, out_features=4, bias=False)
  (Wv): Linear(in_features=8, out_features=4, bias=False)
  (Wo): Linear(in_features=8, out_features=8, bias=False)
)

Inputs:
Tensor 'x' shape: torch.Size([4, 10, 8]), dtype: torch.float32, device: mps:0, min/max: -2.691/3.249
Dict 'freqs':
    Tensor 'freqs[cos]' shape: torch.Size([1, 10, 1, 4]), dtype: torch.float32, device: mps:0, min/max: -0.990/1.000
    Tensor 'freqs[sin]' shape: torch.Size([1, 10, 1, 4]), dtype: torch.float32, device: mps:0, min/max: -0.959/0.989
Tensor 'mask' shape: torch.Size([10, 10]), dtype: torch.bool, device: mps:0, min/max: 0.000/1.000
Other-type 'cache_len': Type=NoneType, Value=None
Other-type 'kv_cache': Type=NoneType, Value=None
Bool 'training': Value=True

Inputs:
Tensor 'q' shape: torch.Size([4, 10, 2, 4]), dtype: torch.float32, device: mps:0, min/max: -1.522/1.880
Tensor 'k' shape: torch.Siz

In [8]:
%%time

# now let's do it for inference

module = SelfAttention(cfg.dim, cfg.head_dim, cfg.num_q_heads, cfg.num_kv_heads, cfg.max_seq_len, cfg.linear_bias, device=cfg.device)
module.enable_logging()
#module.disable_function_logging('apply_precompute_freqs')
#module.disable_function_logging('reshape_for_broadcast')
#module.disable_function_logging('match_headcount')
#module.disable_function_logging('attend')
#module.disable_function_logging('calc_output')

# optionally enabling printing of every single input/output tensor
#module.enable_full_tensor_printing()

# precompute RoPE frequencies, causal mask, and dummy input data
precompute_freqs = PrecomputeRotaryFrequencies(cfg.head_dim, cfg.max_seq_len, cfg.theta, cfg.device)
freqs_cis = precompute_freqs()
mask = torch.ones(cfg.max_seq_len, cfg.max_seq_len, dtype=torch.bool, device=cfg.device).triu(diagonal=1)
# setting up for kv caching
context_chunk_len = cfg.max_seq_len // 4
cache_len = random.randint(1, 3 * context_chunk_len)
seq_len = cache_len + context_chunk_len
kv_cache = {
    'k': torch.zeros((tcfg.micro_batch_size, cfg.max_seq_len, cfg.num_kv_heads, cfg.head_dim), device=cfg.device),
    'v': torch.zeros((tcfg.micro_batch_size, cfg.max_seq_len, cfg.num_kv_heads, cfg.head_dim), device=cfg.device)
}
# need to extend the mask with zeros for the cached values
mask = torch.nn.functional.pad(mask[:context_chunk_len, :context_chunk_len], (cache_len, 0, 0, 0), value=False).bool()
x = torch.randn(tcfg.micro_batch_size,context_chunk_len,cfg.dim).to(cfg.device)

# Call the forward method - logging will occur
output = module(x, freqs_cis, mask, cache_len, kv_cache)

# Disable logging. 
# This isn't actually necessary since we won't be using this object again but that's how you'd do it
module.disable_logging()

# clearing up ram jic we're training later
del module, freqs_cis, mask, cache_len, context_chunk_len, seq_len, kv_cache, x, output


Inputs:
Tensor 'x' shape: torch.Size([4, 2, 8]), dtype: torch.float32, device: mps:0, min/max: -2.401/2.478
Dict 'freqs':
    Tensor 'freqs[cos]' shape: torch.Size([1, 10, 1, 4]), dtype: torch.float32, device: mps:0, min/max: -0.990/1.000
    Tensor 'freqs[sin]' shape: torch.Size([1, 10, 1, 4]), dtype: torch.float32, device: mps:0, min/max: -0.959/0.989
Tensor 'mask' shape: torch.Size([2, 5]), dtype: torch.bool, device: mps:0, min/max: 0.000/1.000
Integer 'cache_len': Value=3
Dict 'kv_cache':
    Tensor 'kv_cache[k]' shape: torch.Size([4, 10, 1, 4]), dtype: torch.float32, device: mps:0, min/max: 0.000/0.000
    Tensor 'kv_cache[v]' shape: torch.Size([4, 10, 1, 4]), dtype: torch.float32, device: mps:0, min/max: 0.000/0.000
Bool 'training': Value=False

Inputs:
Tensor 'q' shape: torch.Size([4, 2, 2, 4]), dtype: torch.float32, device: mps:0, min/max: -1.750/1.798
Tensor 'k' shape: torch.Size([4, 2, 1, 4]), dtype: torch.float32, device: mps:0, min/max: -1.492/1.116
Dict 'freqs_cis':
    T

# Multi-Layer Perceptron

In [9]:
from modules.mlp import MLP

In [10]:
%%time

# GeGLU
module = MLP(
    cfg.dim, 
    int(cfg.dim * cfg.mlp_hidden_mult * 2/3), 
    cfg.dim, 
    'GeLU', 
    gated=True, 
    bias=cfg.linear_bias, 
    dropout_rate = 0.1
).to(cfg.device)
print("number of parameters: %.2fK" % (module.get_num_params()/1e3,))
print(module)
module.enable_logging()

# optionally enabling printing of every single input/output tensor
#module.enable_full_tensor_printing()

x = torch.randn(tcfg.micro_batch_size,cfg.max_seq_len,cfg.dim).to(cfg.device)
output = module(x, training=True)
module.disable_logging()
del module, x, output

number of parameters: 0.50K
MLP(
  (Wup): Linear(in_features=8, out_features=21, bias=False)
  (Wgate): Linear(in_features=8, out_features=21, bias=False)
  (Wdown): Linear(in_features=21, out_features=8, bias=False)
  (nonlinearity): GELU(approximate='none')
)

Inputs:
Tensor 'x' shape: torch.Size([4, 10, 8]), dtype: torch.float32, device: mps:0, min/max: -2.388/2.765
Bool 'training': Value=True

Outputs:
Tensor 'output' shape: torch.Size([4, 10, 8]), dtype: torch.float32, device: mps:0, min/max: -0.461/0.734
CPU times: user 26.1 ms, sys: 3.76 ms, total: 29.8 ms
Wall time: 30.1 ms


In [11]:
%%time

# not gated, testing every other nonlinearity
module = MLP(
    cfg.dim, 
    cfg.dim * cfg.mlp_hidden_mult, 
    cfg.dim, 
    'ReLU', 
    gated=False, 
    bias=True, 
    dropout_rate = 0.1
).to(cfg.device)
print("number of parameters: %.2fK" % (module.get_num_params()/1e3,))
print(module)
module.enable_logging()

# optionally enabling printing of every single input/output tensor
#module.enable_full_tensor_printing()

x = torch.randn(tcfg.micro_batch_size,cfg.max_seq_len,cfg.dim).to(cfg.device)
output = module(x, training=True)
module.disable_logging()
del module, x, output

number of parameters: 0.55K
MLP(
  (Wup): Linear(in_features=8, out_features=32, bias=True)
  (Wdown): Linear(in_features=32, out_features=8, bias=True)
  (nonlinearity): ReLU()
)

Inputs:
Tensor 'x' shape: torch.Size([4, 10, 8]), dtype: torch.float32, device: mps:0, min/max: -3.319/2.575
Bool 'training': Value=True

Outputs:
Tensor 'output' shape: torch.Size([4, 10, 8]), dtype: torch.float32, device: mps:0, min/max: -0.687/1.157
CPU times: user 27.3 ms, sys: 3.71 ms, total: 31 ms
Wall time: 30.9 ms


# ResidualLayer

In [12]:
from modules.layer import Layer

In [13]:
%%time

# TRAINING
module = Layer(cfg).to(cfg.device)
print("number of parameters: %.2fK" % (module.get_num_params()/1e3,))
print(module)

module.enable_logging()
#module.disable_function_logging('attn_connect')
#module.disable_function_logging('mlp_connect')
### enabling printing for sub-modules
#module.pre_attn_norm.enable_logging()
#module.attn.enable_logging()
#module.post_attn_norm.enable_logging()
#module.pre_mlp_norm.enable_logging()
#module.mlp.enable_logging()
#module.post_mlp_norm.enable_logging()

# optionally enabling printing of every single input/output tensor
#module.enable_full_tensor_printing()

# precompute RoPE frequencies, causal mask, and dummy input data
precompute_freqs = PrecomputeRotaryFrequencies(cfg.head_dim, cfg.max_seq_len, cfg.theta, cfg.device)
freqs_cis = precompute_freqs()
mask = torch.ones(cfg.max_seq_len, cfg.max_seq_len, dtype=torch.bool, device=cfg.device).triu(diagonal=1)
mask = torch.triu(mask, diagonal=1)
x = torch.randn(tcfg.micro_batch_size,cfg.max_seq_len,cfg.dim).to(cfg.device)

output = module(x, freqs_cis, mask, training=True)
module.disable_logging()
del module,freqs_cis, mask, x, output

number of parameters: 0.73K
Layer(
  (pre_attn_norm): Norm()
  (attn): SelfAttention(
    (Wq): Linear(in_features=8, out_features=8, bias=False)
    (Wk): Linear(in_features=8, out_features=4, bias=False)
    (Wv): Linear(in_features=8, out_features=4, bias=False)
    (Wo): Linear(in_features=8, out_features=8, bias=False)
  )
  (pre_mlp_norm): Norm()
  (mlp): MLP(
    (Wup): Linear(in_features=8, out_features=21, bias=False)
    (Wgate): Linear(in_features=8, out_features=21, bias=False)
    (Wdown): Linear(in_features=21, out_features=8, bias=False)
    (nonlinearity): SiLU()
  )
)

Inputs:
Tensor 'x' shape: torch.Size([4, 10, 8]), dtype: torch.float32, device: mps:0, min/max: -2.756/2.746
Dict 'freqs':
    Tensor 'freqs[cos]' shape: torch.Size([1, 10, 1, 4]), dtype: torch.float32, device: mps:0, min/max: -0.990/1.000
    Tensor 'freqs[sin]' shape: torch.Size([1, 10, 1, 4]), dtype: torch.float32, device: mps:0, min/max: -0.959/0.989
Tensor 'mask' shape: torch.Size([10, 10]), dtype: 

In [14]:
%%time

# INFERENCE
module = Layer(cfg).to(cfg.device)
print("number of parameters: %.2fK" % (module.get_num_params()/1e3,))
print(module)

module.enable_logging()
#module.disable_function_logging('attn_connect')
#module.disable_function_logging('mlp_connect')
#module.pre_attn_norm.enable_logging()
#module.attn.enable_logging()
#module.post_attn_norm.enable_logging()
#module.pre_mlp_norm.enable_logging()
#module.mlp.enable_logging()
#module.post_mlp_norm.enable_logging()

# optionally enabling printing of every single input/output tensor
#module.enable_full_tensor_printing()

# precompute RoPE frequencies, causal mask, and dummy input data
precompute_freqs = PrecomputeRotaryFrequencies(cfg.head_dim, cfg.max_seq_len, cfg.theta, cfg.device)
freqs_cis = precompute_freqs()
mask = torch.ones(cfg.max_seq_len, cfg.max_seq_len, dtype=torch.bool, device=cfg.device).triu(diagonal=1)
mask = torch.triu(mask, diagonal=1)
# setting up for kv caching
cache_len = cfg.max_seq_len // 3
context_chunk_len = cfg.max_seq_len // 4
seq_len = cache_len + context_chunk_len
kv_cache = {
    'k': torch.zeros((tcfg.micro_batch_size, cfg.max_seq_len, cfg.num_kv_heads, cfg.head_dim), device=cfg.device),
    'v': torch.zeros((tcfg.micro_batch_size, cfg.max_seq_len, cfg.num_kv_heads, cfg.head_dim), device=cfg.device)
}
# need to extend the mask with zeros for the cached values
mask = torch.nn.functional.pad(mask[:context_chunk_len, :context_chunk_len], (cache_len, 0, 0, 0), value=False).bool()
# these don't use seq_len because those entries should already be in the kv cache
#freqs_cis = freqs_cis[:context_chunk_len]
x = torch.randn(tcfg.micro_batch_size,context_chunk_len,cfg.dim).to(cfg.device)

output = module(x, freqs_cis, mask, cache_len, kv_cache)
module.disable_logging()
del module, freqs_cis, mask, cache_len, context_chunk_len, seq_len, kv_cache, x, output

number of parameters: 0.73K
Layer(
  (pre_attn_norm): Norm()
  (attn): SelfAttention(
    (Wq): Linear(in_features=8, out_features=8, bias=False)
    (Wk): Linear(in_features=8, out_features=4, bias=False)
    (Wv): Linear(in_features=8, out_features=4, bias=False)
    (Wo): Linear(in_features=8, out_features=8, bias=False)
  )
  (pre_mlp_norm): Norm()
  (mlp): MLP(
    (Wup): Linear(in_features=8, out_features=21, bias=False)
    (Wgate): Linear(in_features=8, out_features=21, bias=False)
    (Wdown): Linear(in_features=21, out_features=8, bias=False)
    (nonlinearity): SiLU()
  )
)

Inputs:
Tensor 'x' shape: torch.Size([4, 2, 8]), dtype: torch.float32, device: mps:0, min/max: -1.830/2.133
Dict 'freqs':
    Tensor 'freqs[cos]' shape: torch.Size([1, 10, 1, 4]), dtype: torch.float32, device: mps:0, min/max: -0.990/1.000
    Tensor 'freqs[sin]' shape: torch.Size([1, 10, 1, 4]), dtype: torch.float32, device: mps:0, min/max: -0.959/0.989
Tensor 'mask' shape: torch.Size([2, 5]), dtype: tor

# Full Model

In [15]:
from modules.model import Model

In [16]:
%%time

# TRAINING
module = Model(cfg).to(cfg.device)
print("number of parameters: %.2fM" % (module.get_num_params()/1e6,))
print(module)

module.enable_logging()
### enabling printing for sub-modules
module.precompute_freqs.enable_logging()
#module.layers[0].enable_logging()
for i in range(cfg.num_layers):
    module.layers[i].enable_logging()
module.final_norm.enable_logging()

# optionally enabling printing of every single input/output tensor
#module.enable_full_tensor_printing()

input_token_ids = torch.randint(tokenizer.vocab_len, (tcfg.micro_batch_size, cfg.max_seq_len)).to(cfg.device)
target_token_ids = torch.randint(tokenizer.vocab_len, (tcfg.micro_batch_size, cfg.max_seq_len)).to(cfg.device)

output, loss = module(input_token_ids, target_token_ids=target_token_ids)
print(loss)
del module, input_token_ids, target_token_ids, output, loss

number of parameters: 0.01M
Model(
  (token_embedder): Embedding(512, 8)
  (precompute_freqs): PrecomputeRotaryFrequencies()
  (layers): ModuleList(
    (0-1): 2 x Layer(
      (pre_attn_norm): Norm()
      (attn): SelfAttention(
        (Wq): Linear(in_features=8, out_features=8, bias=False)
        (Wk): Linear(in_features=8, out_features=4, bias=False)
        (Wv): Linear(in_features=8, out_features=4, bias=False)
        (Wo): Linear(in_features=8, out_features=8, bias=False)
      )
      (pre_mlp_norm): Norm()
      (mlp): MLP(
        (Wup): Linear(in_features=8, out_features=21, bias=False)
        (Wgate): Linear(in_features=8, out_features=21, bias=False)
        (Wdown): Linear(in_features=21, out_features=8, bias=False)
        (nonlinearity): SiLU()
      )
    )
  )
  (final_norm): Norm()
  (output): Linear(in_features=8, out_features=512, bias=False)
  (criterion): CrossEntropyLoss()
)

Inputs:
Tensor 'input_token_ids' shape: torch.Size([4, 10]), dtype: torch.int64, dev

In [17]:
%%time

# Inference
module = Model(cfg).to(cfg.device)
print("number of parameters: %.2fM" % (module.get_num_params()/1e6,))
print(module)

module.enable_logging()
### enabling printing for sub-modules
#for i in range(cfg.num_layers):
    #module.layers[i].enable_logging()
#module.final_norm.enable_logging()

# optionally enabling printing of every single input/output tensor
#module.enable_full_tensor_printing()

input_token_ids = torch.randint(tokenizer.vocab_len, (tcfg.micro_batch_size, cfg.max_seq_len // 4)).to(cfg.device)
kv_cache = [{ # Initialize kv caches for each layer
                "k": torch.zeros((tcfg.micro_batch_size, cfg.max_seq_len, cfg.num_kv_heads, cfg.head_dim), device=cfg.device),
                "v": torch.zeros((tcfg.micro_batch_size, cfg.max_seq_len, cfg.num_kv_heads, cfg.head_dim), device=cfg.device),
            } for _ in range(cfg.num_layers)]

output, kv_cache = module(input_token_ids, cache_len = cfg.max_seq_len // 3, kv_cache = kv_cache)

del module, input_token_ids, kv_cache, output

number of parameters: 0.01M
Model(
  (token_embedder): Embedding(512, 8)
  (precompute_freqs): PrecomputeRotaryFrequencies()
  (layers): ModuleList(
    (0-1): 2 x Layer(
      (pre_attn_norm): Norm()
      (attn): SelfAttention(
        (Wq): Linear(in_features=8, out_features=8, bias=False)
        (Wk): Linear(in_features=8, out_features=4, bias=False)
        (Wv): Linear(in_features=8, out_features=4, bias=False)
        (Wo): Linear(in_features=8, out_features=8, bias=False)
      )
      (pre_mlp_norm): Norm()
      (mlp): MLP(
        (Wup): Linear(in_features=8, out_features=21, bias=False)
        (Wgate): Linear(in_features=8, out_features=21, bias=False)
        (Wdown): Linear(in_features=21, out_features=8, bias=False)
        (nonlinearity): SiLU()
      )
    )
  )
  (final_norm): Norm()
  (output): Linear(in_features=8, out_features=512, bias=False)
  (criterion): CrossEntropyLoss()
)

Inputs:
Tensor 'input_token_ids' shape: torch.Size([4, 2]), dtype: torch.int64, devi