In [None]:
# my virtual environments are rarely properly connected to jupyter so this fixes that
import sys
import os
current_dir = os.getcwd()  # Get the current working directory
venv_dir = os.path.join(current_dir, 'venv') 
python_version = str(sys.version_info.major) + '.' + str(sys.version_info.minor)
site_packages_path = os.path.join(venv_dir, 'lib', 'python' + python_version, 'site-packages')
sys.path.append(site_packages_path) 

In [8]:
# Importing pytorch
import torch
import torch.nn as nn
from torch.nn import functional as F

# imports for the debugging/demonstration setup
import functools
import inspect

# imports for the tokenizer
from tiny_stories_tokenizer import *

# Imports used for the config
import dataclasses 
from typing import Optional

# for the dataset
from torch.utils.data import Dataset, DataLoader
import pickle

# Imports used for the model
import re
from typing import Any, List, Sequence, Tuple, Union

# used in the training loop
import time

# used to save & load models
import json
from dataclasses import asdict

In [9]:
# this function will be used throughout for debugging/demonstration purposes
# using this is way cleaner than cluttering up our code with print statements
def log_io(func):
    @functools.wraps(func)
    def wrapper(self, *args, **kwargs):
        # Check if logging is enabled globally and for the specific function
        if not self.logging_enabled or func.__name__ in self.disabled_logging_functions:
            return func(self, *args, **kwargs)
        #if not self.logging_enabled:
            #return func(self, *args, **kwargs)

        def log_item(item, name, level=0, is_root=False):
            indent = "    " * level
            if isinstance(item, torch.Tensor):
                print(f"{indent}Tensor '{name}' shape: {item.shape}")
            elif isinstance(item, tuple):
                if is_root and level == 0:
                    # Root level tuple, don't print it as a tuple unless it's a "true" tuple
                    for idx, sub_item in enumerate(item):
                        log_item(sub_item, f"{name}[{idx}]", level)
                else:
                    print(f"{indent}Tuple '{name}':")
                    for idx, sub_item in enumerate(item):
                        log_item(sub_item, f"{name}[{idx}]", level + 1)
            elif isinstance(item, int):
                print(f"{indent}Integer '{name}': Value={item}")
            elif isinstance(item, float):
                print(f"{indent}Float '{name}': Value={item}")
            else:
                print(f"{indent}Other-type '{name}': Type={type(item).__name__}, Value={item}")

        print(f"\n{'='*10}Entering {self.__class__.__name__}.{func.__name__}{'='*10}")
        print("Inputs:")
        arg_names = inspect.getfullargspec(func).args[1:]  # Excluding 'self'
        arg_values = args + tuple(kwargs.values())
        for name, value in zip(arg_names, arg_values):
            log_item(value, name)

        result = func(self, *args, **kwargs)
        print("\nOutputs:")
        if isinstance(result, tuple):
            log_item(result, "output", is_root=True)
        else:
            log_item(result, "output")

        print(f"{'='*10}Exiting {self.__class__.__name__}.{func.__name__}{'='*10}")
        return result
    return wrapper

class LoggingModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.logging_enabled = False
        self.disabled_logging_functions = set()

    def enable_logging(self):
        self.logging_enabled = True

    def disable_logging(self):
        self.logging_enabled = False

    def disable_function_logging(self, func_name):
        self.disabled_logging_functions.add(func_name)

    def enable_function_logging(self, func_name):
        self.disabled_logging_functions.discard(func_name)

In [10]:
tokenizer = get_tokenizer(size = 128) # size options are 128, 256, 512 and 1024

In [28]:
@dataclasses.dataclass # a class meant specifically to just hold data
class Config:
    """ 
    The default configuration & hyperparameters for next-concept predictor
    """
    ### boring hyperparameters ###
    vocab_size: int = tokenizer.vocab_len
    max_seq_len: int = 128
    num_hidden_layers: int = 8
    num_q_heads: int = 4
    num_kv_heads: int = 1 
    assert num_q_heads % num_kv_heads == 0
    embed_dim: int = 128 
    mlp_multiplier: int = 4
    head_dim: int = 32
    theta = 100.0
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    batch_size = 32
    dropout_rate = 0.1 # this has only been implemented into the MLP so far

config = Config()

In [12]:
def RoPE(x: torch.Tensor, dim: int, theta: float = 10000.0) -> torch.Tensor:
    """Applies the rotary embedding to the inputted query or key tensor"""
    # Get sequence length
    seq_len = x.size(1)
    device = x.device
    
    # Dynamically compute frequency cis based on the input sequence length
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device).float() / dim))
    t = torch.arange(seq_len, device=device)
    freqs = torch.outer(t, freqs).float()
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64

    # Apply rotary embeddings to the input tensor
    x_ = torch.view_as_complex(torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), dim=-1))
    x_out = torch.view_as_real(x_ * freqs_cis.unsqueeze(0)).type_as(x)  # Ensure batch dimension is handled
    x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
    x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2], -1).transpose(1, 2)

    return x_out

In [13]:
class MQA(LoggingModule):
    
    def __init__(self, config: Config):
        super().__init__()

        self.num_q_heads = config.num_q_heads
        self.num_kv_heads = config.num_kv_heads
        assert self.num_q_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_q_heads // self.num_kv_heads

        self.embed_dim = config.embed_dim
        self.head_dim = config.head_dim
        self.theta = config.theta

        # Calculates the total size for all projections.
        self.q_size = self.num_q_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim

        # Defines the scaling factor for the attention scores.
        self.scaling = self.head_dim**-0.5

        self.qkv_proj = nn.Linear(self.embed_dim, (self.num_q_heads + 2 * self.num_kv_heads) * self.head_dim, bias=False)
        self.o_proj = nn.Linear(self.num_q_heads * self.head_dim, self.embed_dim, bias=False)
    
        # for our attention mask we'll create a boolean mask that'll later be turned into large negative values
        self.mask = torch.tril(torch.ones((config.max_seq_len, config.max_seq_len), dtype=torch.uint8)
                              ).view(1, 1, config.max_seq_len, config.max_seq_len).to(dtype=torch.bool)

    @log_io
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Extracts batch size and input sequence length from the hidden states tensor.
        batch_size, input_len, _ = x.shape

        # Applies the linear projection to the hidden state to retrieve our q, k & v projections
        qkv = self.qkv_proj(x)
        
        # Splits the combined QKV tensor into separate tensors for queries (xq), keys (xk), and values (xv) based on their respective sizes.
        xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size],dim=-1)

        # Reshapes each of the Q, K, and V tensors to separate the heads and align the dimensions for attention operations.
        xq = xq.view(batch_size, -1, self.num_q_heads, self.head_dim)
        xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim)
        xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim)

        # Applies rotary positional embeddings to queries and keys to incorporate positional information.
        xq = RoPE(xq, self.head_dim, self.theta)
        xk = RoPE(xk, self.head_dim, self.theta)

        # If the number of KV heads is different from the number of query heads, adjusts keys and values to match the query heads count.
        if self.num_kv_heads != self.num_q_heads:
            xk = torch.repeat_interleave(xk, self.num_queries_per_kv, dim=2) # [batch_size, input_len, n_local_heads, head_dim]
            xv = torch.repeat_interleave(xv, self.num_queries_per_kv, dim=2)

        # Transposes Q, K, and V tensors to align them for the batch matrix multiplication in attention calculation.
        q = xq.transpose(1, 2) # [batch_size, n_local_heads, input_len, head_dim]
        k = xk.transpose(1, 2)
        v = xv.transpose(1, 2)

        # Calculates attention scores by performing a batch matrix multiplication between queries and keys, followed by scaling.
        logits = torch.matmul(q, k.transpose(2, 3)) * self.scaling # [batch_size, n_local_heads, input_len, input_len]
        
        # Applies the lower-triangular mask to the attention logits
        logits = torch.where(self.mask[..., :input_len, :input_len].expand_as(logits), 
                             logits, 
                             torch.tensor(-1e30, device=logits.device, dtype=logits.dtype))

        # Applies softmax to the logits to obtain attention probabilities
        scores = F.softmax(logits, dim=-1)
        
        # Computes the weighted sum of values based on the attention scores to obtain the output of the attention mechanism.
        output = torch.matmul(scores, v) # [batch_size, n_local_heads, input_len, head_dim]

        # Reshapes the attention output to match the expected output dimensions, combining the heads back into the hidden dimension.
        output = output.transpose(1, 2).contiguous().view(batch_size, input_len, -1) # [batch_size, input_len, hidden_dim]

        # Applies the final linear projection to the attention output, mapping it back to the hidden size dimension.
        output = self.o_proj(output)

        return output

In [14]:
class MLP(LoggingModule):
    def __init__(self,
                 embed_dim: int,
                 mlp_multiplier: int,
                 output_dim: int,
                 dropout_rate: float = 0.1):
        super().__init__()
        self.mlp_multiplier = mlp_multiplier
        self.hidden_size = embed_dim * mlp_multiplier
        self.dropout_rate = dropout_rate

        # the gate, up and down projections
        self.gate_proj = nn.Linear(embed_dim, self.hidden_size)
        self.up_proj = nn.Linear(embed_dim, self.hidden_size)
        self.down_proj = nn.Linear(self.hidden_size, output_dim)
        
    @log_io
    def forward(self, x: torch.Tensor, training: bool = False ) -> torch.Tensor:
        output = self.down_proj(F.gelu(self.gate_proj(x)) * self.up_proj(x))
        return F.dropout(output, p=self.dropout_rate, training=training)

In [15]:
class RMSNorm(LoggingModule):
    def __init__(self, num_features, eps=1e-5, use_scale=True):
        super(RMSNorm, self).__init__()
        self.eps = eps
        self.scale = nn.Parameter(torch.ones(num_features)) if use_scale else None

    @log_io
    def forward(self, inputs):
        # Calculate the mean squared value for each feature
        mean_squared = inputs.pow(2).mean(dim=-1, keepdim=True)

        # Normalize inputs
        normed_inputs = inputs * torch.rsqrt(mean_squared + self.eps)

        # Apply scale if it exists
        if self.scale is not None:
            normed_inputs = normed_inputs * self.scale

        return normed_inputs

In [16]:
class Layer(LoggingModule):
    def __init__(self, config: Config):
        super().__init__()

        self.mqa = MQA(config)
        self.mlp = MLP(config.embed_dim, config.mlp_multiplier, config.embed_dim, config.dropout_rate)
        
        self.pre_mqa_norm = RMSNorm(config.embed_dim, use_scale=True)
        self.post_mqa_norm = RMSNorm(config.embed_dim, use_scale=True)
        self.pre_mlp_norm = RMSNorm(config.embed_dim, use_scale=True)
        self.post_mlp_norm = RMSNorm(config.embed_dim, use_scale=True)

    @log_io
    def forward(self, x: torch.Tensor ) -> torch.Tensor:
        x = x + self.post_mqa_norm(self.mqa(self.pre_mqa_norm(x)))
        x = x + self.post_mlp_norm(self.mlp(self.pre_mlp_norm(x)))
        return x

In [17]:
class customGPT(LoggingModule):

    def __init__(self,
        config: Config, # the hyperparameters
        tokenizer: tokenizer, # the tokenizer. we don't always store the tokenizer inside of the model, but it doesn't matter here
    ):
        super().__init__()
        self.config = config

        # the attention heads need to cleanly divide up the embed_dim of the model so that we can split it all apart & combine back together
        assert config.embed_dim % config.num_q_heads == 0

        self.max_seq_len = config.max_seq_len
        self.head_dim = config.head_dim
        self.vocab_size = config.vocab_size
        self.tokenizer = tokenizer

        # the embedding matrix. for converting tokens to the first residual state, and the last residual state to logits
        self.embedder = nn.Embedding(self.vocab_size+1, config.embed_dim)
        self.scaling = config.embed_dim ** 0.5 # for normalizing the first embedding
        
        # Initialize a sequence of DecoderLayer instances as specified by the number of hidden layers in the config
        self.layers = nn.ModuleList(Layer(config) for _ in range(config.num_hidden_layers))

        # Initialize a normalization layer to be applied after the last decoder layer, stabilizing the output
        self.final_norm = RMSNorm(config.embed_dim, use_scale=True)

        # the loss function
        self.criterion = nn.CrossEntropyLoss()

    @log_io
    def forward(
        self,
        input_token_ids: torch.Tensor, # a shape (batch_size, input_seq_len) list of integer token ids
        target_token_ids: torch.Tensor = None, # a shape (batch_size, input_seq_len) list of token ids to train on
        ) -> torch.Tensor:

        # turn the input tokens into the first resudial state using the embedding matrix
        x = self.embedder(input_token_ids) * self.scaling # (batch_size, input_len) & (vocab_size, embed_dim) -> (batch_size, input_len, embed_dim)

        # Iteratively process the input through each Layer
        for i in range(len(self.layers)):
            layer = self.layers[i]
            x = layer(x)
        
        # Apply normalization to the output of the final decoder layer
        x = self.final_norm(x)

        # grabbing the weights of the embedding matrix shape (vocab_size, hidden_dim) for use as the output layer
        embedder_weight = self.embedder.weight

        # the embedding matrix is also used as the output layer
        logits = torch.matmul(x, embedder_weight.t()) # (batch_size, input_len, embed_dim) @ (embed_dim, vocab_size) -> (batch_size, input_len, vocab_size)
        
        if target_token_ids is None: # if we're not training, then we don't need to calculate loss
            loss = None
        else:
            # if we are training
            batch_size, input_len, vocab_size = logits.shape
            # then we reshape our logits & targets before calculating cross-entropy loss
            loss = self.criterion(logits.view(batch_size*input_len, vocab_size), 
                                  target_token_ids.reshape(batch_size*input_len))
        
        return logits, loss

    @torch.no_grad() # no need to keep track of gradients during inference
    @log_io
    def Sampler(
        self,
        logits: torch.Tensor, # shape (batch_size, input_len, vocab_size)
        temperature: float, # controls how boring vs random the outputs should be
        top_p: float, # the maximum cumulative probability of output options we're willing to consider
        top_k: int, # the maximum number of output options we're willing to consider
    ) -> torch.Tensor:
        """
        The Sampler function is responsible for generating token predictions
        It supports temperature scaling, top-p (nucleus) sampling, and top-k sampling 
        """
        # Select the last element for each sequence.
        logits = logits[:,-1,:] # (batch_size, input_len, vocab_size) -> (batch_size, vocab_size)
        
        # Apply temperature scaling
        logits.div_(temperature) # (batch_size, vocab_size) / float -> (batch_size, vocab_size)

        # Calculate probabilities with softmax.
        probs = torch.softmax(logits, dim=-1, dtype=torch.float) # dim=-1 is the vocab_size dimension that we calculate along

        # sort the probabilities to for use in top-p & top-k. both are (batch_size, vocab_size)
        probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)

        ### calculating top-p
        # creates same-size tensor of cumulatve probabilities instead of indivdiual probs
        probs_sum = torch.cumsum(probs_sort, dim=-1) 
        # mask where 0's are top-p selections & 1's are to be excluded
        top_ps_mask = (probs_sum - probs_sort) > top_p
        # the original probabilities with excluded tokens changed to 0.0
        probs_sort = torch.where(top_ps_mask, 0, probs_sort) 

        ### calculating top_k
        # create a shape (vocab_size) tensor that just iterates up by 1's
        top_ks_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device) 
        # expand our mask along the batch_size dimension to become size (batch_size, vocab_size)
        top_ks_mask = top_ks_mask.expand(probs_idx.shape[0], -1)
        # top_ks is a list of integers. we keep whichever entries in top_ks_mask are greater than their corresponding entries in top_ks
        top_ks_mask = top_ks_mask >= top_k

        # we'll be combining top-p with top-k and using whichever gives us fewer tokens. a very conservative approach
        # this trims probs_sort to also fit within our top_k requirement
        probs_sort = torch.where(top_ks_mask, 0, probs_sort)

        # Re-normalization so that total probabilities add up to 1
        probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
        
        # now we rearrange the modified probabilities in probs_sort back to their original order according to probs_idx
        probs = torch.gather(probs_sort, dim=-1, index=torch.argsort(probs_idx, dim=-1))
        
        # samples from the distribution
        next_token_id = torch.multinomial(probs, num_samples=1)
        
        return next_token_id # returns the predicted token
        
    @log_io
    def generate(
        self,
        prompt: str,
        output_len: int = None, 
        temperature: float = 1.0, # defaulting to 1.0 means we essentially don't use temperature
        top_p: float = 1.0, # defaulting to 1.0 means we essentially don't use top-p
        top_k: int = config.vocab_size, # setting top_k = vocab_size means we're effectively not using top_k at all
    ) -> str: 
        """ Wrapper around sampler() that deals with manipulation of the sequence """
        # encoding the prompt into token indices
        tokens = self.tokenizer.encode(prompt)
        
        if output_len is None:
            output_len = config.max_seq_len - len(tokens)

        # turning it into the right tensor shape
        tokens = torch.tensor(tokens, device=config.device).unsqueeze(0)
        
        # we wouldn't want to go past the maximum context length we trained on
        assert len(tokens) + output_len <= self.config.max_seq_len

        for i in range(output_len):
            # get the model's output logits and ignore the loss, which would be a NoneType object
            logits, _ = self(tokens[:,:self.max_seq_len])
            
            next_token = self.Sampler(
                logits = logits,
                temperature = temperature,
                top_p = top_p,
                top_k = top_k
            )
            if next_token == config.vocab_size: break

            # add our new token to the sequence
            tokens = torch.cat((tokens, next_token), dim=1)

        # decode our list of tokens to an actual string
        output = self.tokenizer.decode(tokens.squeeze(0).tolist())

        return output

# Training-related Functions

In [18]:
class TinyStoriesDataset(Dataset):
    def __init__(self, file_path):
        # Store the file path
        self.file_path = file_path
        
        # Open the pickle file and load the list of indices (or any structure that allows random access)
        with open(file_path, 'rb') as file:
            self.data = pickle.load(file)
        
        # If your data is not a list but a single string, you need to preprocess it here to split it into samples

    def __len__(self):
        # Return the number of items in the dataset
        return len(self.data)

    def __getitem__(self, idx):
        # Fetch the story at the given index
        story = self.data[idx]
        
        # Here you would typically convert the story to a tensor, for example, by tokenizing it
        # For simplicity, let's assume the story is already in the desired format
        return story

#config.batch_size = 32

train_dataset = TinyStoriesDataset('tiny_stories_train_data.pkl')
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)

val_dataset = TinyStoriesDataset('tiny_stories_val_data.pkl')
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False)

train_loader_iter = iter(train_loader)
val_loader_iter = iter(val_loader)
next_batch = next(train_loader_iter)  # Fetch the next batch
# Now, you can process `next_batch` as needed
print(len(next_batch), type(next_batch), '\n\n', len(next_batch[0]), next_batch[0], '\n\n', len(tokenizer.encode(next_batch[0])), tokenizer.encode(next_batch[0]))

32 <class 'list'> 

 660 Once upon a time, there were two friends who were walking in the park. They came to a bakery, which smelled yummy. One of the friends asked the other, "what do you want to eat?".
The other friend replied, "I want an original pie!" The baker took out a pie and shook it. The friends could feel the pie shaking in the baker's hands. They looked so excited!
But when the baker gave the pie to the friends, it was not an original pie! It was a sour one! The friends were sad and grumpy. All their hopes of eating an original pie had been shaken away.
The friends sadly left the bakery, with no pie in their hands. They never came back to the bakery again. The end. 

 481 [36, 61, 50, 52, 1, 68, 63, 95, 1, 48, 1, 67, 96, 52, 79, 77, 83, 1, 70, 89, 52, 1, 67, 70, 62, 1, 53, 121, 92, 51, 66, 1, 70, 55, 62, 1, 70, 89, 52, 1, 85, 59, 58, 102, 1, 82, 1, 77, 1, 63, 97, 58, 75, 110, 1, 50, 108, 52, 1, 80, 1, 48, 1, 49, 48, 58, 89, 72, 79, 70, 55, 56, 50, 55, 1, 66, 60, 52, 101, 78

In [19]:
def process_batch(batch, tokenizer, config):
    """
    Process a batch of strings for transformer training.

    Args:
    batch (list of str): The batch of strings to process.
    tokenizer: The tokenizer to use for encoding the strings.
    config: Configuration object with max_seq_len and vocab_size attributes.

    Returns:
    torch.Tensor, torch.Tensor: The input and target tensors for the transformer.
    """
    tokenized_batch = [tokenizer.encode(string)[:config.max_seq_len+1] for string in batch]
    max_length = min(max(len(tokens) for tokens in tokenized_batch), config.max_seq_len)

    # Pad the sequences
    padded_batch = [tokens + [config.vocab_size] * (max_length+1 - len(tokens)) for tokens in tokenized_batch]

    # Convert to PyTorch tensor
    tensor_batch = torch.tensor(padded_batch, dtype=torch.long)

    # Split into input and target tensors
    input_tensor = tensor_batch[:, :-1]  # Exclude the last token for input
    target_tensor = tensor_batch[:, 1:]  # Exclude the first token for target

    return input_tensor, target_tensor

# Example usage
# Assuming you have a tokenizer and config object defined, you would call this function like this:
x,y = process_batch(next(train_loader_iter), tokenizer, config)
x.shape,x, y.shape,y 
# you may notice the beginnings often look the same, but that's only because many stories begin with something like "Once upon a time,..."

(torch.Size([32, 128]),
 tensor([[ 36, 109,   1,  ...,  48,   1,  49],
         [ 36,  61,  50,  ...,  90,   1,  54],
         [ 23,  92,   1,  ...,  51,   1,  50],
         ...,
         [123,   1,  81,  ...,  48,  69,  52],
         [ 36,  61,  50,  ...,   1, 126,   1],
         [ 36,  61,  50,  ...,  63,  56,  73]]),
 torch.Size([32, 128]),
 tensor([[109,   1, 122,  ...,   1,  49, 115],
         [ 61,  50,  52,  ...,   1,  54,  95],
         [ 92,   1,  81,  ...,   1,  50,  48],
         ...,
         [  1,  81,   1,  ...,  69,  52,   1],
         [ 61,  50,  52,  ..., 126,   1,  97],
         [ 61,  50,  52,  ...,  56,  73,  73]]))

In [26]:
@torch.no_grad()
def estimate_loss(model, eval_iters = 5): # to estimate loss during the training loop
    out = {}
    model.eval() # sets model to eval mode
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = process_batch(next(val_loader_iter), tokenizer, config)
            logits, loss = model(X.to(config.device), Y.to(config.device))
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train() # just resets to training mode
    return out

# Instantiate a brand new model

In [29]:
model = customGPT(config, tokenizer).to(config.device)

# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')

print(model)

1930.496 K parameters
customGPT(
  (embedder): Embedding(129, 128)
  (layers): ModuleList(
    (0-7): 8 x Layer(
      (mqa): MQA(
        (qkv_proj): Linear(in_features=128, out_features=192, bias=False)
        (o_proj): Linear(in_features=128, out_features=128, bias=False)
      )
      (mlp): MLP(
        (gate_proj): Linear(in_features=128, out_features=512, bias=True)
        (up_proj): Linear(in_features=128, out_features=512, bias=True)
        (down_proj): Linear(in_features=512, out_features=128, bias=True)
      )
      (pre_mqa_norm): RMSNorm()
      (post_mqa_norm): RMSNorm()
      (pre_mlp_norm): RMSNorm()
      (post_mlp_norm): RMSNorm()
    )
  )
  (final_norm): RMSNorm()
  (criterion): CrossEntropyLoss()
)


# Training

In [32]:
# create a PyTorch optimizer
# this is not what they used, but this learning rate & weight decay work for our tiny minGemma
learning_rate = 3e-4
weight_decay = 0.01
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# how long we want to train for
max_iters = 2

# how often we want to check & see how our loss is doing
eval_interval = 1

In [33]:
start_time = time.time()

# Enable anomaly detection. uncomment these lines if you need to do extensive debugging
#torch.autograd.set_detect_anomaly(True)

for iter in range(max_iters):

    # sample a batch of data
    xb, yb = process_batch(next(train_loader_iter), tokenizer, config)
    
    # train
    logits, loss = model(xb.to(config.device), yb.to(config.device))
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    
    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        current_time = time.time()
        elapsed_time = current_time - start_time
        losses = estimate_loss(model)
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, time elapsed: {elapsed_time:.2f} seconds")

# Disable anomaly detection after the training loop
#torch.autograd.set_detect_anomaly(False)

step 0: train loss 112.0365, val loss 111.9393, time elapsed: 1.18 seconds
step 1: train loss 106.3015, val loss 106.3094, time elapsed: 3.80 seconds


# Saving your model

In [20]:
name = f'models/{model.__class__.__name__}_{time.strftime("%Y-%m-%d|%H-%M-%S")}'
torch.save(model.state_dict(), f'{name}.pth')

# Convert the dataclass object to a dictionary
config_dict = asdict(config)

# Serialize the dictionary to a JSON file
with open(f'{name}.json', 'w') as f:
    json.dump(config_dict, f)

# Load a Pretrained Model

In [14]:
name = '?'

# Deserialize the JSON file back to a dictionary
with open(f'models/{name}.json', 'r') as f:
    config_dict = json.load(f)

# Convert the dictionary back to a dataclass object
config = Config(**config_dict)

# Initialize a blank model
model = Model(config, tokenizer).to(config.device)  

# here's the path to a minGemma model that i've trained with roughly 1m parameters
path = f'models/{name}.pth'

# Load the saved state dictionary
model.load_state_dict(torch.load(path)) 
# REMEMBER TO CHANGE VALUES IN CONFIG TO MATCH THE MODEL YOU'VE LOADED

# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')

# If you only plan to do inference, switch to evaluation mode
model.eval()

# If you plan to continue training the model, switch to training mode
#model.train()

FileNotFoundError: [Errno 2] No such file or directory: 'models/?.pth'

# Inference

In [34]:
input_str = "JULIET:\nO Romeo, Romeo! wherefore art thou " # the classic line
max_useable_output_len = config.max_seq_len - len(input_str)
output = model.generate(input_str, output_len = max_useable_output_len, temperature=50.0)
print(output)

JULIET:
O Romeo, Romeo! wherefore art thou ouplayitx-. ennIigSlemWQarppribebePam8ttinayGozouonOUhTheyutGlo G;KWomroufandhaiheTheywststayall:Hplay?ppooz;;beanKKTheyeramsaDdari
uO43


In [35]:
len(output)

179

In [36]:
len(tokenizer.encode('JULIET:\nO Romeo, Romeo! wherefore art thou 9b\nerAk!ot11EEerr$L$$$toplayplayplay2dwSou1waswas. sdha9ll,wasulTheJheomJJ, zplayplayTheyplayver?zEer;NP5000baThey?playplayplaver, stonmm4allCiPPha2CenenhiswawasirXandiMor,A'))

136