# Train and test your own minGrok (or load mine)

### first the setup stuff

In [1]:
import sys
sys.path.append('./venv/lib/python3.10/site-packages')
import dataclasses
from model import *
from tokenizer import SimpleTokenizer, loaded_stoi, loaded_merges

In [2]:
# load the dataset
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# the first 200 characters. It's just one continuous text document with all of the works of shakespeare back-to-back
print(text[:200])

# here are all the unique characters that occur in this text and how many there are
chars = sorted(list(set(text)))
v = len(chars)
print(chars)
print(v)

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you
['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
65


In [3]:
tokenizer = SimpleTokenizer(loaded_stoi, loaded_merges)
print("vocab length: ", tokenizer.vocab_len)

# Encoding text
encoded_text = tokenizer.encode("JULIET:\nO Romeo, Romeo! wherefore art thou Romeo?")
print("Encoded:", encoded_text)

# Decoding back
decoded_text = tokenizer.decode(encoded_text)
print("Decoded:", decoded_text)

vocab length:  128
Encoded: [22, 33, 24, 21, 17, 32, 71, 27, 1, 30, 53, 83, 53, 66, 30, 53, 83, 53, 2, 1, 61, 87, 93, 105, 43, 1, 77, 58, 1, 65, 67, 1, 30, 53, 83, 53, 12]
Decoded: JULIET:
O Romeo, Romeo! wherefore art thou Romeo?


In [4]:
@dataclasses.dataclass
class Config:
    # v was defined earlier when we loaded TinyShakespeare. In Grok it's 131,072
    vocab_size: int = tokenizer.vocab_len

    # The maximum sequence length that this model might ever be used with.
    max_position_embeddings: int = 256 # in Grok it's 8,192

    # The number of layers in the model.
    num_layers: int = 4 # In Grok it's 64

    # The number of attention heads used in the attention layers of the model.
    num_attention_heads: int = 4 # In Grok it's 48

    # The number of key-value heads for implementing attention.
    num_key_value_heads: int = 1 # In Grok it's 8

    # The hidden size of the model, AKA the embedding dimension. Each token embedding vector will be this long
    hidden_size: int = 96 # In Grok it's 6,144

    # How much wider should the inner dimension of the experts be than the model's embedding dimension?
    embedding_multiplier_scale: int = 2 # In Grok it's roughly 5.33

    # how many experts?
    tot_num_experts: int = 4 # in Grok it's 8

    # how many active experts per token?
    chosen_num_experts: int = 2 # in Grok it's also 2

    # what amount of noise should be injected into the router during training?
    noise_std = 0.1 # the value for Grok has not been shared

    # When we create a loss to encourage all experts to be used, how should that loss be weighted?
    lambadada = 10 # Grok's value has not been shared
    # excuse my silly naming

    # The number of head dimensions
    head_dim: int = 24 # In Grok it's 128

    # The epsilon used by the rms normalization layers.
    rms_norm_eps: float = 1e-5 # this is to promote numerical stability & prevent dividing by 0

    # the scaling factor that determines the frequencies for the rotary positional encodings
    rope_theta = 100.0 # Grok and most models use 10,000
    # smaller models should use a smaller theta, but I'm just guessing here. 1000 might work too

    # whether to use a linear layer after normalization
    use_scale: bool = True # same in Grok

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # the dropout rate to use during training
    dropout = 0.05
    
config = Config()

### training a model

In [5]:
# Train and test splits
data = torch.tensor(tokenizer.encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be our training dataset, the rest for validation
train_data = data[:n]
val_data = data[n:]

In [6]:
# data loading for training which generates a small batch of data of inputs x and targets y
def get_batch(split, batch_size):
    # whether we grab from our training or validation dataset
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - config.max_position_embeddings, (batch_size,))
    x = torch.stack([data[i:i+config.max_position_embeddings] for i in ix])
    y = torch.stack([data[i+1:i+config.max_position_embeddings+1] for i in ix])
    x, y = x.to(config.device), y.to(config.device)
    return x, y

In [7]:
@torch.no_grad()
def estimate_loss(model, batch_size, eval_iters = 10): # to periodically estimate loss during the training loop
    out = {}
    model.eval() # sets model to eval mode
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split, batch_size)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train() # just resets to training mode
    return out

In [8]:
# instantiate a new model
model = minGrok(config, tokenizer).to(config.device)

# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')

print(model)

992.352 K parameters
minGrok(
  (embedder): Embedding(128, 96)
  (layers): ModuleList(
    (0-3): 4 x DecoderLayer(
      (mqa): MQA(
        (qkv_proj): Linear(in_features=96, out_features=144, bias=False)
        (o_proj): Linear(in_features=96, out_features=96, bias=False)
      )
      (moe): MoELayer(
        (experts): ModuleList(
          (0-3): 4 x Expert(
            (layer1): Linear(in_features=96, out_features=384, bias=False)
            (layer2): Linear(in_features=192, out_features=96, bias=False)
          )
        )
        (router): Router(
          (router_weights): Linear(in_features=96, out_features=4, bias=False)
        )
      )
      (pre_mqa_norm): RMSNorm()
      (post_mqa_norm): RMSNorm()
      (pre_moe_norm): RMSNorm()
      (post_moe_norm): RMSNorm()
      (drop): Dropout(p=0.05, inplace=False)
    )
  )
  (final_norm): RMSNorm()
  (criterion): CrossEntropyLoss()
)


In [10]:
# create a PyTorch optimizer
# this is not what they used, but this learning rate & weight decay work for our tiny minGemma
learning_rate = 1e-5 # used 3e-4 for the first 5000 iters
weight_decay = 0.01
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# how long we want to train for
max_iters = 2000

# how often we want to check & see how our loss is doing
eval_interval = 100

# batch size to use
batch_size = 32

import time as time

In [11]:
start_time = time.time()

# Enable anomaly detection. uncomment these lines if you need to do extensive debugging
#torch.autograd.set_detect_anomaly(True)

for iter in range(max_iters):

    # sample a batch of data
    xb, yb = get_batch('train', batch_size)

    # train
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        current_time = time.time()
        elapsed_time = current_time - start_time
        losses = estimate_loss(model, batch_size)
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, time elapsed: {elapsed_time:.2f} seconds")

# Disable anomaly detection after the training loop
#torch.autograd.set_detect_anomaly(False)

step 0: train loss 2.4603, val loss 2.6593, time elapsed: 1.15 seconds
step 100: train loss 2.4402, val loss 2.6460, time elapsed: 102.51 seconds
step 200: train loss 2.4155, val loss 2.6502, time elapsed: 201.02 seconds
step 300: train loss 2.4238, val loss 2.6558, time elapsed: 302.31 seconds
step 400: train loss 2.4197, val loss 2.6455, time elapsed: 399.47 seconds
step 500: train loss 2.4427, val loss 2.6603, time elapsed: 496.25 seconds
step 600: train loss 2.4317, val loss 2.6517, time elapsed: 594.00 seconds
step 700: train loss 2.4523, val loss 2.6506, time elapsed: 695.84 seconds
step 800: train loss 2.4124, val loss 2.6615, time elapsed: 799.43 seconds
step 900: train loss 2.4369, val loss 2.6301, time elapsed: 900.91 seconds
step 1000: train loss 2.4354, val loss 2.6568, time elapsed: 1002.66 seconds
step 1100: train loss 2.4033, val loss 2.6486, time elapsed: 1098.36 seconds
step 1200: train loss 2.4403, val loss 2.6508, time elapsed: 1194.98 seconds
step 1300: train loss 2

In [12]:
# save the model currently held in memory
# the filename specifies the model's class, hyperparameters, and date/time it was saved
import os

# Ensure the directory exists
model_dir = 'models'
if not os.path.exists(model_dir):
    os.makedirs(model_dir)

# Create a shorter, more concise filename
filename = (f'{model.__class__.__name__}'
           f'-v{config.vocab_size}'
           f'-max_t{config.max_position_embeddings}'
           f'-layers{config.num_layers}'
           f'-heads{config.num_attention_heads}'
           f'-kv_heads{config.num_key_value_heads}'
           f'-hidden{config.hidden_size}'
           f'-embedding_multiplier_scale{config.embedding_multiplier_scale}'
           f'-head_dim{config.head_dim}'
           f'-theta{config.rope_theta}'
           f'-lr{learning_rate}'
           f'-decay{weight_decay}'
            f'-tot_num_experts{config.tot_num_experts}'
            f'-chosen_num_experts{config.chosen_num_experts}'
            f'-use_scale{config.use_scale}'
           f'-batch{batch_size}'
            f'-train_iter{max_iters}'
           f'--{time.strftime("%Y-%m-%d_%H-%M-%S")}.pth')

# Save the model
model_path = os.path.join(model_dir, filename)
torch.save(model.state_dict(), model_path)

### Alternatively, you can load the 1m parameter model I already trained

In [8]:
# Initialize a blank model
model = minGrok(config, tokenizer).to(config.device)

# here's the path to a minGemma model that i've trained with roughly 1m parameters
path = 'models/minGrok-v128-max_t256-layers4-heads4-kv_heads1-hidden96-embedding_multiplier_scale2-head_dim24-theta100.0-lr0.0003-decay0.01-tot_num_experts4-chosen_num_experts2-use_scaleTrue-batch32-train_iter5000--2024-03-21_18-20-32.pth'

# Load the saved state dictionary
model.load_state_dict(torch.load(path))
# REMEMBER TO CHANGE VALUES IN CONFIG TO MATCH THE MODEL YOU'VE LOADED

# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')

# If you only plan to do inference, switch to evaluation mode
model.eval()

# If you plan to continue training the model, switch to training mode
#model.train()

992.352 K parameters


minGrok(
  (embedder): Embedding(128, 96)
  (layers): ModuleList(
    (0-3): 4 x DecoderLayer(
      (mqa): MQA(
        (qkv_proj): Linear(in_features=96, out_features=144, bias=False)
        (o_proj): Linear(in_features=96, out_features=96, bias=False)
      )
      (moe): MoELayer(
        (experts): ModuleList(
          (0-3): 4 x Expert(
            (layer1): Linear(in_features=96, out_features=384, bias=False)
            (layer2): Linear(in_features=192, out_features=96, bias=False)
          )
        )
        (router): Router(
          (router_weights): Linear(in_features=96, out_features=4, bias=False)
        )
      )
      (pre_mqa_norm): RMSNorm()
      (post_mqa_norm): RMSNorm()
      (pre_moe_norm): RMSNorm()
      (post_moe_norm): RMSNorm()
      (drop): Dropout(p=0.05, inplace=False)
    )
  )
  (final_norm): RMSNorm()
  (criterion): CrossEntropyLoss()
)

### Testing (performing inference)

In [13]:
input_str = "JULIET:\nO Romeo, Romeo! wherefore art thou" # the classic line
max_useable_output_len = config.max_position_embeddings - len(input_str)
output = model.generate(input_str, output_len = max_useable_output_len)
print(output)

JULIET:
O Romeo, Romeo! wherefore art thou in.

Tome?

Nurse:
Third peaguisrener:
Lo, show and go yours, here mace meraticome
For a thee be oneeget and the lambron a it-ntard; whileTHerle you fair murfeen a 'tis to like.

MENEnguill Yort death their honour mind,
If such therese the curry woront, I that mine,
Why the stays is of him in still.

F
