# minGemma

our goal here is to teach all of the architectural concepts used in [Google's Gemma](https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf). The code here is a frankensteinian mixture between a little bit of [Andrej Karpathy's minGPT](https://www.youtube.com/watch?v=kCc8FmEb1nY&t=5012s) and mostly [Google's pytorch implementation of Gemma](https://github.com/google/gemma_pytorch). As such, certain aspects of the Gemma model will or won't be included based on our learning goals, which are to understand the architecture rather than the tokenizer, data cleaning, and distributed training techniques.

#### Included:
- RoPE
- RMSnorm
- GeGLU activation
- Multi-Query Attention (only found in the smaller model)
- The extra normalization in the middle of each layer, between the attention mechanism and the MLP

#### Not Included: 
- **Gemma's training dataset:** The dataset & corresponding RLHF framework used to train Gemmma is out-of-scope for this lesson. This is unfortunate since a big part of Gemma's impressive performance is attributed to their high quantity and quality of data. I recommend reading the technical report for more information
- **Tokenization training:** The full 256,128 token vocabulary of Gemma is unusually large even for models orders of magnitude larger than Gemma (if I remember correctly, Llama's is around 32k). Their tokenizer also has interesting qualities such as beginning-of-sentence & end-of-sentence tokens & other tokens related to instruct tuning. The issue with training such a large token vocabulary is that many tokens are likely to rarely/never be trained, thus meaning you've wasted compute and leaving yourself vulnerable to the [SolidGoldMagikarp problem](https://www.lesswrong.com/posts/aPeJE8bSo6rAFoLqg/solidgoldmagikarp-plus-prompt-generation) unless you use an absurdly large dataset size like they did (2T and 6T tokens for the 2b and 7b models respectively). That being said, a better lesson than mine would've taught you how to train a more reasonably sized vocabulary. Instead of doing that, I recommend for homework you check out [Andrej Karpathy's recent guide on building tokenizers](https://www.youtube.com/watch?v=zduSFxRajkE) and implement a reasonably sized one yourself. Instead of that, we'll be using simple character-wise tokenization of TinyShakespeare by default, but the code will include the option to use Gemma's original gigantic tokenizer just so you can see what it looks like.
- **Highly parallelized training:** This is something I don't believe google included in their open-sourcing of these models which is unfortunate, but reasonable given that they trained on TPUs and no one owns TPUs except google. Rather, their open-sourced code (at least the pytorch version) is designed strictly for inference with their open-sourced weights. I recommend looking elsewhere if you want to learn about parallelized training techniques like sharding on GPUs
- **KV caching:** A standard method for saving on compute in transformers, KV caching involves saving individual key and value vectors calculated at each step of inference for use at future steps of inference. I've not included it here because it's an inference-specific mechanism whereas this code is meant to train, but if you develop your own inference version of this code then you 100% need to re-implment KV caching
- **Quantization:** This shouldn't be necessary for the absolutely tiny size of the default parameters I've set for minGemma
- **Batched Inference:** The open-sourced pytorch code for Gemma is setup to efficiently handle batched decoding of variable context lengths. The implementation here is only designed to perform inference on a single prompt at a time, which is all we really need for learning about the gemma architecture; i would hope no one is using this code for deployment.
- **Embedding Bias:** There was an odd reference to a bias term for use in the output layer after multiplying by the transpose of the embedding matrix. It was labeled as optional and didn't get used in google's own pytorch implementation so i assume it's inclusion was an accident, but it may have been utilized in different instantiations of the model. My interpretation of this idea is that some tokens are always more or less likely than others, which a bias term could capture. If anyone knows feel free to reach out or push an update to the repo

### ToDo
- see if i can remove typing from the attention softmax calculation
- clean everything up

In [1]:
# Importing pytorch
import torch
import torch.nn as nn
from torch.nn import functional as F

# Imports used for the config
import dataclasses 
from typing import Optional

# Imports used for the model
import re
from typing import Any, List, Sequence, Tuple, Union

# for optionally downloading the original Gemma tokenizer
import os
import requests
from sentencepiece import SentencePieceProcessor

# used in the training loop
import time

# The Dataset

the dataset we'll be using is just TinyShakespeare for sake of simplicity & ability to do run/train locally on any computer

In [2]:
# load the dataset
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# the first 200 characters. It's just one continuous text document with all of the works of shakespeare back-to-back
print(text[:200])

# here are all the unique characters that occur in this text and how many there are
chars = sorted(list(set(text)))
v = len(chars)
print('\n', chars, v)

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you
['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'] 65


# The Tokenizer(s)

We'll be tokenizing by individual characters rather than true tokenization, meaning every single letter, number, and symbol is its own token. Tokenization is an often underlooked part of the LLM design process, and I suggest you check out the [recent video by Andrej Karpathy](https://www.youtube.com/watch?v=zduSFxRajkE) that goes in-depth on how to make your own tokenizer. Think of it as a homework assignment to:
1. watch his video
2. build your own tokenizer with let's say 500 tokens in it
3. replace this character-wise tokenizer with your own

Right after defining `CharacterTokenizer` which is the one we'll actually be using, I've also defined `OriginalGemmaTokenizer` for demonstration purposes. The latter is just copy & pasted from [Google's own Gemma pytorch repo](https://github.com/google/gemma_pytorch). Again we will not be delving into the full details of their tokenizer; rather, just take a look after you watch the video by Andrej and see what you can learn. It seems google built Gemma's tokenizer on top of SentencePiece and included extra special characters for the beginning of a sentence, end of a sentence, and a padding token that's built to just fill in empty-space for sequences that aren't as long as the model's full context length. 

In [3]:
# the one we'll actually be using
class CharacterTokenizer:
    def __init__(self, chars: List[str]):
        """ Create mappings from characters to integers and vice versa """
        self.stoi = {ch: i for i, ch in enumerate(chars)}
        self.itos = {i: ch for i, ch in enumerate(chars)}

    def encode(self, s: str) -> List[int]:
        """Converts a string into a list of character IDs"""
        return [self.stoi.get(c) for c in s] 

    def decode(self, t: List[int]) -> str:
        """Converts a list of character IDs back into a string."""
        return ''.join([self.itos.get(i) for i in t])

# The original
class OriginalGemmaTokenizer:
    def __init__(self, model_path: Optional[str]):
        # load tokenizer.
        assert os.path.isfile(model_path), model_path
        self.sp_model = SentencePieceProcessor(model_file=model_path)

        # BOS / EOS token IDs.
        self.n_words: int = self.sp_model.vocab_size()
        self.bos_id: int = self.sp_model.bos_id()
        self.eos_id: int = self.sp_model.eos_id()
        self.pad_id: int = self.sp_model.pad_id()
        assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()

    def encode(self, s: str, bos: bool = True, eos: bool = False) -> List[int]:
        """Converts a string into a list of tokens."""
        assert isinstance(s, str)
        t = self.sp_model.encode(s)
        if bos:
            t = [self.bos_id] + t
        if eos:
            t = t + [self.eos_id]
        return t

    def decode(self, t: List[int]) -> str:
        """Converts a list of tokens into a string."""
        return self.sp_model.decode(t)

# minGemma Config

this is the actual configuration with hyperparameters that we'll be using

In [None]:
@dataclasses.dataclass # a class meant specifically to just hold data
class GemmaConfig:
    """ 
    The default configuration & hyperparameters for minGemma. regular Gemma 2b & 7b are defined below 
    Explanations for many of these will be more clear later when they actually get used
    """
    # The number of tokens in the vocabulary.
    vocab_size: int = v # v was defined earlier when we loaded TinyShakespeare. In Gemma it's 256,128, which is HUGE
    
    # The maximum sequence length that this model might ever be used with.
    max_position_embeddings: int = 256
    # In Gemma it's 8192, but they also use (relatively large) tokens as opposed to characters so really theirs is much longer
    
    # The number of blocks in the model.
    num_hidden_layers: int = 4 # In Gemma 7b it's 28 and 2b it's 18
    
    # The number of attention heads used in the attention layers of the model.
    num_attention_heads: int = 4 # In Gemma 7b it's 16 and 2b it's 8
    
    # The number of key-value heads for implementing attention.
    num_key_value_heads: int = 1 # In Gemma 7b it's 16 and in 2b it's 1
    # Notice what with Gemma 7b num_attention_heads = num_key_value_heads whereas this is not true for 2b
    # this is because 7b uses regular multi-head attention (MHA) while 2b uses multi-query attention (MQA)
    # the difference is that MQA shares single key & value matrices across all heads whereas MHA, the more
    # traditional route, uses unique key & value matrices for each head. MQA is a popular way to reduce
    # total parameter count, and according to Google's internal testing supposedly MHA works better at the
    # 7b size while MQA works better at the 2b size
    
    # The hidden size of the model, AKA the embedding dimension. Each token embedding vector will be this long
    hidden_size: int = 128 # In Gemma 7b it's 3072 and in 2b it's 2048
    
    # The inner dimension of the MLP part of the decoder layer
    intermediate_size: int = 512 # In Gemma 7b it's 24576 and in 2b it's 16384
    
    # The number of head dimensions
    head_dim: int = 32 # In both Gemmas it's 256
    
    # The epsilon used by the rms normalization layers.
    rms_norm_eps: float = 1e-6 # this is to promote numerical stability & prevent dividing by 0
    
    # The path to the model tokenizer if you're not using our character-wise default
    tokenizer: Optional[str] = None 
    # setting to None because we'll be defining ours in this notebook rather than loading it from a file
    
    # the scaling factor that determines the frequencies for the rotary positional encodings
    rope_theta = 100.0 # Gemma and most models use 10,000
    # smaller models should use a smaller theta, but I'm just guessing here. 1000 might work too

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Original Gemma's Config

If you wanted to locally train either full-sized Gemma with its original tokenizer from scratch using only TinyShakespear, you could use these functions to do it. However this is not a good idea because
1. One of the major characteristics that makes the two Gemma models so good is the high quality and abnormally high quantity of data they are trained on, which TinyShakespeare definitely is not. If you wanted though it wouldn't be too difficult to substitute TinyShakespeare for a large open-source dataset currently up on huggingface.
2. Training a 7b or even 2b parameter model requires code optimized for decentralized training over multiple GPUs or TPUs, which this is certainly not.
3. The real Gemma's Tokenizer is too large (256k tokens) for even a regular chinchilla-optimal dataset size. When you have that many tokens, many (if not most) end up not being properly trained unless you use an absurdly large dataset size like Google did, and improperly trained tokens lead to the [SolidGoldMagikarp problem](https://www.lesswrong.com/posts/aPeJE8bSo6rAFoLqg/solidgoldmagikarp-plus-prompt-generation)

To be clear, 

**THE FOLLOWING CODE CELL IS ONLY PRESENT FOR DEMONSTRATION PURPOSES. IT DOES NOT ACTUALLY GET USED**

In [None]:
def get_config_for_7b() -> GemmaConfig:
    return GemmaConfig(
        vocab_size = 256128,
        max_position_embeddings = 8192,
        num_hidden_layers = 28,
        num_attention_heads = 16,
        num_key_value_heads = 16,
        hidden_size = 3072,
        intermediate_size = 24576,
        head_dim = 256,
        tokenizer = 'tokenizer/tokenizer.model', # to load the actual Gemma Tokenizer
        rope_theta = 10000.0
    )

def get_config_for_2b() -> GemmaConfig:
    return GemmaConfig(
        vocab_size = 256128,
        max_position_embeddings = 8192,
        num_hidden_layers = 18,
        num_attention_heads = 8,
        num_key_value_heads = 1,
        hidden_size = 2048,
        intermediate_size = 16384,
        head_dim = 256,
        tokenizer = 'tokenizer/tokenizer.model',
        rope_theta = 10000.0
    )

def download_original_tokenizer():
    """
    I'm too lazy to read through google's license and figure out whether it's ok for me to download & redistribute 
    their tokenizer, so instead i just had ChatGPT write a script that has you download it yourself
    """
    # Configuration
    DESTINATION_FOLDER_PATH = './tokenizer'  # Local path in your Jupyter environment
    FILE_NAME = 'tokenizer.model'  # Name of the file to download
    local_file_path = os.path.join(DESTINATION_FOLDER_PATH, FILE_NAME) # Define the local path to save the file

    # Check if the file already exists
    if os.path.exists(local_file_path):
        print(f'File already exists at {local_file_path}. No download needed.')
        return
    
    # Construct the URL to the raw file on GitHub
    url = f'https://raw.githubusercontent.com/google/gemma_pytorch/main/tokenizer/{FILE_NAME}'
    
    # Make the HTTP GET request
    response = requests.get(url)
    
    # Check if the request was successful
    if response.status_code == 200:
        # Ensure the destination folder exists
        os.makedirs(DESTINATION_FOLDER_PATH, exist_ok=True)
        
        # Write the content of the response to a local file
        with open(local_file_path, 'wb') as file:
            file.write(response.content)
        print(f'File successfully downloaded to {local_file_path}')
    else:
        print(f'Failed to download the file. HTTP status code: {response.status_code}')

ok enough messing around, let's actually instantiate our config and tokenizer

In [5]:
def get_model_config(variant: str = None) -> GemmaConfig:
    if variant == '7b':
        download_original_tokenizer()
        return get_config_for_7b(), OriginalGemmaTokenizer('tokenizer/tokenizer.model')
    elif variant == '2b':
        download_original_tokenizer()
        return get_config_for_2b(), OriginalGemmaTokenizer('tokenizer/tokenizer.model')
    else:
        return GemmaConfig(), CharacterTokenizer(chars) # the default config is minGemma

# you would add as input '7b' or '2b' to choose one of the original Gemma configs instead of minGemma
# but again, don't do that. they're only there for demonstration purposes
config, tokenizer = get_model_config()

# testing the tokenizer
print(tokenizer.encode('hello world!'))
print(tokenizer.decode(tokenizer.encode('hello world!')))

[46, 43, 50, 50, 53, 1, 61, 53, 56, 50, 42, 2]
hello world!


# Rotary Positional Encoding (RoPE)

Gemma uses RoPE, which is a popular relative positional encoding scheme. Positional encodings are designed to help the model understand the order of tokens in the text since transformers don't have a built-in ordering to them when reading a sequence. Instead of telling Gemma the precise location of each token (which would be "absolute" positional encoding), relative positional encodings only reveal to Gemma the placement of different tokens relative to each other. RoPE is a relative positional encoding scheme that functions by multiplying against query & key matrices in the attention mechanism (rather than, for example, addition along the token embeddings) and uses trigonometric functions to effecitvely "rotate" the matrices used in self-attention. RoPE is the standard-to-beat nowadays as it's also find in other notable open-source models like Llama. 

The essential idea is that any given vector within the query & key matrices has a location $m$ and $n$ respectively. Remember a given vector within a query or key matrix corresponds to a specific token in the sequence. You multiply the rotary matrix corresponding to that position against each the query & key, and then when you further multiply those together you effectively have a comparison that takes into account how far away the two tokens are. 

If you like math, then you'll see the relation to a kernel. For some arbitrary $ 0 < \epsilon \leq \frac{\pi}{2N}$ is chosen, where $N$ is the maximum sequence length.

$$ \text{RoPE}(x, m) = x e^{mi\epsilon} $$
$$ \langle\text{RoPE}(q_j, m), \text{RoPE}(k_j, n)\rangle = \langle q_je^{mi\epsilon}, k_je^{ni\epsilon} \rangle $$
$$ = q_jk_je^{mi\epsilon}e^{ni\epsilon} $$
$$ = q_jk_je^{(m-n)i\epsilon} $$
$$ = \text{RoPE}(q_jk_j, m-n) $$

To get a more thorough understanding, check out [the original paper](https://arxiv.org/abs/2104.09864), [this slightly more approachable but very thorough guide](https://blog.eleuther.ai/rotary-embeddings/) or [this far more approachable youtube video with visuals of the rotation](https://www.youtube.com/watch?v=o29P0Kpobz0). I'm not going to go super in-depth in my commenting on the code since this topic has been around for like 3 years now and it's not one of my interests and i'm so tired

In [6]:
def apply_rotary_emb(x: torch.Tensor, dim: int, theta: float = 10000.0) -> torch.Tensor:
    """Applies the rotary embedding to the inputted query or key tensor"""
    # Get sequence length
    seq_len = x.size(1)
    device = x.device
    
    # Dynamically compute frequency cis based on the input sequence length
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device).float() / dim))
    t = torch.arange(seq_len, device=device)
    freqs = torch.outer(t, freqs).float()
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64

    # Apply rotary embeddings to the input tensor
    x_ = torch.view_as_complex(torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), dim=-1))
    x_out = torch.view_as_real(x_ * freqs_cis.unsqueeze(0)).type_as(x)  # Ensure batch dimension is handled
    x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
    x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2], -1).transpose(1, 2)

    return x_out

# Root Mean Square Layer Normalization

Gemma uses [Root Mean Square Layer Normalization](https://arxiv.org/pdf/1910.07467.pdf) (RMSnorm) which is also not particularly new or interesting. As with other forms of layer normalization, the goal is to keep each layer within the same distribution so as to ensure smooth optimization. 

The root mean square is a statistic defined by squaring every entry in a given vector, averaging them, and then square-rooting that, which gives us a single floating point value.
$$ \text{RMS}(a) = \sqrt{\frac{1}{n}\sum\limits_{i=1}^n a_i^2} $$
Then we just divide the vector by that value
$$ \bar{a}_i = \frac{a_i}{\text{RMS}(a_i)}g_i$$

When the mean of summed inputs happens to equal zero, then RMSnorm lines up exactly with the older method [LayerNorm](https://arxiv.org/abs/1607.06450v1). The thesis of RMSnorm is essentially that scaling the variance is the thing that's actually helpful, and centering the mean is not so useful, therefore it makes sense for efficiency's sake to just not set the mean to 0. I'd also like to point out that RMSnorm places the vector onto a hypersphere (a sphere but with many more than 3 dimensions) with radius $\sqrt{n}$. This doesn't really matter here, but it's important for other projects I plan to release in the coming months.

The weird thing with Gemma which we'll see later is that they normalize the residual not only either before (or after--before & after are effectively the same thing) the decoder block, but also in-between the attention and feedforward parts of the block.

In [7]:
class RMSNorm(torch.nn.Module):
    """
    Implements the RMS Normalization (Root Mean Square Normalization) layer.
    RMSNorm is a variant of layer normalization that normalizes the activations
    of the previous layer based on their root mean square value.

    Parameters:
    - dim (int): The dimension of the input features the normalization is applied to.
    - eps (float): A small value added to the denominator for numerical stability. Default is 1e-6.
    - add_unit_offset (bool): If True, adds a unit (1) to the learned scaling coefficient, effectively
      starting with no scaling. If False, the scaling coefficient starts from zero. Default is True.
    """

    def __init__(
        self,
        dim: int,
        eps: float = 1e-6,
        add_unit_offset: bool = True,
    ):
        super().__init__() 
        self.eps = eps  # Small epsilon value for numerical stability since you can't divide by 0
        self.add_unit_offset = add_unit_offset  # Flag to determine if a unit should be added to the weight
        
        # Initialize the weight parameter with zeros, which will be learned during training.
        # The shape of the weight is [dim], meaning one weight per feature dimension.
        self.weight = nn.Parameter(torch.zeros(dim))

    def _norm(self, x):
        """
        Private helper function to normalize the input tensor.

        Parameters:
        - x (Tensor): The input tensor to normalize.

        Returns:
        - Tensor: The normalized tensor.
        """
        # Calculate the root mean square value for each feature (across the last dimension),
        # then use reciprocal square root (rsqrt) for normalization.
        # Add self.eps to the denominator for numerical stability.
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        """
        Forward pass of the RMSNorm layer.

        Parameters:
        - x (Tensor): The input tensor to normalize.

        Returns:
        - output: The normalized and scaled tensor.
        """
        # Normalize the input tensor using the _norm function and ensure the data type matches the input.
        x = self._norm(x.float()).type_as(x)
        
        # If add_unit_offset is True, scale the normalized tensor by (1 + self.weight),
        # effectively starting with a scaling factor of 1 (no scaling).
        # Otherwise, scale by self.weight only.
        if self.add_unit_offset:
            output = x * (1 + self.weight)
        else:
            output = x * self.weight
            
        # Return the scaled output tensor.
        return output

# Multi-Layer Perceptron

The interesting thing here is the use of Gated GeLU (GeGLU) activation function as SwiGLU, which is what Llama uses. Check out this [recent survey of 400 activation functions](https://arxiv.org/pdf/2402.09092.pdf) or [google's exploration of gated activation functions](https://arxiv.org/pdf/2002.05202.pdf) which was cited in the Gemma paper for more information

In [8]:
class GemmaMLP(nn.Module):
    """
    This class implements a multi-layer perceptron with a GeGLU gating mechanism. The GeGLU
    activation combines a standard GeLU activation with a learned gating mechanism, enabling
    the network to control the flow of information more dynamically.

    Attributes:
        gate_proj (nn.Linear): A linear layer that transforms the input tensor to an intermediate
                               representation, which is then passed through a GeLU activation for
                               gating purposes.
        up_proj (nn.Linear): Another linear layer that transforms the input tensor to the same
                             intermediate representation but without gating. This representation
                             is element-wise multiplied by the gated tensor from `gate_proj`.
        down_proj (nn.Linear): A linear layer that transforms the gated and combined tensor back
                               to the original dimension of the hidden size, producing the final output.
    """

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
    ):
        """
        Initializes the GemmaMLP module.

        Parameters:
            hidden_size (int): The size of the input and output tensors.
            intermediate_size (int): The size of the tensor after the initial transformation
                                     and before the gating and final projection. This is typically
                                     larger than the hidden size to allow for a richer representation.
        """
        super().__init__()

        # Linear transformation for the gating mechanism, projecting input to an intermediate size.
        self.gate_proj = nn.Linear(hidden_size, intermediate_size)

        # Linear transformation for the input tensor, also projecting to the intermediate size but
        # intended for element-wise multiplication with the gated output.
        self.up_proj = nn.Linear(hidden_size, intermediate_size)

        # Linear transformation to project the gated and combined tensor back to the original
        # hidden size, completing the MLP structure.
        self.down_proj = nn.Linear(intermediate_size, hidden_size)

    def forward(self, x):
        """
        Defines the forward pass of the GemmaMLP module.

        Parameters:
            x (Tensor): The input tensor to the MLP.

        Returns:
            Tensor: The output tensor after applying the GeGLU gating mechanism and the MLP transformations.
        """
        # Applies linear transformation for gating.
        gate = self.gate_proj(x)

        # Applies GeLU activation to the gate, introducing non-linearity and enabling the gating mechanism.
        gate = F.gelu(gate)

        # Applies another linear transformation to the input tensor for subsequent combination with the gate.
        up = self.up_proj(x)

        # Element-wise multiplication of the gated tensor with the transformed input tensor, modulating
        # the input based on the gate's activation.
        fuse = gate * up

        # Applies the final linear transformation to project the modulated tensor back to the hidden size.
        outputs = self.down_proj(fuse)

        # Returns the final output tensor of the MLP, after gating and modulation.
        return outputs

# ------- BOOKMARK --------

# Attention

In [9]:
class GemmaAttention(nn.Module):
    """
    Implements Multi-Query Attention which supports a distinct number of attention heads for queries and key-values (KV).
    In the case where the same number of queries and key-values are used, this implemenation is equivalent to regular Multi-Head Attention.  
    """
    
    def __init__(
        self,
        config: GemmaConfig
        #hidden_size: int, # The dimensionality of input and output tensors.
        #num_heads: int, # The number of query attention heads
        #num_kv_heads: int, # The number of key-value heads, which can differ from the query heads
        #head_dim: int, # The dimensionality of each attention head
        #theta: float,
        #max_seq_len: int
    ):
        super().__init__()

        self.num_heads = config.num_attention_heads
        self.num_kv_heads = config.num_key_value_heads
        
        # Ensures that the number of query heads is evenly divisible by the number of KV heads.
        assert self.num_heads % self.num_kv_heads == 0
        # Determines the number of query heads associated with each KV head.
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

        self.hidden_size = config.hidden_size
        self.head_dim = config.head_dim
        #self.max_seq_len = max_seq_len
        self.theta = config.rope_theta

        # Calculates the total size for all query projections.
        self.q_size = self.num_heads * self.head_dim
        # Calculates the total size for all key and value projections.
        self.kv_size = self.num_kv_heads * self.head_dim

        # Defines the scaling factor for the attention scores.
        self.scaling = self.head_dim**-0.5

        # Initializes the linear projection layer for queries, keys, and values
        #self.qkv_proj = Linear(self.hidden_size, (self.num_heads + 2 * self.num_kv_heads) * self.head_dim)
        self.qkv_proj = nn.Linear(self.hidden_size, (self.num_heads + 2 * self.num_kv_heads) * self.head_dim, bias=False)
        # Initializes the output projection layer, mapping the concatenated attention outputs back to the hidden size.
        #self.o_proj = Linear(self.num_heads * self.head_dim, self.hidden_size)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)


        # we'll use very large negative values to prevent attending to certain tokens
        mask_negatives = torch.full((1, 1, config.max_position_embeddings, config.max_position_embeddings),
                                 -2.3819763e38).to(torch.float)
        # then we'll replace the lower triangular ones with 0's to allow attention to see past tokens
        mask = torch.triu(mask_negatives, diagonal=1).to(config.device)
        # then we'll define self.mask as a tensor that shouldn't undergo gradient descent
        self.register_buffer('mask', mask)

    def forward(
        self,
        hidden_states: torch.Tensor, # The input tensor to the attention mechanism. shape (batch_size, input_len, hidden_size)
        #freqs_cis: torch.Tensor, # The frequencies for the rotary position embeddings
        #kv_write_indices: torch.Tensor, # Indices specifying where to write in the KV cache
        #kv_cache: Tuple[torch.Tensor, torch.Tensor], # A tuple of tensors holding the cached keys and values
        #mask: torch.Tensor, # The attention mask tensor, used to exclude future positions from attention calculations
    ) -> torch.Tensor:

        # Ensures the input tensor is 3-dimensional (batch_size, input_len, hidden_size).
        hidden_states_shape = hidden_states.shape
        assert len(hidden_states_shape) == 3

        # Extracts batch size and input sequence length from the hidden states tensor.
        batch_size, input_len, _ = hidden_states_shape

        # want to make sure they didn't input a sequence that's longer than the length we trained on
        #assert input_len <= max_seq_len # can't do this bc max_seq_len is not defined

        # Applies the linear projection to the hidden state to retrieve out
        qkv = self.qkv_proj(hidden_states)
        # Splits the combined QKV tensor into separate tensors for queries (xq), keys (xk), and values (xv) based on their respective sizes.
        xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size],
                               dim=-1)
        # for readability's sake it would've made more sense to do these separately; this is just more efficient

        # Reshapes each of the Q, K, and V tensors to separate the heads and align the dimensions for attention operations.
        xq = xq.view(batch_size, -1, self.num_heads, self.head_dim)
        xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim)
        xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim)

        # Applies rotary positional embeddings to queries and keys to incorporate positional information.
        xq = apply_rotary_emb(xq, self.head_dim, self.theta)#, freqs_cis=freqs_cis)
        xk = apply_rotary_emb(xk, self.head_dim, self.theta)#, freqs_cis=freqs_cis)

        # Updates the key-value cache with the new keys and values for this forward pass, using provided indices.
        # [batch_size, input_len, n_local_kv_heads, head_dim]
        #k_cache, v_cache = kv_cache
        #k_cache.index_copy_(1, kv_write_indices, xk)
        #v_cache.index_copy_(1, kv_write_indices, xv)

        #key = k_cache
        #value = v_cache

        # If the number of KV heads is different from the number of query heads, adjusts keys and values to match the query heads count.
        if self.num_kv_heads != self.num_heads:
            # [batch_size, input_len, n_local_heads, head_dim]
            key = torch.repeat_interleave(xk,#key, 
                                          self.num_queries_per_kv, 
                                          dim=2)
            value = torch.repeat_interleave(xv,#value,
                                            self.num_queries_per_kv,
                                            dim=2)

        # Transposes Q, K, and V tensors to align them for the batch matrix multiplication in attention calculation.
        # [batch_size, n_local_heads, input_len, head_dim]
        q = xq.transpose(1, 2)#query.trnaspose(1,2)
        k = xk.transpose(1, 2)#key.transpose(1, 2)
        v = xv.transpose(1, 2)#value.transpose(1, 2)

        # Calculates attention scores by performing a batch matrix multiplication between queries and keys, followed by scaling.
        # [batch_size, n_local_heads, input_len, input_len]
        scores = torch.matmul(q, k.transpose(2, 3)) * self.scaling
        
        # Applies the lower-triangular mask to the attention scores
        #print("mask: ", self.mask.shape)
        #print("input_len: ", input_len)
        #print("scores: ", scores.shape)
        scores = scores + self.mask[...,:input_len, :input_len] # make sure mask is the correct size. input_len <= max_seq_len
        # is it weird that we're masking with addition of 0's & big negatives instead of multiplication of 1's and 0's? 
        # as far as i'm aware it's weird, although my knowledge could just be out of date

        # Applies softmax to the scores to obtain attention probabilities, 
        # and ensuring numerical stability by casting to float before softmax and back to original dtype after.
        scores = F.softmax(scores.float(), dim=-1).type_as(q) # i think i can remove the typing stuff since i'm not doing quantization

        # Computes the weighted sum of values based on the attention scores to obtain the output of the attention mechanism.
        # [batch_size, n_local_heads, input_len, head_dim]
        output = torch.matmul(scores, v)

        # Reshapes the attention output to match the expected output dimensions, combining the heads back into the hidden dimension.
        # [batch_size, input_len, hidden_dim]
        output = output.transpose(1, 2).contiguous().view(batch_size, input_len, -1)

        # Applies the final linear projection to the attention output, mapping it back to the hidden size dimension.
        output = self.o_proj(output)

        return output

# Layer

interestingly, here we normalize 

In [10]:
class GemmaDecoderLayer(nn.Module):
    """
    A decoder layer that integrates the GemmaAttention mechanism and multi-layer perceptron (MLP). It includes
    normalization steps both before and after the attention mechanism to stabilize and accelerate training.
    """

    def __init__(
        self,
        config: GemmaConfig,
    ):
        super().__init__()

        # Initializes the GemmaAttention mechanism with parameters from the config, enabling self-attention within the decoder layer.
        self.self_attn = GemmaAttention(config
            #hidden_size = config.hidden_size,
            #num_heads = config.num_attention_heads,
            #num_kv_heads = config.num_key_value_heads,
            #head_dim = config.head_dim,
            #theta = config.rope_theta,
            #max_seq_len = config.max_position_embeddings,
        )
        # Initializes the GemmaMLP module, providing a non-linear transformation after the attention mechanism.
        self.mlp = GemmaMLP(
            hidden_size = config.hidden_size,
            intermediate_size = config.intermediate_size,
        )
        # Applies RMSNorm normalization to the input of the decoder layer for stable training dynamics.
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps = config.rms_norm_eps)
        # Applies RMSNorm after the attention mechanism and before the MLP to ensure the output is well-conditioned for further processing.
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps = config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor, # The input tensor to the decoder layer. shape (batch_size, input_len, hidden_size)
        #freqs_cis: torch.Tensor, # The frequencies for the rotary position embeddings, enhancing the model's understanding of sequence order.
        #kv_write_indices: torch.Tensor, # Indices specifying where to write in the key-value cache, facilitating efficient memory use.
        #kv_cache: Tuple[torch.Tensor, torch.Tensor], # A cache for keys and values to allow reusing attention computations across decoding steps.
        #mask: torch.Tensor, # The attention mask tensor, used to exclude future positions from attention calculations
    ) -> torch.Tensor:
        
        # Self Attention Block
        # Stores the original input for use as a residual connection, aiding in mitigating the vanishing gradient problem
        residual = hidden_states
        # Normalizes the input before processing by the attention mechanism.
        hidden_states = self.input_layernorm(hidden_states)
        # Processes the normalized input through the GemmaAttention mechanism
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            #freqs_cis=freqs_cis,
            #kv_write_indices=kv_write_indices,
            #kv_cache=kv_cache,
            #mask=mask,
        )
        # The aforementioned residual connection
        hidden_states = residual + hidden_states

        # MLP Block
        # Again, stores the output of the attention block for use as a residual connection before processing by the MLP.
        residual = hidden_states
        # Normalizes the output of the attention block before passing it to the MLP, ensuring a stable input distribution.
        hidden_states = self.post_attention_layernorm(hidden_states)
        # Transforms the normalized attention output through the MLP, introducing additional non-linearity and capacity to the model.
        hidden_states = self.mlp(hidden_states)
        # Another residual connection
        hidden_states = residual + hidden_states

        return hidden_states

# The Body of the Model

In [11]:
class GemmaBody(nn.Module):
    """ the class that loops through each layer of Gemma """

    def __init__(self, config: GemmaConfig):
        super().__init__()
        self.config = config
        self.vocab_size = config.vocab_size

        # Initialize a sequence of GemmaDecoderLayer instances as specified by the number of hidden layers in the configuration
        #self.layers = nn.ModuleList()
        #for _ in range(config.num_hidden_layers):
            #self.layers.append(GemmaDecoderLayer(config))
        self.layers = nn.ModuleList(GemmaDecoderLayer(config) for _ in range(config.num_hidden_layers))

        # Initialize a normalization layer to be applied after the last decoder layer, stabilizing the output
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor, # The first residual state of the model. shape (batch_size, input_len, hidden_size)
        #freqs_cis: torch.Tensor, # The frequencies for the rotary position embeddings, enhancing the model's understanding of sequence order.
        #kv_write_indices: torch.Tensor, # Indices specifying where to write in the key-value cache, facilitating efficient memory use.
        #kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], # A cache for keys and values to allow reusing attention computations across decoding steps
        #mask: torch.Tensor, # The attention mask tensor, used to exclude future positions from attention calculations
    ) -> torch.Tensor:

        # Iteratively process the input through each GemmaDecoderLayer, passing along necessary parameters for attention.
        # The hidden states are updated at each layer, progressively incorporating more complex dependencies and transformations.
        for i in range(len(self.layers)):
            layer = self.layers[i]
            hidden_states = layer(
                hidden_states=hidden_states,
                #freqs_cis=freqs_cis,
                #kv_write_indices=kv_write_indices,
                #kv_cache=kv_caches[i],
                #mask=mask,
            )
        
        # Apply normalization to the output of the final decoder layer, ensuring the model's output is well-conditioned for subsequent use.
        hidden_states = self.norm(hidden_states)
        
        return hidden_states

# The Model itself, and its generate() and load_weights() functions

In [12]:
class minGemma(nn.Module):

    def __init__(
        self,
        config: GemmaConfig,
        tokenizer: tokenizer,
    ):
        super().__init__()
        self.config = config

        # the attention heads need to cleanly divide up the hidden_size of the model so that we can split it all apart & combine back together
        assert config.hidden_size % config.num_attention_heads == 0

        self.max_seq_len = config.max_position_embeddings
        self.head_dim = config.head_dim
        self.vocab_size = config.vocab_size

        # uses whichever tokenizer was previously specified in the "Config" section
        self.tokenizer = tokenizer
        #self.tokenizer = tokenizer.Tokenizer(config.tokenizer)
        
        #self.embedder = Embedding(self.vocab_size, config.hidden_size) 
        self.embedder = nn.Embedding(self.vocab_size, config.hidden_size) # the embedding matrix
        self.model = GemmaBody(config) # the body of the model; all the transformer decoder layers
        #self.sampler = Sampler(self.vocab_size) # the function that probabailistically choses the next token

        # Pre-compute rotary embedding table.
        #freqs_cis = precompute_freqs_cis(self.head_dim,
        #                                 self.max_seq_len * 2,
        #                                 theta=config.rope_theta)
        # defines `self.freqs_cis` using `register_buffer()` which prevents the tensor from being trainable
        #self.register_buffer('freqs_cis', freqs_cis)

        # the loss function
        self.criterion = nn.CrossEntropyLoss()

    #@torch.no_grad()
    def forward( ####### need to add targets and make some of these optional
        self,
        input_token_ids: torch.Tensor,
        target_token_ids: torch.Tensor = None,
        #input_positions: torch.Tensor,
        #kv_write_indices: torch.Tensor,
        #kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
        #mask: torch.Tensor,
        #output_positions: torch.Tensor,
        #temperatures: torch.Tensor,
        #top_ps: torch.Tensor,
        #top_ks: torch.Tensor,
        #**kwargs,
    ) -> torch.Tensor:
        #freqs_cis = self.freqs_cis.index_select(0, input_positions)
        #kv_write_indices = input_positions

        # turn the input tokens into the first resudial state using the embedding matrix
        # (batch_size, input_len) & (vocab_size, hidden_size) -> (batch_size, input_len, hidden_size)
        hidden_states = self.embedder(input_token_ids)
        
        # Gemma normalizes the embedding by sqrt(hidden_size).
        hidden_states = hidden_states * (self.config.hidden_size**0.5)

        # this is where the bulk of the calculations are performed, the actual decoder layers
        hidden_states = self.model(
            hidden_states=hidden_states,
            #freqs_cis=freqs_cis,
            #kv_write_indices=kv_write_indices,
            #kv_caches=kv_caches,
            #mask=mask,
        ) # -> (batch_size, input_len, hidden_size)

        # grabbing the weights of the embedding matrix shape (vocab_size, hidden_dim) for use as the output layer
        embedder_weight = self.embedder.weight
        #if self.config.quant:
            #embedder_weight = (
                #embedder_weight * self.embedder.weight_scaler.unsqueeze(-1))

        # the embedding matrix is also used as the output layer. this saves on parameters compared to using a separate matrix
        # (batch_size, input_len, hidden_size) @ (hidden_size, vocab_size) -> (batch_size, input_len, vocab_size)
        logits = torch.matmul(hidden_states, embedder_weight.t())

        # there's an optional embedding bias used by tokens that are always/generally more/less likely
        #if embedding_bias is not None: # i think this might've gotten passed in with **kwargs
            # (batch_size, input_len, vocab_size) + (vocab_size) -> (batch_size, input_len, vocab_size)
            #logits += embedding_bias

        #if temperatures is None:
            # selects the token with the largest logit. No need to do softmax since there's no temperature
            #return torch.argmax(logits, dim=-1).squeeze(dim=-1)
        
        if target_token_ids is None: # if we're not training
            loss = None
        else:
            batch_size, input_len, vocab_size = logits.shape
            loss = self.criterion(logits.view(batch_size*input_len, vocab_size), target_token_ids.view(batch_size*input_len))
        
        return logits, loss

    @torch.no_grad()
    def Sampler(
        self,
        #embedding: torch.Tensor, # shape (vocab_size, hidden_size)
        #hidden_states: torch.Tensor, # shape (batch_size, input_len, hidden_size)
        logits: torch.Tensor,
        #output_positions: torch.Tensor, # shape (batch_size) list of integer indices. I think they should be all -1's for NTP
        #temperatures: torch.Tensor, # shape (batch_size) list of float temperatures
        temperature: float,
        #top_ps: torch.Tensor, # shape (batch_size) list of float percentages to denote out top_p criteria
        top_p: float,
        #top_ks: torch.Tensor, # shape (batch_size) list of integers to denote our top_k criteria
        top_k: int,
        #embedding_bias: Optional[torch.Tensor] = None, # shape (vocab_size) meant to bias output towards certain tokens
    ) -> torch.Tensor:
        """
        The Sampler function is responsible for generating token predictions from Gemma's output.
        It supports temperature scaling, top-p (nucleus) sampling, and top-k sampling to control the diversity and quality of the generated text. 
        The class operates as follows:
    
        1. Selects the last hidden state for each sequence in the batch
    
        2. Computes logits by multiplying the selected hidden states with the transposed embedding matrix. 
        An optional embedding bias can be added to these logits to bias the predictions towards certain tokens.
    
        3. If no temperature is provided, the function returns the indices of the tokens with the highest logits, aka greedy decoding.
    
        4. If a temperature is provided, it is used to scale the logits, making the distribution over tokens sharper (lower temperature) 
        or flatter (higher temperature), which affects the randomness of the sampling.
    
        5. The softmax function is applied to the scaled logits to obtain a probability distribution over the vocabulary.
    
        6. For top-p sampling, the function computes the cumulative sum of the sorted probabilities and masks out tokens until the 
        cumulative probability exceeds the threshold defined by `top_ps`. This allows the model to focus on a subset of the most 
        probable tokens while ignoring the long tail of less likely tokens.
    
        7. For top-k sampling, the function masks out all tokens except the `k` most likely ones, as specified by `top_ks`. 
        This ensures that the model only considers a fixed number of the most probable tokens for the next token prediction.
    
        8. After applying both the top-p and top-k masks, the probabilities are re-normalized so that they sum up to 1
    
        9. The function then samples from the re-normalized probability distribution to select the next token. 
        This step introduces randomness into the text generation process, allowing the model to generate diverse and coherent text.
        """
        ### currently this setup still *kind of* allows for batch size. I think i can not change this and still
        # have it work but for clarity i prolly should change it

        # Select the last element for each sequence.
        # (batch_size, input_len, hidden_size) -> (batch_size, hidden_size)
        # the "1" selects along the 'input_len' dimension & `output_positions` is a list of indices
        #hidden_states = hidden_states.index_select(1, output_positions).squeeze(dim=1) # squeeze removes the input_len dimension
        # (batch_size, input_len, vocab_size) -> (batch_size, vocab_size)
        logits = logits[:,-1,:] # do i need a squeeze?
        
        # Apply temperature scaling
        # (batch_size, vocab_size) / float -> (batch_size, vocab_size)
        logits.div_(temperature)

        # Calculate probabilities with softmax.
        probs = torch.softmax(logits, dim=-1, dtype=torch.float) # dim=-1 is the vocab_size dimension

        # sort the probabilities to for use in top-p & top-k
        probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) # both (batch_size, vocab_size)

        # calculating top-p
        # creates same-size tensor of cumulatve probabilities instead of indivdiual probs
        probs_sum = torch.cumsum(probs_sort, dim=-1) 
        # mask where 0's are top-p selections & 1's are to be excluded
        top_ps_mask = (probs_sum - probs_sort) > top_p#s.unsqueeze(dim=1) 
        # the original probabilities with excluded tokens changed to 0.0
        probs_sort = torch.where(top_ps_mask, 0, probs_sort) 

        # calculating top_k
        # create a shape (vocab_size) tensor that just iterates up by 1's
        top_ks_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device) 
        # expand our mask along the batch_size dimension to become size (batch_size, vocab_size)
        # "expand" means copy the original into this new size, so each length vocab_size row is the same
        top_ks_mask = top_ks_mask.expand(probs_idx.shape[0], -1)
        # top_ks is a list of integers. we keep whichever entries in top_ks_mask are greater than their corresponding entries in top_ks
        top_ks_mask = top_ks_mask >= top_k#s.unsqueeze(dim=1)

        # we'll be combining top-p with top-k and using whichever gives us fewer tokens. a very conservative approach
        # further trim probs_sort to fit within our top_k requirement
        probs_sort = torch.where(top_ks_mask, 0, probs_sort)

        # Re-normalization so that total probabilities add up to 1
        probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
        # used to rearrange the modified probabilities in probs_sort back to their original order according to probs_idx
        probs = torch.gather(probs_sort,
                             dim=-1,
                             index=torch.argsort(probs_idx, dim=-1))
        # samples from the distribution
        next_token_id = torch.multinomial(probs,
                                           num_samples=1)#,
                                           #replacement=True).squeeze(dim=-1)
        
        return next_token_id # returns the predicted tokens
        
    def generate(
        self,
        #prompts: Union[str, Sequence[str]],
        prompt: str,
        #device: Any,
        output_len: int = 100, # the model will output 100 tokens
        temperature: float = 0.95, # 0.95 is pretty close to not even using temperature at all (1.0 would be no effect)
        top_p: float = 1.0, # defaulting to 1 means we essentially don't use top-p
        top_k: int = 100, # this value for top-k is literally larger than our character-wise vocabulary bc default Gemma has such a huge vocab
    ) -> str: #Union[str, Sequence[str]]:
        """Generates responses for given prompts using Gemma model."""
        # encoding the prompt into token indices
        prompt_tokens = self.tokenizer.encode(prompt)
        prompt_tokens = torch.tensor(prompt_tokens, device=config.device).unsqueeze(0)
        
        # we wouldn't want to go past the maximum context length we trained on
        assert len(prompt_tokens) + output_len <= self.config.max_position_embeddings

        tokens = prompt_tokens
        for i in range(output_len):
            logits, _ = self(tokens[:,:self.max_seq_len])
            
            next_token = self.Sampler(
                #embedding=embedder_weight,
                #hidden_states=hidden_states,
                logits = logits,
                #output_positions=output_positions,
                #temperatures=temperatures,
                temperature = temperature,
                #top_ps=top_ps,
                top_p = top_p,
                #top_ks=top_ks,
                top_k = top_k
            )

            tokens = torch.cat((tokens, next_token), dim=1)

        output = self.tokenizer.decode(tokens.squeeze(0).tolist())

        return output

    #def load_weights(self, model_path: str):
    #    self.load_state_dict(
    #        torch.load(model_path, mmap=True)['model_state_dict'],
    #        strict=False,
    #    )

# Training-related Functions

In [13]:
# Train and test splits
data = torch.tensor(tokenizer.encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest validation
train_data = data[:n]
val_data = data[n:]

In [14]:
# data loading for training
def get_batch(split, batch_size):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - config.max_position_embeddings, (batch_size,))
    x = torch.stack([data[i:i+config.max_position_embeddings] for i in ix])
    y = torch.stack([data[i+1:i+config.max_position_embeddings+1] for i in ix])
    x, y = x.to(config.device), y.to(config.device)
    return x, y

In [15]:
xb, yb = get_batch('train', 1)
print(tokenizer.decode(xb.squeeze(0).tolist()))
print("-------")
print(tokenizer.decode(yb.squeeze(0).tolist()))

boy:
Perhaps thy childishness will move him more
Than can our reasons. There's no man in the world
More bound to 's mother; yet here he lets me prate
Like one i' the stocks. Thou hast never in thy life
Show'd thy dear mother any courtesy,
When she, poor he
-------
oy:
Perhaps thy childishness will move him more
Than can our reasons. There's no man in the world
More bound to 's mother; yet here he lets me prate
Like one i' the stocks. Thou hast never in thy life
Show'd thy dear mother any courtesy,
When she, poor hen


In [16]:
@torch.no_grad()
def estimate_loss(model, batch_size, eval_iters = 10): # to estimate loss later during the training loop
    out = {}
    model.eval() # sets model to eval mode
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split, batch_size)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train() # just resets to training mode
    return out

# Instantiating a brand new model

In [20]:
model = minGemma(config, tokenizer).to(config.device)

# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')

print(model)

964.352 K parameters
minGemma(
  (embedder): Embedding(65, 128)
  (model): GemmaBody(
    (layers): ModuleList(
      (0-3): 4 x GemmaDecoderLayer(
        (self_attn): GemmaAttention(
          (qkv_proj): Linear(in_features=128, out_features=192, bias=False)
          (o_proj): Linear(in_features=128, out_features=128, bias=False)
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear(in_features=128, out_features=512, bias=True)
          (up_proj): Linear(in_features=128, out_features=512, bias=True)
          (down_proj): Linear(in_features=512, out_features=128, bias=True)
        )
        (input_layernorm): RMSNorm()
        (post_attention_layernorm): RMSNorm()
      )
    )
    (norm): RMSNorm()
  )
  (criterion): CrossEntropyLoss()
)


# Load a Pretrained Model

In [18]:
# Initialize a blank model
model = minGemma(config, tokenizer).to(config.device)  

# here's the path to a minGemma model that i've trained with roughly 1m parameters
path = 'models/minGemma-vocab_size65-max_position_embeddings256-num_hidden_layers4-num_attention_heads4-num_key_value_heads1-hidden_size128-intermediate_size512-head_dim32-rms_norm_eps1e-06-rope_theta100.0--2024-02-23|02-10-08.pth'

# Load the saved state dictionary
model.load_state_dict(torch.load(path))
# REMEMBER TO CHANGE VALUES IN CONFIG TO MATCH THE MODEL YOU'VE LOADED

# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')

# If you only plan to do inference, switch to evaluation mode
model.eval()

# If you plan to continue training the model, switch to training mode
#model.train()

964.352 K parameters


minGemma(
  (embedder): Embedding(65, 128)
  (model): GemmaBody(
    (layers): ModuleList(
      (0-3): 4 x GemmaDecoderLayer(
        (self_attn): GemmaAttention(
          (qkv_proj): Linear(in_features=128, out_features=192, bias=False)
          (o_proj): Linear(in_features=128, out_features=128, bias=False)
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear(in_features=128, out_features=512, bias=True)
          (up_proj): Linear(in_features=128, out_features=512, bias=True)
          (down_proj): Linear(in_features=512, out_features=128, bias=True)
        )
        (input_layernorm): RMSNorm()
        (post_attention_layernorm): RMSNorm()
      )
    )
    (norm): RMSNorm()
  )
  (criterion): CrossEntropyLoss()
)

# Training

In [22]:
# create a PyTorch optimizer
# this is not what they used, but this learning rate & weight decay work for our tiny minGemma
learning_rate = 3e-4
weight_decay = 0.01
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# how long we want to train for
max_iters = 10000

# how often we want to check & see how our loss is doing
eval_interval = 250

# how many batches to use
batch_size = 32

In [23]:
start_time = time.time()

# Enable anomaly detection. uncomment these lines if you need to do extensive debugging
#torch.autograd.set_detect_anomaly(True)

for iter in range(max_iters):

    # sample a batch of data
    xb, yb = get_batch('train', batch_size)
    
    # train
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    
    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        current_time = time.time()
        elapsed_time = current_time - start_time
        losses = estimate_loss(model, batch_size)
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, time elapsed: {elapsed_time:.2f} seconds")

# Disable anomaly detection after the training loop
#torch.autograd.set_detect_anomaly(False)

step 0: train loss 121.8347, val loss 121.7503, time elapsed: 0.57 seconds
step 250: train loss 3.1086, val loss 3.1733, time elapsed: 94.64 seconds
step 500: train loss 2.6935, val loss 2.7184, time elapsed: 190.55 seconds
step 750: train loss 2.4985, val loss 2.5294, time elapsed: 284.27 seconds
step 1000: train loss 2.3782, val loss 2.3996, time elapsed: 378.43 seconds
step 1250: train loss 2.3049, val loss 2.3334, time elapsed: 471.49 seconds
step 1500: train loss 2.2583, val loss 2.2705, time elapsed: 565.12 seconds
step 1750: train loss 2.2035, val loss 2.2442, time elapsed: 658.16 seconds
step 2000: train loss 2.1502, val loss 2.1993, time elapsed: 751.26 seconds
step 2250: train loss 2.1077, val loss 2.1747, time elapsed: 844.68 seconds
step 2500: train loss 2.0497, val loss 2.1414, time elapsed: 937.64 seconds
step 2750: train loss 2.0311, val loss 2.1084, time elapsed: 1032.37 seconds
step 3000: train loss 1.9717, val loss 2.0872, time elapsed: 1125.54 seconds
step 3250: trai

# Saving your model

In [24]:
# save the model currently held in memory
# the filename specifies the model's class, hyperparameters, and date/time it was saved
torch.save(model.state_dict(),
           f'models/{model.__class__.__name__}'
           f'-vocab_size{config.vocab_size}'
           f'-max_position_embeddings{config.max_position_embeddings}'
           f'-num_hidden_layers{config.num_hidden_layers}'
           f'-num_attention_heads{config.num_attention_heads}'
           f'-num_key_value_heads{config.num_key_value_heads}'
           f'-hidden_size{config.hidden_size}'
           f'-intermediate_size{config.intermediate_size}'
           f'-head_dim{config.head_dim}'
           f'-rms_norm_eps{config.rms_norm_eps}'
           f'-rope_theta{config.rope_theta}'
           f'--{time.strftime("%Y-%m-%d|%H-%M-%S")}.pth')

# Inference

In [25]:
input_str = "JULIET:\nO Romeo, Romeo! wherefore art thou Rom" # the classic line
max_useable_output_len = config.max_position_embeddings - len(input_str)
output = model.generate(input_str, output_len = max_useable_output_len)
print(output)

JULIET:
O Romeo, Romeo! wherefore art thou Romeo
May no consvation.

CLALEND:
Not mindly late on these of the defence's hath
As sorm to believed then oather hath makes my shame
To to me have do note, I knee, their trunken me?

WARWICK:
O, voiceed succe to 


*^^^ I think that looks pretty good. the model has gotten the general pattern of what english words & sentences look like. And either it's memorized Juliet's first line or it has developed some [induction heads](https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html) to figure out in-context that an R should be followed by 'omeo'.*