# This notebook is designed for teaching/testing purposes to help you visualize the tensor shapes that go through each module

In [1]:
# my virtual environments are rarely properly connected to jupyter so this fixes that. 
# you prolly won't need this cell but running it won't hurt anything either
import sys
import os
current_dir = os.getcwd()  # Get the current working directory
venv_dir = os.path.join(current_dir, 'venv') 
python_version = str(sys.version_info.major) + '.' + str(sys.version_info.minor)
site_packages_path = os.path.join(venv_dir, 'lib', 'python' + python_version, 'site-packages')
sys.path.append(site_packages_path) 

In [2]:
# config file
from config import ModelConfig
cfg = ModelConfig()
print(cfg)

# import the tokenizer specified by cfg
from tools import import_from_nested_path
imported_objects = import_from_nested_path(['tokenizers', cfg.tokenizer], 'tokenizer', ['get_tokenizer'])
get_tokenizer = imported_objects.get('get_tokenizer')
tokenizer = get_tokenizer(size = 512) # assuming 'bpe', size options are 95, 128, 256, 512, 1024 and 2048

import random
import torch

ModelConfig(dim=128, device='mps', tokenizer='bpe_v1', vocab_len=8192, num_layers=4, second_resid_norm=False, num_heads=4, head_dim=32, max_seq_len=128, mm_bias=False, pmem_size=336, pmem_count=2, scale_first_resid=True, norm_type='RMSNorm', norm_affine=True, norm_bias=True, eps=1e-06)


# Norms

In [3]:
from modules.norm import Norm

In [4]:
%%time

### RMSNorm

# Create an instance of RMSNorm
module = Norm(cfg.dim, 'RMSNorm').to(cfg.device)

# let's take a look
print(sum(p.numel() for p in module.parameters())/1e3, 'K parameters')
print(module)

# Initially, logging is disabled by default
module.enable_logging()

### Optionally disabling printing for sub-functions
#module.disable_function_logging('CosineNorm')
#module.disable_function_logging('LayerNorm')
#module.disable_function_logging('RMSNorm')

x = torch.randn(32,cfg.max_seq_len,cfg.dim).to(cfg.device)

# Call the forward method - logging will occur
output = module(x)

# Disable logging. 
# This isn't actually necessary since we won't be using this object again but that's how you'd do it
module.disable_logging()

# clearing up ram jic we're training later
del module, x, output

0.256 K parameters
Norm()

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 128])
CPU times: user 61.1 ms, sys: 41.3 ms, total: 102 ms
Wall time: 110 ms


# Leaky Avg

In [5]:
from modules.memory_mosaic import LeakyAvg

In [6]:
%%time

# Create an instance of context memory
module = LeakyAvg(cfg.max_seq_len, cfg.num_heads).to(cfg.device)
print(sum(p.numel() for p in module.parameters())/1e3, 'K parameters')
print(module)

# Initially, logging is disabled by default
module.enable_logging()
### Optionally disabling printing for sub-functions
#module.disable_function_logging('')

# Call the forward method - logging will occur
x = torch.randn(32,cfg.num_heads,cfg.max_seq_len,cfg.dim // cfg.num_heads).to(cfg.device)
output = module(x)

# Disable logging. 
# This isn't actually necessary since we won't be using this object again but that's how you'd do it
module.disable_logging()

# clearing up ram jic we're training later
del module, x, output

0.004 K parameters
LeakyAvg()

Inputs:
Tensor 'k' shape: torch.Size([32, 4, 128, 32])

Outputs:
Tensor 'output' shape: torch.Size([32, 4, 128, 32])
CPU times: user 35.9 ms, sys: 5.5 ms, total: 41.4 ms
Wall time: 44.7 ms


# Key Feature Extractor

In [7]:
from modules.memory_mosaic import KeyFeatureExtractor

In [8]:
%%time

# Create an instance of context memory
module = KeyFeatureExtractor(
    cfg.num_heads, 
    cfg.head_dim,
    cfg.dim, 
    cfg.mm_bias, 
    cfg.max_seq_len
).to(cfg.device)
print(sum(p.numel() for p in module.parameters())/1e3, 'K parameters')
print(module)

# Initially, logging is disabled by default
module.enable_logging()

### Optionally disabling printing for sub-functions
#module.disable_function_logging('make_key')
#module.disable_function_logging('scale_key')
### enabling printing for sub-modules
module.leaky_avg.enable_logging()

# Call the forward method - logging will occur
x = torch.randn(32,cfg.max_seq_len,cfg.dim).to(cfg.device)
output = module(x)

# Disable logging. 
# This isn't actually necessary since we won't be using this object again but that's how you'd do it
module.disable_logging()

# clearing up ram jic we're training later
del module, x, output

16.392 K parameters
KeyFeatureExtractor(
  (W_k): Linear(in_features=128, out_features=128, bias=False)
  (leaky_avg): LeakyAvg()
)

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 4, 128, 32])

Inputs:
Tensor 'k' shape: torch.Size([32, 4, 128, 32])

Outputs:
Tensor 'output' shape: torch.Size([32, 4, 128, 32])

Inputs:
Tensor 'k' shape: torch.Size([32, 4, 128, 32])
Integer 'scale_pow': Value=1

Outputs:
Tensor 'output' shape: torch.Size([32, 4, 128, 32])

Outputs:
Tensor 'output' shape: torch.Size([32, 4, 128, 32])
CPU times: user 41.1 ms, sys: 5.1 ms, total: 46.2 ms
Wall time: 51.3 ms


# Value Feature Extractor

In [9]:
from modules.memory_mosaic import ValFeatureExtractor

In [10]:
%%time

# Create an instance of context memory
module = ValFeatureExtractor(
    cfg.num_heads, 
    cfg.head_dim,
    cfg.dim, 
    cfg.mm_bias
).to(cfg.device)
print(sum(p.numel() for p in module.parameters())/1e3, 'K parameters')
print(module)

# Initially, logging is disabled by default
module.enable_logging()

### Optionally disabling printing for sub-functions
#module.disable_function_logging('make_val')
#module.disable_function_logging('scale_val')

# Call the forward method - logging will occur
x = torch.randn(32,cfg.max_seq_len,cfg.dim).to(cfg.device)
output = module(x)

# Disable logging. 
# This isn't actually necessary since we won't be using this object again but that's how you'd do it
module.disable_logging()

# clearing up ram jic we're training later
del module, x, output

16.392 K parameters
ValFeatureExtractor(
  (W_v): Linear(in_features=128, out_features=128, bias=False)
)

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 4, 32, 128])

Inputs:
Tensor 'v' shape: torch.Size([32, 4, 32, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 4, 128, 32])

Outputs:
Tensor 'output' shape: torch.Size([32, 4, 128, 32])
CPU times: user 27.8 ms, sys: 3.54 ms, total: 31.4 ms
Wall time: 33.5 ms


# Context Memory

In [11]:
from modules.memory_mosaic import ContextMem

In [12]:
%%time

# Create an instance of context memory
module = ContextMem(
    cfg.num_heads, 
    cfg.head_dim,
    cfg.dim, 
    cfg.mm_bias, 
    cfg.max_seq_len, 
    cfg.dropout_rate
).to(cfg.device)
print(sum(p.numel() for p in module.parameters())/1e3, 'K parameters')
print(module)

# Initially, logging is disabled by default
module.enable_logging()

### Optionally disabling printing for sub-functions
#module.disable_function_logging('attend')
#module.disable_function_logging('proj_vals')
#module.disable_function_logging('reassemble_heads')
### enabling printing for sub-modules
module.k_featurizer.enable_logging()
module.k_featurizer.leaky_avg.enable_logging()
module.v_featurizer.enable_logging()

# Call the forward method - logging will occur
x = torch.randn(32,cfg.max_seq_len,cfg.dim).to(cfg.device)
output = module(x)

# Disable logging. 
# This isn't actually necessary since we won't be using this object again but that's how you'd do it
module.disable_logging()

# clearing up ram jic we're training later
del module, x, output

49.168 K parameters
ContextMem(
  (k_featurizer): KeyFeatureExtractor(
    (W_k): Linear(in_features=128, out_features=128, bias=False)
    (leaky_avg): LeakyAvg()
  )
  (v_featurizer): ValFeatureExtractor(
    (W_v): Linear(in_features=128, out_features=128, bias=False)
  )
  (c_proj): Linear(in_features=128, out_features=128, bias=False)
  (resid_dropout): Dropout(p=0.1, inplace=False)
  (attn_dropout): Dropout(p=0.1, inplace=False)
)

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 4, 128, 32])

Inputs:
Tensor 'k' shape: torch.Size([32, 4, 128, 32])

Outputs:
Tensor 'output' shape: torch.Size([32, 4, 128, 32])

Inputs:
Tensor 'k' shape: torch.Size([32, 4, 128, 32])
Integer 'scale_pow': Value=1

Outputs:
Tensor 'output' shape: torch.Size([32, 4, 128, 32])

Outputs:
Tensor 'output' shape: torch.Size([32, 4, 128, 32])

Inputs:

# Persistent Memory

In [13]:
from modules.memory_mosaic import PersistentMem

In [14]:
%%time

# Create an instance of context memory
module = PersistentMem(
    cfg.num_heads, 
    cfg.head_dim,
    cfg.dim, 
    cfg.mm_bias, 
    cfg.max_seq_len, 
    cfg.pmem_count, 
    cfg.pmem_size, 
    cfg.dropout_rate
).to(cfg.device)
print(sum(p.numel() for p in module.parameters())/1e3, 'K parameters')
print(module)

# Initially, logging is disabled by default
module.enable_logging()

### Optionally disabling printing for sub-functions
#module.disable_function_logging('attend')
#module.disable_function_logging('proj_val')
#module.disable_function_logging('scale')
#module.disable_function_logging('reassemble_heads')
### enabling printing for sub-modules
module.k_featurizer.enable_logging()
module.k_featurizer.leaky_avg.enable_logging()

# Call the forward method - logging will occur
x = torch.randn(32,cfg.max_seq_len,cfg.dim).to(cfg.device).to(cfg.device)
output = module(x)

# Disable logging. 
# This isn't actually necessary since we won't be using this object again but that's how you'd do it
module.disable_logging()

# clearing up ram jic we're training later
del module, x, output

204.812 K parameters
PersistentMem(
  (k_featurizer): KeyFeatureExtractor(
    (W_k): Linear(in_features=128, out_features=128, bias=False)
    (leaky_avg): LeakyAvg()
  )
  (c_proj): Linear(in_features=128, out_features=128, bias=False)
  (resid_dropout): Dropout(p=0.1, inplace=False)
  (attn_dropout): Dropout(p=0.1, inplace=False)
)

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 128])
Integer 'scale_pow': Value=2

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 4, 128, 32])

Inputs:
Tensor 'k' shape: torch.Size([32, 4, 128, 32])

Outputs:
Tensor 'output' shape: torch.Size([32, 4, 128, 32])

Inputs:
Tensor 'k' shape: torch.Size([32, 4, 128, 32])
Integer 'scale_pow': Value=2

Outputs:
Tensor 'output' shape: torch.Size([32, 4, 128, 32])

Outputs:
Tensor 'output' shape: torch.Size([32, 4, 128, 32])

Inputs:
Integer 'y': Value=0
Tensor 'k' shape: torch.Size([32, 4, 128, 32])
Intege

# ResidualLayer

In [15]:
from modules.layer import Layer

In [16]:
%%time

# TRAINING
module = Layer(cfg).to(cfg.device)
print(sum(p.numel() for p in module.parameters())/1e3, 'K parameters')
print(module)

module.enable_logging()
### Optionally disabling printing for sub-functions
#module.disable_function_logging('context_connect')
#module.disable_function_logging('persistent_connect')
### enabling printing for sub-modules
#module.pre_context_norm.enable_logging()
#module.context.enable_logging()
#module.post_context_norm.enable_logging()
#module.pre_persistent_norm.enable_logging()
#module.persistent.enable_logging()
#module.post_persistent_norm.enable_logging()

x = torch.randn(32,cfg.max_seq_len,cfg.dim).to(cfg.device)

output = module(x, training=True)
module.disable_logging()
del module, x, output

254.492 K parameters
Layer(
  (pre_context_norm): Norm()
  (context): ContextMem(
    (k_featurizer): KeyFeatureExtractor(
      (W_k): Linear(in_features=128, out_features=128, bias=False)
      (leaky_avg): LeakyAvg()
    )
    (v_featurizer): ValFeatureExtractor(
      (W_v): Linear(in_features=128, out_features=128, bias=False)
    )
    (c_proj): Linear(in_features=128, out_features=128, bias=False)
    (resid_dropout): Dropout(p=0.1, inplace=False)
    (attn_dropout): Dropout(p=0.1, inplace=False)
  )
  (pre_persistent_norm): Norm()
  (persistent): PersistentMem(
    (k_featurizer): KeyFeatureExtractor(
      (W_k): Linear(in_features=128, out_features=128, bias=False)
      (leaky_avg): LeakyAvg()
    )
    (c_proj): Linear(in_features=128, out_features=128, bias=False)
    (resid_dropout): Dropout(p=0.1, inplace=False)
    (attn_dropout): Dropout(p=0.1, inplace=False)
  )
)

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 128])
Integer 'training': Value=True

Inputs:
Tensor 'x' 

# Full Model

In [17]:
from modules.model import Model

In [18]:
%%time

# TRAINING
module = Model(cfg).to(cfg.device)
print(sum(p.numel() for p in module.parameters())/1e3, 'K parameters')
print(module)

module.enable_logging()
### enabling printing for sub-modules
module.layers[0].enable_logging() # we'll only look at one layer
module.final_norm.enable_logging()

input_token_ids = torch.randint(tokenizer.vocab_len, (32, cfg.max_seq_len)).to(cfg.device)
target_token_ids = torch.randint(tokenizer.vocab_len, (32, cfg.max_seq_len)).to(cfg.device)

output, loss = module(input_token_ids, target_token_ids=target_token_ids)
print(loss)
module.disable_logging()
del module, input_token_ids, target_token_ids, output

2067.184 K parameters
Model(
  (token_embedder): Embedding(8195, 128)
  (layers): ModuleList(
    (0-3): 4 x Layer(
      (pre_context_norm): Norm()
      (context): ContextMem(
        (k_featurizer): KeyFeatureExtractor(
          (W_k): Linear(in_features=128, out_features=128, bias=False)
          (leaky_avg): LeakyAvg()
        )
        (v_featurizer): ValFeatureExtractor(
          (W_v): Linear(in_features=128, out_features=128, bias=False)
        )
        (c_proj): Linear(in_features=128, out_features=128, bias=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
        (attn_dropout): Dropout(p=0.1, inplace=False)
      )
      (pre_persistent_norm): Norm()
      (persistent): PersistentMem(
        (k_featurizer): KeyFeatureExtractor(
          (W_k): Linear(in_features=128, out_features=128, bias=False)
          (leaky_avg): LeakyAvg()
        )
        (c_proj): Linear(in_features=128, out_features=128, bias=False)
        (resid_dropout): Dropout(p=0.1, inpl