In [1]:
# my virtual environments are rarely properly connected to jupyter so this fixes that. 
# you won't need this cell but running it won't hurt anything either
import sys
import os
current_dir = os.getcwd()  # Get the current working directory
venv_dir = os.path.join(current_dir, './venv') 
python_version = str(sys.version_info.major) + '.' + str(sys.version_info.minor)
site_packages_path = os.path.join(venv_dir, 'lib', 'python' + python_version, 'site-packages')
sys.path.append(site_packages_path) 

In [2]:
# model modules
from model import *

# inference code
from inference import *

# used to save & load models
import json
from dataclasses import asdict

# Load a Pretrained Model

In [3]:
# tokenizer
from tokenizer import *
size = 512 # size options are 128, 256, 512 and 1024
path = f'./tokenizers/tiny_stories_tokenizer_{size}.model'
tokenizer = get_tokenizer(path) 

# config file
from config import *
cfg = Config()
cfg.vocab_len = tokenizer.vocab_len
print(cfg)

Config(dim=128, vocab_len=515, device='cpu', num_layers=10, pre_connect_dropout=False, second_resid_norm=False, mlp_hidden_mult=2, mlp_bias=False, mlp_nonlinearity='GeLU', mlp_gated=True, num_q_heads=4, num_kv_heads=1, theta=10000, max_seq_len=512, scale_first_resid=True, norm_type='RMSNorm', norm_affine=True, norm_bias=True, eps=1e-06, max_batch_size=1)


In [4]:
# pretrained model options:
# - a 1.5m parameter that hasn't really been trained, just a test: customGPT_2024-04-25|10-16-11
name = 'customGPT_2024-04-25|10-16-11'

# Deserialize the JSON file back to a dictionary
with open(f'models/{name}.json', 'r') as f:
    config_dict = json.load(f)

# Convert the dictionary back to a dataclass object
cfg = Config(**config_dict)
cfg.device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Initialize a blank model
model = customGPT(cfg).to(cfg.device)  

# here's the path to a minGemma model that i've trained with roughly 1m parameters
path = f'models/{name}.pth'

# Load the saved state dictionary
model.load_state_dict(torch.load(path)) 
# REMEMBER TO CHANGE VALUES IN params TO MATCH THE MODEL YOU'VE LOADED

# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')

# If you only plan to do inference, switch to evaluation mode
model.eval()

1463.936 K parameters


customGPT(
  (token_embedder): Embedding(515, 128)
  (layers): ModuleList(
    (0-9): 10 x ResidualLayer(
      (pre_attn_norm): Norm()
      (attn): MQSA(
        (Wq): Linear(in_features=128, out_features=128, bias=False)
        (Wk): Linear(in_features=128, out_features=32, bias=False)
        (Wv): Linear(in_features=128, out_features=32, bias=False)
        (Wo): Linear(in_features=128, out_features=128, bias=False)
      )
      (pre_mlp_norm): Norm()
      (mlp): MLP(
        (Wgate): Linear(in_features=128, out_features=256, bias=False)
        (Wup): Linear(in_features=128, out_features=256, bias=False)
        (Wdown): Linear(in_features=256, out_features=128, bias=False)
        (nonlinearity): GELU(approximate='none')
      )
    )
  )
  (final_norm): Norm()
  (criterion): CrossEntropyLoss()
)

# Inference

In [5]:
prompt = "Hello, world!"
output = generate(
    prompt, 
    model, 
    tokenizer,
    max_gen_len = 5,
    temperature = 20.,
)
print(output)

Hello, world!Tom2letde4


In [6]:
prompt = "Hello, world!"
output = generate(
    prompt, 
    model, 
    tokenizer,
    memory_saver_div = 16,
    temperature = 100.,
)
print(output)

maximum attention matrix size in memory will be 32x512 rather than 512x512

Hello, world!ceallsMaxOnwantshldscaredwillgirXjtoysopget
HeinVenedbooclenamSarasaidcreishunmavenAbothgened ustjohadparthathelpbrhelpoutoundwereeveryck2daybestasuhascarver.
"fulasaysackanwouldpotWelookedlookedMiaileilelookMomsmnytAnnalotnamaroundnedaskediagr4keankscaredKSpotroingcatbedderThehere1hoeatmanso.
"0ucoodnolyzlyssjroomilaskeddidcrewater7enedwereeatatonesmidescevenustOndayOneorwasMia.didtoankiaput.
Cfore4ulpown$seusscaredagainfterumpOncewhatSheLucygetab!"
ppltThenq9somedogkneweq," ongimeoutderlaughedrain
tooumppotqomedadnamsaclomadbugTheyBengogetstmakerefidewnamedth2n!"
itedOneustpeheretoysoutlongusegrscaredenednowbutmewawithddwebir. irmanamellvenAseeBhomeile,ugRittiesurtedgrefefastyour
