# minGemma

our goal here is to teach all of the architectural concepts used in [Google's Gemma](https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf). The code here is a frankensteinian mixture between a little bit of [Andrej Karpathy's minGPT](https://www.youtube.com/watch?v=kCc8FmEb1nY&t=5012s) and mostly [Google's pytorch implementation of Gemma](https://github.com/google/gemma_pytorch). As such, certain aspects of the Gemma model will or won't be included based on our learning goals

#### Included:
- RoPE
- RMSnorm
- GeGLU activation
- Multi-Query Attention (for the smaller model)
- The extra normalization in the middle of each layer, between the attention mechanism and the MLP

#### Not Included: 
- **Gemma's training dataset:** The dataset & corresponding RLHF framework used to train Gemmma is out-of-scope for this lesson. This is unfortunate since a big part of Gemma's impressive performance is attributed to their high quantity and quality of data. I recommend reading the technical report for more information
- **Tokenization training:** The full 256,128 token vocabulary of Gemma is unusually large even for models orders of magnitude larger than Gemma (if I remember correctly, Llama's is around 32k). Their tokenizer also has interesting qualities such as beginning-of-sentence & end-of-sentence tokens & other tokens related to instruct tuning. The issue with training such a large token vocabulary is that many tokens are likely to rarely/never be trained, thus meaning you've wasted compute and leaving yourself vulnerable to the [SolidGoldMagikarp problem](https://www.lesswrong.com/posts/aPeJE8bSo6rAFoLqg/solidgoldmagikarp-plus-prompt-generation) unless you use an absurdly large dataset size like they did (2T and 6T tokens for the 2b and 7b models respectively). That being said, a better lesson than mine would've taught you how to train a more reasonably sized vocabulary. Instead of doing that, I recommend for homework you check out [Andrej Karpathy's recent guide on building tokenizers](https://www.youtube.com/watch?v=zduSFxRajkE) and implement a reasonably sized one yourself. Instead of that, we'll be using simple character-wise tokenization of TinyShakespeare by default, but the code will include the option to use Gemma's original gigantic tokenizer just so you can see what it looks like.
- **Highly parallelized training:** This is something I don't believe google included in their open-sourcing of these models which is unfortunate, but reasonable given that they trained on TPUs and no one owns TPUs except google. Rather, their open-sourced code (at least the pytorch version) is designed strictly for inference with their open-sourced weights. I recommend looking elsewhere if you want to learn about parallelized training techniques like sharding on GPUs
- **KV caching:** A standard method for saving on compute in transformers, KV caching involves saving individual key and value vectors calculated at each step of inference for use at future steps of inference. I've not included it here because it's an inference-specific mechanism whereas this code is meant to train, but if you develop your own inference version of this code then you 100% need to re-implment KV caching
- **Quantization:** This shouldn't be necessary for the absolutely tiny size of the default parameters I've set for minGemma
- **Batched Inference:** The open-sourced pytorch code for Gemma is setup to efficiently handle batched decoding of variable context lengths. The implementation here is only designed to perform inference on a single prompt at a time, which is all we really need for learning about the gemma architecture; i would hope no one is using this code for deployment.
- **Embedding Bias:** There was an odd reference to a bias term for use in the output layer after multiplying by the transpose of the embedding matrix. It was labeled as optional and didn't get used in google's own pytorch implementation so i assume it's inclusion was an accident, but it may have been utilized in different instantiations of the model. If anyone knows feel free to reach out

### ToDo
- look in-depth through the RoPE comments. I let ChatGPT write all of that so idk how much sense they make
- fill in all the empty links i've setup
- look into & talk about GeGLU in the MLP
- see if i can remove typing from the attention softmax calculation
- fix the calls to GemmaConfig in inits to allow for the 2b and 7b options

In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F

# config.py
import dataclasses # used for the config. this library is good at making classes that are full of just datatypes
from typing import Optional

# model.py
import re
from typing import Any, List, Sequence, Tuple, Union#, Optional
#from gemma import config as gemma_config
#from gemma import tokenizer

# for optionally downloading the original tokenizer
import os
import requests
from sentencepiece import SentencePieceProcessor

# training loop
import time

# The Dataset

In [2]:
# the dataset we'll be using is just TinyShakespeare
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()
print(text[:200])

# here are all the unique characters that occur in this text. we'll do character-wise tokenization
chars = sorted(list(set(text)))
v = len(chars)
print(chars, v)

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you
['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'] 65


# The Tokenizer(s)

In [3]:
# turning it into an actual tokenizer
# this version was created by ChatGPT. not super confident in it. gonna watch Andrej's video on making a custom tokenizer
# also need to change the rest of the references to use this tokenizer rather
class CharacterTokenizer:
    def __init__(self, chars: List[str]):
        # Create mappings from characters to integers and vice versa
        self.stoi = {ch: i for i, ch in enumerate(chars)}
        self.itos = {i: ch for i, ch in enumerate(chars)}

        ### not sure i need all these. I think they make it compatible with later code
        # Special tokens, adjust as needed
        #self.bos_char = '<'  # Beginning of string character
        #self.eos_char = '>'  # End of string character
        #self.pad_char = '~'  # Padding character

        # Add special tokens to mappings
        #special_tokens = [self.bos_char, self.eos_char, self.pad_char]
        #for token in special_tokens:
            #if token not in self.stoi:
                #index = len(self.stoi)
                #self.stoi[token] = index
                #self.itos[index] = token

        # Update vocabulary size
        #self.n_words = len(self.stoi)

        # Define special token IDs
        #self.bos_id = self.stoi[self.bos_char]
        #self.eos_id = self.stoi[self.eos_char]
        #self.pad_id = self.stoi[self.pad_char]

    def encode(self, s: str) -> List[int]:#, bos: bool = True, eos: bool = True
        """Converts a string into a list of character IDs"""
        #if bos:
            #s = self.bos_char + s
        #if eos:
            #s += self.eos_char
        return [self.stoi.get(c) for c in s] #, self.stoi[self.pad_char] # not sure how this works

    def decode(self, t: List[int]) -> str:
        """Converts a list of character IDs back into a string."""
        return ''.join([self.itos.get(i) for i in t])#, self.pad_char


# the original Gemma tokenizer
# I'm including this to use as an option if you want to try it with inference but you shouldn't use it for training minGemma
class GemmaTokenizer:
    def __init__(self, model_path: Optional[str]):
        # load tokenizer.
        assert os.path.isfile(model_path), model_path
        self.sp_model = SentencePieceProcessor(model_file=model_path)

        # BOS / EOS token IDs.
        self.n_words: int = self.sp_model.vocab_size()
        self.bos_id: int = self.sp_model.bos_id()
        self.eos_id: int = self.sp_model.eos_id()
        self.pad_id: int = self.sp_model.pad_id()
        assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()

    def encode(self, s: str, bos: bool = True, eos: bool = False) -> List[int]:
        """Converts a string into a list of tokens."""
        assert isinstance(s, str)
        t = self.sp_model.encode(s)
        if bos:
            t = [self.bos_id] + t
        if eos:
            t = t + [self.eos_id]
        return t

    def decode(self, t: List[int]) -> str:
        """Converts a list of tokens into a string."""
        return self.sp_model.decode(t)

# Config

In [4]:
@dataclasses.dataclass
class GemmaConfig:
    """ The default configuration & hyperparameters for minGemma. regular Gemma 2b & 7b are defined below """
    
    # The number of tokens in the vocabulary.
    vocab_size: int = v # defined earlier when we loaded TinyShakespeare. In Gemma it's 256,128
    
    # The maximum sequence length that this model might ever be used with.
    max_position_embeddings: int = 256 # In Gemma it's 8192
    
    # The number of blocks in the model.
    num_hidden_layers: int = 4 # In Gemma 7b it's 28 and 2b it's 18
    
    # The number of attention heads used in the attention layers of the model.
    num_attention_heads: int = 4 # In Gemma 7b it's 16 and 2b it's 8
    
    # The number of key-value heads for implementing attention.
    num_key_value_heads: int = 1 # In Gemma 7b it's 16 and in 2b it's 1
    # Notice what with Gemma 7b num_attention_heads = num_key_value_heads whereas this is not true for 2b
    # this is because 7b uses regular multi-head attention while 2b uses multi-query attention
    # the difference is that MQA shares single key & value matrices across all heads whereas MHA
    
    # The hidden size of the model.
    hidden_size: int = 128 # In Gemma 7b it's 3072 and in 2b it's 2048
    
    # The dimension of the MLP representations.
    intermediate_size: int = 512 # In Gemma 7b it's 24576 and in 2b it's 16384
    
    # The number of head dimensions.
    head_dim: int = 32 # In both Gemmas it's 256
    
    # The epsilon used by the rms normalization layers.
    rms_norm_eps: float = 1e-6
    
    # The path to the model tokenizer.
    tokenizer: Optional[str] = 'tokenizer/CharacterTokenizer.model' # not sure what to do with this
    # instead of the original, we'll be using the functions defined above for our character-wise tokenization
    
    # The dtype of the weights.
    #dtype: str = 'bfloat16'
    
    # Whether a quantized version of the model is used.
    # quant: bool = False 
    # we won't be providing this option in our implementation. All corresponding code for it has been removed

    #def get_dtype(self) -> Optional[torch.dtype]:
        #"""Gets the torch dtype from the config dtype string."""
        #return _STR_DTYPE_TO_TORCH_DTYPE.get(self.dtype, None)

    # the scaling factor that determines the frequencies for the rotary positional encodings
    rope_theta = 100.0 # smaller models should use a smaller theta. I'm just guessing here though

    device = 'cuda' if torch.cuda.is_available() else 'cpu' #'mps' if torch.backends.mps.is_available() else

def get_config_for_7b() -> GemmaConfig:
    return GemmaConfig(
        vocab_size = 256128,
        max_position_embeddings = 8192,
        num_hidden_layers = 28,
        num_attention_heads = 16,
        num_key_value_heads = 16,
        hidden_size = 3072,
        intermediate_size = 24576,
        head_dim = 256,
        tokenizer = 'tokenizer/tokenizer.model',
        rope_theta = 10000.0
    )

def get_config_for_2b() -> GemmaConfig:
    return GemmaConfig(
        vocab_size = 256128,
        max_position_embeddings = 8192,
        num_hidden_layers = 18,
        num_attention_heads = 8,
        num_key_value_heads = 1,
        hidden_size = 2048,
        intermediate_size = 16384,
        head_dim = 256,
        tokenizer = 'tokenizer/tokenizer.model',
        rope_theta = 10000.0
    )

def download_original_tokenizer():
    """
    I'm too lazy to read through google's license and figure out whether it's ok for me to download & redistribute 
    their tokenizer, so instead i just had ChatGPT write a script that has you download it yourself
    """
    # Configuration
    DESTINATION_FOLDER_PATH = './tokenizer'  # Local path in your Jupyter environment
    FILE_NAME = 'tokenizer.model'  # Name of the file to download
    local_file_path = os.path.join(DESTINATION_FOLDER_PATH, FILE_NAME) # Define the local path to save the file

    # Check if the file already exists
    if os.path.exists(local_file_path):
        print(f'File already exists at {local_file_path}. No download needed.')
        return
    
    # Construct the URL to the raw file on GitHub
    url = f'https://raw.githubusercontent.com/google/gemma_pytorch/main/tokenizer/{FILE_NAME}'
    
    # Make the HTTP GET request
    response = requests.get(url)
    
    # Check if the request was successful
    if response.status_code == 200:
        # Ensure the destination folder exists
        os.makedirs(DESTINATION_FOLDER_PATH, exist_ok=True)
        
        # Write the content of the response to a local file
        with open(local_file_path, 'wb') as file:
            file.write(response.content)
        print(f'File successfully downloaded to {local_file_path}')
    else:
        print(f'Failed to download the file. HTTP status code: {response.status_code}')

def get_model_config(variant: str = None) -> GemmaConfig:
    if variant == '7b':
        download_original_tokenizer()
        return get_config_for_7b(), GemmaTokenizer('tokenizer/tokenizer.model')
    elif variant == '2b':
        download_original_tokenizer()
        return get_config_for_2b(), GemmaTokenizer('tokenizer/tokenizer.model')
    else:
        return GemmaConfig(), CharacterTokenizer(chars) # the default config is minGemma

In [5]:
# add as input '7b' or '2b' to choose one of the original Gemma configs instead of minGemma
# but if you do that only use it for inference; i guarantee you don't have enough compute to train locally
config, tokenizer = get_model_config()

# testing the tokenizer
print(tokenizer.encode('hello world!'))
print(tokenizer.decode(tokenizer.encode('hello world!')))

[46, 43, 50, 50, 53, 1, 61, 53, 56, 50, 42, 2]
hello world!


In [6]:
# Train and test splits
data = torch.tensor(tokenizer.encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest validation
train_data = data[:n]
val_data = data[n:]

# Positional Encoding

Gemma uses RoPE, which is a popular relative positional encoding scheme. Positional encodings are designed to help the model understand the order of tokens in the text since transformers don't have a built-in ordering to them when reading a sequence. Instead of telling Gemma the precise location of each token (which would be "absolute" positional encoding), relative positional encodings only reveal to Gemma the placement of different tokens relative to each other. RoPE is a relative positional encoding scheme that functions within the attention mechanism (rather than, for example, along the embeddings) and uses trigonometry functions to effecitvely "rotate" the matrices used in self-attention. RoPE is the standard-to-beat nowadays as it's also find in other notable open-source models like Llama. To see more, check out [the original paper](https://arxiv.org/abs/2104.09864) or [this slightly more approachable guide](https://blog.eleuther.ai/rotary-embeddings/)

In [7]:
def apply_rotary_emb(x: torch.Tensor, dim: int, theta: float = 10000.0) -> torch.Tensor:
    """Applies the rotary embedding to the query and key tensors dynamically."""
    # Get sequence length and ensure compatibility with both batched and single sequence inputs
    seq_len = x.size(1)
    device = x.device
    #print("x: ", x.shape)
    
    # Dynamically compute frequency cis based on the input sequence length
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device).float() / dim))
    #print("freqs: ", freqs.shape)
    t = torch.arange(seq_len, device=device)
    freqs = torch.outer(t, freqs).float()
    #print("freqs: ", freqs.shape)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    #print("freqs_cis: ", freqs_cis.shape)

    # Apply rotary embeddings
    x_ = torch.view_as_complex(torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), dim=-1))
    #print("x_: ", x_.shape)
    #print("freqs_cis: ", freqs_cis.unsqueeze(0).shape)
    x_out = torch.view_as_real(x_ * freqs_cis.unsqueeze(0)).type_as(x)  # Ensure batch dimension is handled
    x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
    x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2], -1).transpose(1, 2)

    return x_out

# Normalization

Gemma uses [RMSnorm](). Oddly in Gemma they normalize both the input and the output of each transformer sub-layer, whereas usually you would only normalize in one or the other.

In [8]:
class RMSNorm(torch.nn.Module):
    """
    Implements the RMS Normalization (Root Mean Square Normalization) layer.
    RMSNorm is a variant of layer normalization that normalizes the activations
    of the previous layer based on their root mean square value.

    Parameters:
    - dim (int): The dimension of the input features the normalization is applied to.
    - eps (float): A small value added to the denominator for numerical stability. Default is 1e-6.
    - add_unit_offset (bool): If True, adds a unit (1) to the learned scaling coefficient, effectively
      starting with no scaling. If False, the scaling coefficient starts from zero. Default is True.
    """

    def __init__(
        self,
        dim: int,
        eps: float = 1e-6,
        add_unit_offset: bool = True,
    ):
        super().__init__()  # Initialize the base class (nn.Module)
        self.eps = eps  # Small epsilon value for numerical stability
        self.add_unit_offset = add_unit_offset  # Flag to determine if a unit should be added to the weight
        # Initialize the weight parameter with zeros, which will be learned during training.
        # The shape of the weight is [dim], one weight per feature dimension.
        self.weight = nn.Parameter(torch.zeros(dim))

    def _norm(self, x):
        """
        Private helper function to normalize the input tensor.

        Parameters:
        - x (Tensor): The input tensor to normalize.

        Returns:
        - Tensor: The normalized tensor.
        """
        # Calculate the root mean square value for each feature (across the last dimension),
        # then use reciprocal square root (rsqrt) for normalization.
        # Add self.eps to the denominator for numerical stability.
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        """
        Forward pass of the RMSNorm layer.

        Parameters:
        - x (Tensor): The input tensor to normalize.

        Returns:
        - output: The normalized and scaled tensor.
        """
        # Normalize the input tensor using the _norm function and ensure the data type matches the input.
        x = self._norm(x.float()).type_as(x)
        # If add_unit_offset is True, scale the normalized tensor by (1 + self.weight),
        # effectively starting with a scaling factor of 1 (no scaling).
        # Otherwise, scale by self.weight only.
        if self.add_unit_offset:
            output = x * (1 + self.weight)
        else:
            output = x * self.weight
        # Return the scaled output tensor.
        return output

# Multi-Layer Perceptron

The interesting thing here is the GeGLU activation function as opposed to ReLU, GeLU or SwiGLU. 

In [9]:
class GemmaMLP(nn.Module):
    """
    This class implements a multi-layer perceptron with a GeGLU gating mechanism. The GeGLU
    activation combines a standard GeLU activation with a learned gating mechanism, enabling
    the network to control the flow of information more dynamically.

    Attributes:
        gate_proj (nn.Linear): A linear layer that transforms the input tensor to an intermediate
                               representation, which is then passed through a GeLU activation for
                               gating purposes.
        up_proj (nn.Linear): Another linear layer that transforms the input tensor to the same
                             intermediate representation but without gating. This representation
                             is element-wise multiplied by the gated tensor from `gate_proj`.
        down_proj (nn.Linear): A linear layer that transforms the gated and combined tensor back
                               to the original dimension of the hidden size, producing the final output.
    """

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
    ):
        """
        Initializes the GemmaMLP module.

        Parameters:
            hidden_size (int): The size of the input and output tensors.
            intermediate_size (int): The size of the tensor after the initial transformation
                                     and before the gating and final projection. This is typically
                                     larger than the hidden size to allow for a richer representation.
        """
        super().__init__()

        # Linear transformation for the gating mechanism, projecting input to an intermediate size.
        self.gate_proj = nn.Linear(hidden_size, intermediate_size)

        # Linear transformation for the input tensor, also projecting to the intermediate size but
        # intended for element-wise multiplication with the gated output.
        self.up_proj = nn.Linear(hidden_size, intermediate_size)

        # Linear transformation to project the gated and combined tensor back to the original
        # hidden size, completing the MLP structure.
        self.down_proj = nn.Linear(intermediate_size, hidden_size)

    def forward(self, x):
        """
        Defines the forward pass of the GemmaMLP module.

        Parameters:
            x (Tensor): The input tensor to the MLP.

        Returns:
            Tensor: The output tensor after applying the GeGLU gating mechanism and the MLP transformations.
        """
        # Applies linear transformation for gating.
        gate = self.gate_proj(x)

        # Applies GeLU activation to the gate, introducing non-linearity and enabling the gating mechanism.
        gate = F.gelu(gate)

        # Applies another linear transformation to the input tensor for subsequent combination with the gate.
        up = self.up_proj(x)

        # Element-wise multiplication of the gated tensor with the transformed input tensor, modulating
        # the input based on the gate's activation.
        fuse = gate * up

        # Applies the final linear transformation to project the modulated tensor back to the hidden size.
        outputs = self.down_proj(fuse)

        # Returns the final output tensor of the MLP, after gating and modulation.
        return outputs

# Attention

In [10]:
class GemmaAttention(nn.Module):
    """
    Implements Multi-Query Attention which supports a distinct number of attention heads for queries and key-values (KV).
    In the case where the same number of queries and key-values are used, this implemenation is equivalent to regular Multi-Head Attention.  
    """
    
    def __init__(
        self,
        config: GemmaConfig
        #hidden_size: int, # The dimensionality of input and output tensors.
        #num_heads: int, # The number of query attention heads
        #num_kv_heads: int, # The number of key-value heads, which can differ from the query heads
        #head_dim: int, # The dimensionality of each attention head
        #theta: float,
        #max_seq_len: int
    ):
        super().__init__()

        self.num_heads = config.num_attention_heads
        self.num_kv_heads = config.num_key_value_heads
        
        # Ensures that the number of query heads is evenly divisible by the number of KV heads.
        assert self.num_heads % self.num_kv_heads == 0
        # Determines the number of query heads associated with each KV head.
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

        self.hidden_size = config.hidden_size
        self.head_dim = config.head_dim
        #self.max_seq_len = max_seq_len
        self.theta = config.rope_theta

        # Calculates the total size for all query projections.
        self.q_size = self.num_heads * self.head_dim
        # Calculates the total size for all key and value projections.
        self.kv_size = self.num_kv_heads * self.head_dim

        # Defines the scaling factor for the attention scores.
        self.scaling = self.head_dim**-0.5

        # Initializes the linear projection layer for queries, keys, and values
        #self.qkv_proj = Linear(self.hidden_size, (self.num_heads + 2 * self.num_kv_heads) * self.head_dim)
        self.qkv_proj = nn.Linear(self.hidden_size, (self.num_heads + 2 * self.num_kv_heads) * self.head_dim, bias=False)
        # Initializes the output projection layer, mapping the concatenated attention outputs back to the hidden size.
        #self.o_proj = Linear(self.num_heads * self.head_dim, self.hidden_size)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)


        # we'll use very large negative values to prevent attending to certain tokens
        mask_negatives = torch.full((1, 1, config.max_position_embeddings, config.max_position_embeddings),
                                 -2.3819763e38).to(torch.float)
        # then we'll replace the lower triangular ones with 0's to allow attention to see past tokens
        mask = torch.triu(mask_negatives, diagonal=1).to(config.device)
        # then we'll define self.mask as a tensor that shouldn't undergo gradient descent
        self.register_buffer('mask', mask)

    def forward(
        self,
        hidden_states: torch.Tensor, # The input tensor to the attention mechanism. shape (batch_size, input_len, hidden_size)
        #freqs_cis: torch.Tensor, # The frequencies for the rotary position embeddings
        #kv_write_indices: torch.Tensor, # Indices specifying where to write in the KV cache
        #kv_cache: Tuple[torch.Tensor, torch.Tensor], # A tuple of tensors holding the cached keys and values
        #mask: torch.Tensor, # The attention mask tensor, used to exclude future positions from attention calculations
    ) -> torch.Tensor:

        # Ensures the input tensor is 3-dimensional (batch_size, input_len, hidden_size).
        hidden_states_shape = hidden_states.shape
        assert len(hidden_states_shape) == 3

        # Extracts batch size and input sequence length from the hidden states tensor.
        batch_size, input_len, _ = hidden_states_shape

        # want to make sure they didn't input a sequence that's longer than the length we trained on
        #assert input_len <= max_seq_len # can't do this bc max_seq_len is not defined

        # Applies the linear projection to the hidden state to retrieve out
        qkv = self.qkv_proj(hidden_states)
        # Splits the combined QKV tensor into separate tensors for queries (xq), keys (xk), and values (xv) based on their respective sizes.
        xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size],
                               dim=-1)
        # for readability's sake it would've made more sense to do these separately; this is just more efficient

        # Reshapes each of the Q, K, and V tensors to separate the heads and align the dimensions for attention operations.
        xq = xq.view(batch_size, -1, self.num_heads, self.head_dim)
        xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim)
        xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim)

        # Applies rotary positional embeddings to queries and keys to incorporate positional information.
        xq = apply_rotary_emb(xq, self.head_dim, self.theta)#, freqs_cis=freqs_cis)
        xk = apply_rotary_emb(xk, self.head_dim, self.theta)#, freqs_cis=freqs_cis)

        # Updates the key-value cache with the new keys and values for this forward pass, using provided indices.
        # [batch_size, input_len, n_local_kv_heads, head_dim]
        #k_cache, v_cache = kv_cache
        #k_cache.index_copy_(1, kv_write_indices, xk)
        #v_cache.index_copy_(1, kv_write_indices, xv)

        #key = k_cache
        #value = v_cache

        # If the number of KV heads is different from the number of query heads, adjusts keys and values to match the query heads count.
        if self.num_kv_heads != self.num_heads:
            # [batch_size, input_len, n_local_heads, head_dim]
            key = torch.repeat_interleave(xk,#key, 
                                          self.num_queries_per_kv, 
                                          dim=2)
            value = torch.repeat_interleave(xv,#value,
                                            self.num_queries_per_kv,
                                            dim=2)

        # Transposes Q, K, and V tensors to align them for the batch matrix multiplication in attention calculation.
        # [batch_size, n_local_heads, input_len, head_dim]
        q = xq.transpose(1, 2)#query.trnaspose(1,2)
        k = xk.transpose(1, 2)#key.transpose(1, 2)
        v = xv.transpose(1, 2)#value.transpose(1, 2)

        # Calculates attention scores by performing a batch matrix multiplication between queries and keys, followed by scaling.
        # [batch_size, n_local_heads, input_len, input_len]
        scores = torch.matmul(q, k.transpose(2, 3)) * self.scaling
        
        # Applies the lower-triangular mask to the attention scores
        #print("mask: ", self.mask.shape)
        #print("input_len: ", input_len)
        #print("scores: ", scores.shape)
        scores = scores + self.mask[...,:input_len, :input_len] # make sure mask is the correct size. input_len <= max_seq_len
        # is it weird that we're masking with addition of 0's & big negatives instead of multiplication of 1's and 0's? 
        # as far as i'm aware it's weird, although my knowledge could just be out of date

        # Applies softmax to the scores to obtain attention probabilities, 
        # and ensuring numerical stability by casting to float before softmax and back to original dtype after.
        scores = F.softmax(scores.float(), dim=-1).type_as(q) # i think i can remove the typing stuff since i'm not doing quantization

        # Computes the weighted sum of values based on the attention scores to obtain the output of the attention mechanism.
        # [batch_size, n_local_heads, input_len, head_dim]
        output = torch.matmul(scores, v)

        # Reshapes the attention output to match the expected output dimensions, combining the heads back into the hidden dimension.
        # [batch_size, input_len, hidden_dim]
        output = output.transpose(1, 2).contiguous().view(batch_size, input_len, -1)

        # Applies the final linear projection to the attention output, mapping it back to the hidden size dimension.
        output = self.o_proj(output)

        return output

# Layer

interestingly, here we normalize 

In [11]:
class GemmaDecoderLayer(nn.Module):
    """
    A decoder layer that integrates the GemmaAttention mechanism and multi-layer perceptron (MLP). It includes
    normalization steps both before and after the attention mechanism to stabilize and accelerate training.
    """

    def __init__(
        self,
        config: GemmaConfig,
    ):
        super().__init__()

        # Initializes the GemmaAttention mechanism with parameters from the config, enabling self-attention within the decoder layer.
        self.self_attn = GemmaAttention(config
            #hidden_size = config.hidden_size,
            #num_heads = config.num_attention_heads,
            #num_kv_heads = config.num_key_value_heads,
            #head_dim = config.head_dim,
            #theta = config.rope_theta,
            #max_seq_len = config.max_position_embeddings,
        )
        # Initializes the GemmaMLP module, providing a non-linear transformation after the attention mechanism.
        self.mlp = GemmaMLP(
            hidden_size = config.hidden_size,
            intermediate_size = config.intermediate_size,
        )
        # Applies RMSNorm normalization to the input of the decoder layer for stable training dynamics.
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps = config.rms_norm_eps)
        # Applies RMSNorm after the attention mechanism and before the MLP to ensure the output is well-conditioned for further processing.
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps = config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor, # The input tensor to the decoder layer. shape (batch_size, input_len, hidden_size)
        #freqs_cis: torch.Tensor, # The frequencies for the rotary position embeddings, enhancing the model's understanding of sequence order.
        #kv_write_indices: torch.Tensor, # Indices specifying where to write in the key-value cache, facilitating efficient memory use.
        #kv_cache: Tuple[torch.Tensor, torch.Tensor], # A cache for keys and values to allow reusing attention computations across decoding steps.
        #mask: torch.Tensor, # The attention mask tensor, used to exclude future positions from attention calculations
    ) -> torch.Tensor:
        
        # Self Attention Block
        # Stores the original input for use as a residual connection, aiding in mitigating the vanishing gradient problem
        residual = hidden_states
        # Normalizes the input before processing by the attention mechanism.
        hidden_states = self.input_layernorm(hidden_states)
        # Processes the normalized input through the GemmaAttention mechanism
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            #freqs_cis=freqs_cis,
            #kv_write_indices=kv_write_indices,
            #kv_cache=kv_cache,
            #mask=mask,
        )
        # The aforementioned residual connection
        hidden_states = residual + hidden_states

        # MLP Block
        # Again, stores the output of the attention block for use as a residual connection before processing by the MLP.
        residual = hidden_states
        # Normalizes the output of the attention block before passing it to the MLP, ensuring a stable input distribution.
        hidden_states = self.post_attention_layernorm(hidden_states)
        # Transforms the normalized attention output through the MLP, introducing additional non-linearity and capacity to the model.
        hidden_states = self.mlp(hidden_states)
        # Another residual connection
        hidden_states = residual + hidden_states

        return hidden_states

# The Body of the Model

In [12]:
class GemmaModel(nn.Module):
    """ the class that loops through each layer of Gemma """

    def __init__(self, config: GemmaConfig):
        super().__init__()
        self.config = config
        self.vocab_size = config.vocab_size

        # Initialize a sequence of GemmaDecoderLayer instances as specified by the number of hidden layers in the configuration
        #self.layers = nn.ModuleList()
        #for _ in range(config.num_hidden_layers):
            #self.layers.append(GemmaDecoderLayer(config))
        self.layers = nn.ModuleList(GemmaDecoderLayer(config) for _ in range(config.num_hidden_layers))

        # Initialize a normalization layer to be applied after the last decoder layer, stabilizing the output
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor, # The first residual state of the model. shape (batch_size, input_len, hidden_size)
        #freqs_cis: torch.Tensor, # The frequencies for the rotary position embeddings, enhancing the model's understanding of sequence order.
        #kv_write_indices: torch.Tensor, # Indices specifying where to write in the key-value cache, facilitating efficient memory use.
        #kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], # A cache for keys and values to allow reusing attention computations across decoding steps
        #mask: torch.Tensor, # The attention mask tensor, used to exclude future positions from attention calculations
    ) -> torch.Tensor:

        # Iteratively process the input through each GemmaDecoderLayer, passing along necessary parameters for attention.
        # The hidden states are updated at each layer, progressively incorporating more complex dependencies and transformations.
        for i in range(len(self.layers)):
            layer = self.layers[i]
            hidden_states = layer(
                hidden_states=hidden_states,
                #freqs_cis=freqs_cis,
                #kv_write_indices=kv_write_indices,
                #kv_cache=kv_caches[i],
                #mask=mask,
            )
        
        # Apply normalization to the output of the final decoder layer, ensuring the model's output is well-conditioned for subsequent use.
        hidden_states = self.norm(hidden_states)
        
        return hidden_states

# The Model itself, and its generate() and load_weights() functions

In [51]:
class GemmaForCausalLM(nn.Module):

    def __init__(
        self,
        config: GemmaConfig,
        tokenizer: tokenizer,
    ):
        super().__init__()
        self.config = config

        # the attention heads need to cleanly divide up the hidden_size of the model so that we can split it all apart & combine back together
        assert config.hidden_size % config.num_attention_heads == 0

        self.max_seq_len = config.max_position_embeddings
        self.head_dim = config.head_dim
        self.vocab_size = config.vocab_size

        # uses whichever tokenizer was previously specified in the "Config" section
        self.tokenizer = tokenizer
        #self.tokenizer = tokenizer.Tokenizer(config.tokenizer)
        
        #self.embedder = Embedding(self.vocab_size, config.hidden_size) 
        self.embedder = nn.Embedding(self.vocab_size, config.hidden_size) # the embedding matrix
        self.model = GemmaModel(config) # the body of the model; all the transformer decoder layers
        #self.sampler = Sampler(self.vocab_size) # the function that probabailistically choses the next token

        # Pre-compute rotary embedding table.
        #freqs_cis = precompute_freqs_cis(self.head_dim,
        #                                 self.max_seq_len * 2,
        #                                 theta=config.rope_theta)
        # defines `self.freqs_cis` using `register_buffer()` which prevents the tensor from being trainable
        #self.register_buffer('freqs_cis', freqs_cis)

        # the loss function
        self.criterion = nn.CrossEntropyLoss()

    #@torch.no_grad()
    def forward( ####### need to add targets and make some of these optional
        self,
        input_token_ids: torch.Tensor,
        target_token_ids: torch.Tensor = None,
        #input_positions: torch.Tensor,
        #kv_write_indices: torch.Tensor,
        #kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
        #mask: torch.Tensor,
        #output_positions: torch.Tensor,
        #temperatures: torch.Tensor,
        #top_ps: torch.Tensor,
        #top_ks: torch.Tensor,
        #**kwargs,
    ) -> torch.Tensor:
        #freqs_cis = self.freqs_cis.index_select(0, input_positions)
        #kv_write_indices = input_positions

        # turn the input tokens into the first resudial state using the embedding matrix
        # (batch_size, input_len) & (vocab_size, hidden_size) -> (batch_size, input_len, hidden_size)
        hidden_states = self.embedder(input_token_ids)
        
        # Gemma normalizes the embedding by sqrt(hidden_size).
        hidden_states = hidden_states * (self.config.hidden_size**0.5)

        # this is where the bulk of the calculations are performed, the actual decoder layers
        hidden_states = self.model(
            hidden_states=hidden_states,
            #freqs_cis=freqs_cis,
            #kv_write_indices=kv_write_indices,
            #kv_caches=kv_caches,
            #mask=mask,
        ) # -> (batch_size, input_len, hidden_size)

        # grabbing the weights of the embedding matrix shape (vocab_size, hidden_dim) for use as the output layer
        embedder_weight = self.embedder.weight
        #if self.config.quant:
            #embedder_weight = (
                #embedder_weight * self.embedder.weight_scaler.unsqueeze(-1))

        # the embedding matrix is also used as the output layer. this saves on parameters compared to using a separate matrix
        # (batch_size, input_len, hidden_size) @ (hidden_size, vocab_size) -> (batch_size, input_len, vocab_size)
        logits = torch.matmul(hidden_states, embedder_weight.t())

        # there's an optional embedding bias used by tokens that are always/generally more/less likely
        #if embedding_bias is not None: # i think this might've gotten passed in with **kwargs
            # (batch_size, input_len, vocab_size) + (vocab_size) -> (batch_size, input_len, vocab_size)
            #logits += embedding_bias

        #if temperatures is None:
            # selects the token with the largest logit. No need to do softmax since there's no temperature
            #return torch.argmax(logits, dim=-1).squeeze(dim=-1)
        
        if target_token_ids is None: # if we're not training
            loss = None
        else:
            batch_size, input_len, vocab_size = logits.shape
            loss = self.criterion(logits.view(batch_size*input_len, vocab_size), target_token_ids.view(batch_size*input_len))
        
        return logits, loss

    @torch.no_grad()
    def Sampler(
        self,
        #embedding: torch.Tensor, # shape (vocab_size, hidden_size)
        #hidden_states: torch.Tensor, # shape (batch_size, input_len, hidden_size)
        logits: torch.Tensor,
        #output_positions: torch.Tensor, # shape (batch_size) list of integer indices. I think they should be all -1's for NTP
        #temperatures: torch.Tensor, # shape (batch_size) list of float temperatures
        temperature: float,
        #top_ps: torch.Tensor, # shape (batch_size) list of float percentages to denote out top_p criteria
        top_p: float,
        #top_ks: torch.Tensor, # shape (batch_size) list of integers to denote our top_k criteria
        top_k: int,
        #embedding_bias: Optional[torch.Tensor] = None, # shape (vocab_size) meant to bias output towards certain tokens
    ) -> torch.Tensor:
        """
        The Sampler function is responsible for generating token predictions from Gemma's output.
        It supports temperature scaling, top-p (nucleus) sampling, and top-k sampling to control the diversity and quality of the generated text. 
        The class operates as follows:
    
        1. Selects the last hidden state for each sequence in the batch
    
        2. Computes logits by multiplying the selected hidden states with the transposed embedding matrix. 
        An optional embedding bias can be added to these logits to bias the predictions towards certain tokens.
    
        3. If no temperature is provided, the function returns the indices of the tokens with the highest logits, aka greedy decoding.
    
        4. If a temperature is provided, it is used to scale the logits, making the distribution over tokens sharper (lower temperature) 
        or flatter (higher temperature), which affects the randomness of the sampling.
    
        5. The softmax function is applied to the scaled logits to obtain a probability distribution over the vocabulary.
    
        6. For top-p sampling, the function computes the cumulative sum of the sorted probabilities and masks out tokens until the 
        cumulative probability exceeds the threshold defined by `top_ps`. This allows the model to focus on a subset of the most 
        probable tokens while ignoring the long tail of less likely tokens.
    
        7. For top-k sampling, the function masks out all tokens except the `k` most likely ones, as specified by `top_ks`. 
        This ensures that the model only considers a fixed number of the most probable tokens for the next token prediction.
    
        8. After applying both the top-p and top-k masks, the probabilities are re-normalized so that they sum up to 1
    
        9. The function then samples from the re-normalized probability distribution to select the next token. 
        This step introduces randomness into the text generation process, allowing the model to generate diverse and coherent text.
        """
        ### currently this setup still *kind of* allows for batch size. I think i can not change this and still
        # have it work but for clarity i prolly should change it

        # Select the last element for each sequence.
        # (batch_size, input_len, hidden_size) -> (batch_size, hidden_size)
        # the "1" selects along the 'input_len' dimension & `output_positions` is a list of indices
        #hidden_states = hidden_states.index_select(1, output_positions).squeeze(dim=1) # squeeze removes the input_len dimension
        # (batch_size, input_len, vocab_size) -> (batch_size, vocab_size)
        logits = logits[:,-1,:] # do i need a squeeze?
        
        # Apply temperature scaling
        # (batch_size, vocab_size) / float -> (batch_size, vocab_size)
        logits.div_(temperature)

        # Calculate probabilities with softmax.
        probs = torch.softmax(logits, dim=-1, dtype=torch.float) # dim=-1 is the vocab_size dimension

        # sort the probabilities to for use in top-p & top-k
        probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) # both (batch_size, vocab_size)

        # calculating top-p
        # creates same-size tensor of cumulatve probabilities instead of indivdiual probs
        probs_sum = torch.cumsum(probs_sort, dim=-1) 
        # mask where 0's are top-p selections & 1's are to be excluded
        top_ps_mask = (probs_sum - probs_sort) > top_p#s.unsqueeze(dim=1) 
        # the original probabilities with excluded tokens changed to 0.0
        probs_sort = torch.where(top_ps_mask, 0, probs_sort) 

        # calculating top_k
        # create a shape (vocab_size) tensor that just iterates up by 1's
        top_ks_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device) 
        # expand our mask along the batch_size dimension to become size (batch_size, vocab_size)
        # "expand" means copy the original into this new size, so each length vocab_size row is the same
        top_ks_mask = top_ks_mask.expand(probs_idx.shape[0], -1)
        # top_ks is a list of integers. we keep whichever entries in top_ks_mask are greater than their corresponding entries in top_ks
        top_ks_mask = top_ks_mask >= top_k#s.unsqueeze(dim=1)

        # we'll be combining top-p with top-k and using whichever gives us fewer tokens. a very conservative approach
        # further trim probs_sort to fit within our top_k requirement
        probs_sort = torch.where(top_ks_mask, 0, probs_sort)

        # Re-normalization so that total probabilities add up to 1
        probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
        # used to rearrange the modified probabilities in probs_sort back to their original order according to probs_idx
        probs = torch.gather(probs_sort,
                             dim=-1,
                             index=torch.argsort(probs_idx, dim=-1))
        # samples from the distribution
        next_token_id = torch.multinomial(probs,
                                           num_samples=1)#,
                                           #replacement=True).squeeze(dim=-1)
        
        return next_token_id # returns the predicted tokens
        
    def generate(
        self,
        #prompts: Union[str, Sequence[str]],
        prompt: str,
        #device: Any,
        output_len: int = 100, # the model will output 100 tokens
        temperature: float = 0.95, # 0.95 is pretty close to not even using temperature at all (1.0 would be no effect)
        top_p: float = 1.0, # defaulting to 1 means we essentially don't use top-p
        top_k: int = 100, # this value for top-k is literally larger than our character-wise vocabulary bc default Gemma has such a huge vocab
    ) -> str: #Union[str, Sequence[str]]:
        """Generates responses for given prompts using Gemma model."""
        # encoding the prompt into token indices
        prompt_tokens = self.tokenizer.encode(prompt)
        prompt_tokens = torch.tensor(prompt_tokens, device=config.device).unsqueeze(0)
        
        # we wouldn't want to go past the maximum context length we trained on
        assert len(prompt_tokens) + output_len <= self.config.max_position_embeddings

        tokens = prompt_tokens
        for i in range(output_len):
            logits, _ = self(tokens[:,:self.max_seq_len])
            
            next_token = self.Sampler(
                #embedding=embedder_weight,
                #hidden_states=hidden_states,
                logits = logits,
                #output_positions=output_positions,
                #temperatures=temperatures,
                temperature = temperature,
                #top_ps=top_ps,
                top_p = top_p,
                #top_ks=top_ks,
                top_k = top_k
            )

            tokens = torch.cat((tokens, next_token), dim=1)

        output = self.tokenizer.decode(tokens.squeeze(0).tolist())

        return output

    def load_weights(self, model_path: str):
        self.load_state_dict(
            torch.load(model_path, mmap=True)['model_state_dict'],
            strict=False,
        )



    
        ##### this stuff is originally from the batched inference #####
        
        # When a single prompt is provided, treat it as a batch of 1.
        #is_str_prompt = isinstance(prompts, str)
        #if is_str_prompt:
            #prompts = [prompts]
        #batch_size = len(prompts)

        # encoding the batch of prompts
        #prompt_tokens = [self.tokenizer.encode(prompt) for prompt in prompts]
        #min_prompt_len = min(len(p) for p in prompt_tokens)
        #max_prompt_len = max(len(p) for p in prompt_tokens)
        #max_seq_len = max_prompt_len + output_len
    
        # build KV caches
        #kv_caches = []
        #for _ in range(self.config.num_hidden_layers):
            #size = (batch_size, max_seq_len, self.config.num_key_value_heads, self.config.head_dim)
            #dtype = self.config.get_dtype()
            #k_cache = torch.zeros(size=size, dtype=dtype, device=device)
            #v_cache = torch.zeros(size=size, dtype=dtype, device=device)
            #kv_caches.append((k_cache, v_cache))

        # prepare inputs
        #token_ids_tensor = torch.full((batch_size, max_seq_len),
        #                              self.tokenizer.pad_id, dtype=torch.int64)
        #input_token_ids_tensor = torch.full((batch_size, min_prompt_len),
        #                                    self.tokenizer.pad_id,
        #                                    dtype=torch.int64)
        #for i, p in enumerate(prompt_tokens):
        #    token_ids_tensor[i, :len(p)] = torch.tensor(p)
        #    input_token_ids_tensor[i, :min_prompt_len] = torch.tensor(
        #        p[:min_prompt_len])
        #token_ids_tensor = token_ids_tensor.to(device)
        #input_token_ids_tensor = input_token_ids_tensor.to(device)
        #prompt_mask_tensor = token_ids_tensor != self.tokenizer.pad_id
        #input_positions_tensor = torch.arange(0, min_prompt_len,
        #                                      dtype=torch.int64).to(device)
        #mask_tensor = torch.full((1, 1, max_seq_len, max_seq_len),
        #                         -2.3819763e38).to(torch.float)
        #mask_tensor = torch.triu(mask_tensor, diagonal=1).to(device)
        #curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor)
        #output_positions_tensor = torch.LongTensor([min_prompt_len - 1]).to(
        #    device)
        #temperatures_tensor = torch.FloatTensor([temperature] * batch_size).to(
        #    device)
        #top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device)
        #top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device)
        #output_index = torch.tensor(min_prompt_len, dtype=torch.int64).to(
        #    device)

        # Prefill up to min_prompt_len tokens, then treat other prefill as
        # decode and ignore output.
        #for i in range(max_seq_len - min_prompt_len):
        #    next_token_ids = self(
        #        input_token_ids=input_token_ids_tensor,
        #        input_positions=input_positions_tensor,
        #        #kv_write_indices=None,
        #        #kv_caches=kv_caches,
        #        mask=curr_mask_tensor,
        #        output_positions=output_positions_tensor,
        #        temperatures=temperatures_tensor,
        #        top_ps=top_ps_tensor,
        #        top_ks=top_ks_tensor,
        #    )

        #    curr_prompt_mask = prompt_mask_tensor.index_select(
        #        1, output_index).squeeze(dim=1)
        #    curr_token_ids = token_ids_tensor.index_select(
        #        1, output_index).squeeze(dim=1)
        #    output_token_ids = torch.where(curr_prompt_mask, curr_token_ids,
        #                                   next_token_ids).unsqueeze(dim=1)
        #    token_ids_tensor.index_copy_(1, output_index, output_token_ids)

        #    input_token_ids_tensor = output_token_ids
        #    input_positions_tensor = output_index.unsqueeze(dim=-1)
        #    curr_mask_tensor = mask_tensor.index_select(2,
        #                                                input_positions_tensor)
        #    output_positions_tensor = torch.tensor(0, dtype=torch.int64).to(
        #        device)
        #    output_index = output_index + 1

        # Detokenization.
        #token_ids = token_ids_tensor.tolist()
        #results = []
        #for i, tokens in enumerate(token_ids):
        #    trimmed_output = tokens[len(prompt_tokens[i]):len(prompt_tokens[i])
        #                            + output_len]
        #    if self.tokenizer.eos_id in trimmed_output:
        #        eos_index = trimmed_output.index(self.tokenizer.eos_id)
        #        trimmed_output = trimmed_output[:eos_index]
        #    results.append(self.tokenizer.decode(trimmed_output))

        # If a string was provided as input, return a string as output.
        #return results[0] if is_str_prompt else results

# Training Functions

In [52]:
# Train and test splits
data = torch.tensor(tokenizer.encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest validation
train_data = data[:n]
val_data = data[n:]

In [53]:
# data loading for training
def get_batch(split, batch_size):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - config.max_position_embeddings, (batch_size,))
    x = torch.stack([data[i:i+config.max_position_embeddings] for i in ix])
    y = torch.stack([data[i+1:i+config.max_position_embeddings+1] for i in ix])
    x, y = x.to(config.device), y.to(config.device)
    return x, y

In [54]:
xb, yb = get_batch('train', 1)
print(tokenizer.decode(xb.squeeze(0).tolist()))
print("-------")
print(tokenizer.decode(yb.squeeze(0).tolist()))

 limits of yon lime and stone:
And with him are the Lord Aumerle, Lord Salisbury,
Sir Stephen Scroop, besides a clergyman
Of hol
-------
limits of yon lime and stone:
And with him are the Lord Aumerle, Lord Salisbury,
Sir Stephen Scroop, besides a clergyman
Of holy


In [55]:
@torch.no_grad()
def estimate_loss(model, batch_size, eval_iters = 10): # to estimate loss later during the training loop
    out = {}
    model.eval() # sets model to eval mode
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split, batch_size)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train() # just resets to training mode
    return out

# Training

In [67]:
model = GemmaForCausalLM(config, tokenizer).to(config.device)

# create a PyTorch optimizer
# this is not what they used, but this learning rate & weight decay work for our tiny minGemma
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)

# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')

964.352 K parameters


In [68]:
# how long we want to train for
max_iters = 1000

# how often we want to check & see how our loss is doing
eval_interval = 100

# how many batches to use
batch_size = 32

In [65]:
start_time = time.time()

# Enable anomaly detection. uncomment these lines if you need to do extensive debugging
#torch.autograd.set_detect_anomaly(True)

for iter in range(max_iters):

    # sample a batch of data
    xb, yb = get_batch('train', batch_size)
    
    # train
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    
    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        current_time = time.time()
        elapsed_time = current_time - start_time
        losses = estimate_loss(model, batch_size)
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, time elapsed: {elapsed_time:.2f} seconds")

# Disable anomaly detection after the training loop
#torch.autograd.set_detect_anomaly(False)

step 0: train loss 3.5073, val loss 3.5963, time elapsed: 0.20 seconds
step 10: train loss 3.4327, val loss 3.4751, time elapsed: 3.37 seconds
step 20: train loss 3.3197, val loss 3.4030, time elapsed: 6.50 seconds
step 30: train loss 3.2524, val loss 3.2527, time elapsed: 9.74 seconds
step 40: train loss 3.1818, val loss 3.2755, time elapsed: 12.93 seconds
step 50: train loss 3.1007, val loss 3.1933, time elapsed: 15.97 seconds
step 60: train loss 3.0814, val loss 3.1016, time elapsed: 19.17 seconds
step 70: train loss 3.0129, val loss 3.0688, time elapsed: 22.49 seconds
step 80: train loss 3.0285, val loss 3.0388, time elapsed: 25.59 seconds
step 90: train loss 2.9795, val loss 3.0136, time elapsed: 28.78 seconds
step 99: train loss 2.9170, val loss 2.9423, time elapsed: 31.75 seconds


# Inference

In [66]:
input_str = "JULIET:\nO Romeo, Romeo! wherefore art thou Rom" # the classic line
output = model.generate(input_str)
print(output)

JULIET:
O Romeo, Romeo! wherefore art thou Rome slont d
Ae fo thnggOOOOO adI....

D
lank seho erBBBBBBurinoold orot td sour sen sdntitotAnhMosbsth
