In [1]:
# my virtual environments are rarely properly connected to jupyter so this fixes that. 
# you won't need this cell but running it won't hurt anything either
import sys
import os
current_dir = os.getcwd()  # Get the current working directory
venv_dir = os.path.join(current_dir, '../venv') 
python_version = str(sys.version_info.major) + '.' + str(sys.version_info.minor)
site_packages_path = os.path.join(venv_dir, 'lib', 'python' + python_version, 'site-packages')
sys.path.append(site_packages_path) 

In [2]:
# used to save & load models
import json
from dataclasses import asdict

# tokenizer
sys.path.append("..")  # Adds the parent directory to the path so we can see the tokenizer
from tokenizer_TinyStories import *
size = 512 # size options are 128, 256, 512 and 1024
path = f'../tokenizers/tiny_stories_tokenizer_{size}.model'
tokenizer = get_tokenizer(path) 

# config file
from config import *
cfg = Config()
cfg.vocab_len = tokenizer.vocab_len
print(cfg)

# model modules
from model import *

# inference code
from inference import *

# used to save & load models
import json
from dataclasses import asdict

Config(dim=128, vocab_len=512, device='cpu', num_layers=10, pre_connect_dropout=False, second_resid_norm=False, mlp_hidden_mult=2, mlp_bias=False, mlp_nonlinearity='GeLU', mlp_gated=True, num_q_heads=4, num_kv_heads=1, theta=10000, max_seq_len=512, scale_first_resid=True, norm_type='RMSNorm', norm_affine=True, norm_bias=True, eps=1e-06, max_batch_size=1)


# Load a Pretrained Model

In [31]:
# pretrained model options:
# - 
name = ''

# Deserialize the JSON file back to a dictionary
with open(f'models/{name}.json', 'r') as f:
    config_dict = json.load(f)

# Convert the dictionary back to a dataclass object
cfg = Config(**config_dict)
cfg.device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Initialize a blank model
model = customGPT(cfg).to(cfg.device)  

# here's the path to a minGemma model that i've trained with roughly 1m parameters
path = f'models/{name}.pth'

# Load the saved state dictionary
model.load_state_dict(torch.load(path)) 
# REMEMBER TO CHANGE VALUES IN params TO MATCH THE MODEL YOU'VE LOADED

# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')

# If you only plan to do inference, switch to evaluation mode
model.eval()

2985.088 K parameters


Llama3(
  (tok_embeddings): Embedding(512, 128)
  (layers): ModuleList(
    (0-11): 12 x TransformerBlock(
      (attention): Attention(
        (wq): Linear(in_features=128, out_features=128, bias=False)
        (wk): Linear(in_features=128, out_features=32, bias=False)
        (wv): Linear(in_features=128, out_features=32, bias=False)
        (wo): Linear(in_features=128, out_features=128, bias=False)
      )
      (feed_forward): FeedForward(
        (w1): Linear(in_features=128, out_features=512, bias=False)
        (w2): Linear(in_features=512, out_features=128, bias=False)
        (w3): Linear(in_features=128, out_features=512, bias=False)
      )
      (attention_norm): RMSNorm()
      (ffn_norm): RMSNorm()
    )
  )
  (norm): RMSNorm()
  (output): Linear(in_features=128, out_features=512, bias=False)
  (criterion): CrossEntropyLoss()
)

# temporarily use this for testing until I have a model trained

In [3]:
model = customGPT(cfg, tokenizer).to(cfg.device)

# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')

print(model)

1463.552 K parameters
customGPT(
  (token_embedder): Embedding(512, 128)
  (layers): ModuleList(
    (0-9): 10 x ResidualLayer(
      (pre_attn_norm): Norm()
      (attn): MQSA(
        (Wq): Linear(in_features=128, out_features=128, bias=False)
        (Wk): Linear(in_features=128, out_features=32, bias=False)
        (Wv): Linear(in_features=128, out_features=32, bias=False)
        (Wo): Linear(in_features=128, out_features=128, bias=False)
      )
      (pre_mlp_norm): Norm()
      (mlp): MLP(
        (Wgate): Linear(in_features=128, out_features=256, bias=False)
        (Wup): Linear(in_features=128, out_features=256, bias=False)
        (Wdown): Linear(in_features=256, out_features=128, bias=False)
        (nonlinearity): GELU(approximate='none')
      )
    )
  )
  (final_norm): Norm()
  (criterion): CrossEntropyLoss()
)


# Inference

In [4]:
example_prompt = "Hello, world!"
generated_text = generate(
    example_prompt, 
    model, 
    tokenizer,
    max_gen_len = 5,
    temperature = 20.,
)
print(generated_text)

Hello, world!!!!!!


In [6]:
example_prompt = "Hello, world!"
generated_text = generate(
    example_prompt, 
    model, 
    tokenizer,
    memory_saver_div = 16,
    temperature = 20.,
)
print(generated_text)

maximum attention matrix size in memory will be 32x512 rather than 512x512

Hello, world!!!!!!UUdogmmmmmmmmmssfastpotpotpotpotpotpotpot!"    birdbirdbirdbirdforforforpotpotfveryveryverybestbestcccickccli" " getherdeureureureureureureureureureheheheomomknowknow0bugbugld5HHHMM



YeslikedlikedlikedlikedlikedlikedlikedlikedhedAAmairirithithithwhotogetherLucyTimhenhenhenhenhenhenpecpecpecpecpecpecpecpeclivedlivedlivedkingkingumpave5startedstartedstartedasasasustustbothbothbothourourkllededcanjjscarscarscarscarcamecamecamecamecamecamecameououanyanyanyanybackbackbackouriled"
"
wwnamedreucexcitedexcitedexcitedexcitedexcitedooowaterpotpotpotightightfloflofloflofloflofloSheplayinginryYoutbackTim????iriririririririririririr"TheyTheyTheyTheykind---girlgirlgirlgirlututgirmakearVVVVZthatthatch.
.
.
.
.
.
asasassmsmsmShesomethingOneSamSamSamSambl::::::ranranranyouyouteknowknowknowknowknowknowknowknowknowkekekekesawriedherehereakeakeclymoremoremorebecauseqqqqqqiteiteiteiteiteitethat!!!!!cccccccthatth