In [1]:
# importing the model config
from params import *

# importing minLlama3
from model import *

# used in the training loop
import time

# used to save & load models
import json
from dataclasses import asdict

# Load a Pretrained Model

In [4]:
# pretrained model options:
#    2m parameters, context length = 256, trained for 500 iterations w/ batch size of 32 and no dropout: 'Llama3_2024-04-19|04-00-15'
name = 'Llama3_2024-04-19|04-00-15'

# Deserialize the JSON file back to a dictionary
with open(f'models/{name}.json', 'r') as f:
    params_dict = json.load(f)

# Convert the dictionary back to a dataclass object
params = ModelArgs(**params_dict)
params.device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Initialize a blank model
model = Llama3(params, tokenizer).to(params.device)  

# here's the path to a minGemma model that i've trained with roughly 1m parameters
path = f'models/{name}.pth'

# Load the saved state dictionary
model.load_state_dict(torch.load(path)) 
# REMEMBER TO CHANGE VALUES IN params TO MATCH THE MODEL YOU'VE LOADED

# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')

# If you only plan to do inference, switch to evaluation mode
model.eval()

1968.256 K parameters


Llama3(
  (tok_embeddings): Embedding(256, 128)
  (layers): ModuleList(
    (0-7): 8 x TransformerBlock(
      (attention): Attention(
        (wq): Linear(in_features=128, out_features=128, bias=False)
        (wk): Linear(in_features=128, out_features=32, bias=False)
        (wv): Linear(in_features=128, out_features=32, bias=False)
        (wo): Linear(in_features=128, out_features=128, bias=False)
      )
      (feed_forward): FeedForward(
        (w1): Linear(in_features=128, out_features=512, bias=False)
        (w2): Linear(in_features=512, out_features=128, bias=False)
        (w3): Linear(in_features=128, out_features=512, bias=False)
      )
      (attention_norm): RMSNorm()
      (ffn_norm): RMSNorm()
    )
  )
  (norm): RMSNorm()
  (output): Linear(in_features=128, out_features=256, bias=False)
  (criterion): CrossEntropyLoss()
)

# Inference

In [5]:
input_str = "JULIET:\nO Romeo, Romeo! wherefore art thou R" # the classic line

In [9]:
# doing everything with default values
print(model.generate(input_str))

JULIET:
O Romeo, Romeo! wherefore art thou Romeo sleep
the first be the dish of the sea, where you are too,
The king of my friends, and the city of the world.

Second Citizen:
He is the ground to signifier who to give
the princes of the pitchance of the victory.

MERCUTIO:
You was the man the world of William Sir William Saint Plantagenet?

LUCIO:
He comes your way will prove you what he weep


#### now let's use memory_saver_div to take advantage of KV caching for linear scaling of memory usage with sequence length increase in exchange for *potential* quality degradation. memory_saver_div must be a power of 2, and it is used to calculate the maximum length of the query's sequence length dimension in the attention matrix

In [13]:
output = model.generate(
    input_str, 
    max_gen_len = params.max_seq_len - len(input_str), # our model doesn't have a built-in <endoftext> token so we have to specify when to stop generating
    memory_saver_div = 4, # the largest value we'll allow our query sequence length to get. makes memory consumption linear with respect to sequence length
    temperature = 0.6, # this is the default value that Llama3's official code has set
    top_p = 0.9, # this is the default value that Llama3's official code has set
    top_k = 16, # meta's code doesn't actually implement top_k selection but i've added it anyways as an alternative
)
print(output)

maximum attention matrix size will be 64x256 rather than 256x256

JULIET:
O Romeo, Romeo! wherefore art thou Rome,
When I am a look that shall we were a gods,
I cannot recompany to the king to please the blood.

First Servingman:
O, good some little blood with her brother and plain as whose founds,
Which is this day of his extremity:
I have a princely spirit of the field,
And when a man of his common grow.

POMPEY:
Well, be 
