In [1]:
# my virtual environments are rarely properly connected to jupyter so this fixes that
import sys
import os
current_dir = os.getcwd()  # Get the current working directory
venv_dir = os.path.join(current_dir, 'venv') 
python_version = str(sys.version_info.major) + '.' + str(sys.version_info.minor)
site_packages_path = os.path.join(venv_dir, 'lib', 'python' + python_version, 'site-packages')
sys.path.append(site_packages_path) 

In [2]:
# importing the model config
from params import *

# importing minLlama3
from model import *

# used in the training loop
import time

# used to save & load models
import json
from dataclasses import asdict

# Load a Pretrained Model

In [31]:
# pretrained model options:
# 2m parameters, context length = 256, trained for 500 iterations w/ batch size of 32 and no dropout: 'Llama3_2024-04-19|04-00-15'
# 2m parameters, context length = 512, trained for 1000 iterations w/ batch size 32 and dropout 0.1: 'Llama3_2024-04-19|15-18-16'
# 3m parameters, context length = 512, trained for 1300 iterations w/ batch size of 24 and dropout 0.1: 'Llama3_2024-04-19|17-21-51'
name = 'Llama3_2024-04-19|17-21-51'

# Deserialize the JSON file back to a dictionary
with open(f'models/{name}.json', 'r') as f:
    params_dict = json.load(f)

# Convert the dictionary back to a dataclass object
params = ModelArgs(**params_dict)
params.device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Initialize a blank model
model = Llama3(params, tokenizer).to(params.device)  

# here's the path to a minGemma model that i've trained with roughly 1m parameters
path = f'models/{name}.pth'

# Load the saved state dictionary
model.load_state_dict(torch.load(path)) 
# REMEMBER TO CHANGE VALUES IN params TO MATCH THE MODEL YOU'VE LOADED

# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')

# If you only plan to do inference, switch to evaluation mode
model.eval()

2985.088 K parameters


Llama3(
  (tok_embeddings): Embedding(512, 128)
  (layers): ModuleList(
    (0-11): 12 x TransformerBlock(
      (attention): Attention(
        (wq): Linear(in_features=128, out_features=128, bias=False)
        (wk): Linear(in_features=128, out_features=32, bias=False)
        (wv): Linear(in_features=128, out_features=32, bias=False)
        (wo): Linear(in_features=128, out_features=128, bias=False)
      )
      (feed_forward): FeedForward(
        (w1): Linear(in_features=128, out_features=512, bias=False)
        (w2): Linear(in_features=512, out_features=128, bias=False)
        (w3): Linear(in_features=128, out_features=512, bias=False)
      )
      (attention_norm): RMSNorm()
      (ffn_norm): RMSNorm()
    )
  )
  (norm): RMSNorm()
  (output): Linear(in_features=128, out_features=512, bias=False)
  (criterion): CrossEntropyLoss()
)

# Inference

In [32]:
input_str = "JULIET:\nO Romeo, Romeo! wherefore art thou R" # the classic line

In [33]:
# doing everything with default values
print(model.generate(input_str))

JULIET:
O Romeo, Romeo! wherefore art thou Romeo?

JULIET:
The truth, my lord, for I have done for thee,
And I would not have it speak again.

ROMEO:
I would not speak for thee, for I have done.

ROMEO:
I would not speak as she as I love.

ROMEO:
Alack, that is thy soul should be thy son.

JULIET:
I would not be a man that loves it not.

JULIET:
O she did light! O woful day!

JULIET:
O heavens! O wife, what time thou hast believed
me to my right and haste to thee more.

JULIET:
O she is dead!

ROMEO:
I cannot tell it as I saw thee better than thou
Hath yet to thee and thou hast slain thee here.

ROMEO:
That I had been a false that knight should have
The friar of Romeo was to be thus and sea tongue,
May mark a fresh and prove thy looks in easy.

JULIET:
All love, and all thy shoulders give me from thee,
I have but tender mine own beloved.

JULIET:



#### now let's use memory_saver_div to take advantage of KV caching for linear scaling of memory usage with sequence length increase in exchange for *potential* quality degradation. memory_saver_div must be a power of 2, and it is used to calculate the maximum length of the query's sequence length dimension in the attention matrix

In [34]:
output = model.generate(
    input_str, 
    max_gen_len = params.max_seq_len - len(input_str), # our model doesn't have a built-in <endoftext> token so we have to specify when to stop generating
    memory_saver_div = 8, # the largest value we'll allow our query sequence length to get. makes memory consumption linear with respect to sequence length
    temperature = 0.6, # this is the default value that Llama3's official code has set
    top_p = 0.9, # this is the default value that Llama3's official code has set
    top_k = 32, # meta's code doesn't actually implement top_k selection but i've added it anyways as an alternative
)
print(output)

maximum attention matrix size will be 64x512 rather than 512x512

JULIET:
O Romeo, Romeo! wherefore art thou Romeo?

JULIET:
That I may make thee pause; for I am in
a murderer to the fish of death's deliver'd,
I have not death to live thee on the time;
But on thy shame in the battle word,
By me a servant beauty may be spent.

ROMEO:
O misery name, God shall be thus for thy life.

JULIET:
O she did kill me for thy knee be made.

JULIET:
Thou wilt not speak? O true, let me see thee,
The words of Juliet, Juliet, and dead thy life;
For I have not heard thy love with hell!

JULIET:
Hold, Juliet, to thy guests, and makes me dead.

JULIET:
Hie thee to me: O pretty is night!

JULIET:
I am in all of wit and fast by thy life.

JULIET:
O she's dead friend, for what sad news?

JULIET:
Ay me! the duke and head the princes of joy!

JULIET:
O God, I said, for all for thee I have.


