In [1]:
# Importing pytorch
import torch
import torch.nn as nn
from torch.nn import functional as F

# used for the tokenizer
import pickle
import os

# Imports used for the config
import dataclasses 
from typing import Optional

# Imports used for the model
import re
from typing import Any, List, Sequence, Tuple, Union

# used in the training loop
import time

# The Dataset

the dataset we'll be using is just TinyShakespeare for sake of simplicity & ability to do run/train locally on any computer

In [2]:
# load the dataset
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# the first 200 characters. It's just one continuous text document with all of the works of shakespeare back-to-back
print(text[:200])

# here are all the unique characters that occur in this text and how many there are
chars = sorted(list(set(text)))
v = len(chars)
print('\n', chars, v)

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you

 ['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'] 65


# The Tokenizer

We'll be using a very simple tokenizer I trained off of the TinyShakespeare dataset that has 128 total tokens and ignores stuff like special tokens & regex. 

In [3]:
# Load the tokenizer data using pickle
with open('./tokenizers/tokenizer.model', 'rb') as f:
    loaded_tokenizer_data = pickle.load(f)

# Extract the stoi mapping and merges from the loaded data
loaded_stoi = loaded_tokenizer_data['stoi']
loaded_merges = loaded_tokenizer_data['merges']

class SimpleTokenizer:
    def __init__(self, stoi, merges):
        self.stoi = stoi
        self.merges = merges
        self.itos = {i: s for s, i in stoi.items()}  # Inverse mapping for decoding

        self.vocab_len = len(stoi) + len(merges)

    def encode(self, text):
        # Convert the text to a list of token IDs, using space for unknown characters
        tokens = [self.stoi.get(c, self.stoi[' ']) for c in text]

        # Perform merging with the possibility of nested merges
        i = 0
        while i < len(tokens) - 1:
            pair = (tokens[i], tokens[i + 1])
            if pair in self.merges:
                # Replace the current pair with its merged token
                merged_token = self.merges[pair]
                tokens[i] = merged_token
                del tokens[i + 1]

                # Move back to handle possible nested merges
                if i > 0:
                    i -= 1
            else:
                i += 1

        return tokens

    def decode(self, tokens):
        def expand_token(token):
            # Base case: if the token is a direct mapping, return its character
            if token in self.itos:
                return self.itos[token]
            # Recursive case: if the token is a merged token, expand its constituents
            elif token in self.merges.values():
                pair = next(key for key, value in self.merges.items() if value == token)
                return ''.join(expand_token(t) for t in pair)
            # Fallback for unknown tokens
            else:
                return ''

        # Decode each token in the list, handling nested merges recursively
        return ''.join(expand_token(token) for token in tokens)
        
# Example usage
# Assuming loaded_stoi and loaded_merges are already loaded from the tokenizer.model file

tokenizer = SimpleTokenizer(loaded_stoi, loaded_merges)
print("vocab length: ", tokenizer.vocab_len)

# Encoding text
encoded_text = tokenizer.encode("JULIET:\nO Romeo, Romeo! wherefore art thou Romeo?")
print("Encoded:", encoded_text)

# Decoding back
decoded_text = tokenizer.decode(encoded_text)
print("Decoded:", decoded_text)

vocab length:  128
Encoded: [22, 33, 24, 21, 17, 32, 71, 27, 1, 30, 53, 83, 53, 66, 30, 53, 83, 53, 2, 1, 61, 87, 93, 105, 43, 1, 77, 58, 1, 65, 67, 1, 30, 53, 83, 53, 12]
Decoded: JULIET:
O Romeo, Romeo! wherefore art thou Romeo?


# Config

In [4]:
@dataclasses.dataclass # a class meant specifically to just hold data
class Config:
    """ 
    The default configuration & hyperparameters for minGemma. regular Gemma 2b & 7b are defined below 
    Explanations for many of these will be more clear later when they actually get used
    """
    # The number of tokens in the vocabulary.
    vocab_size: int = tokenizer.vocab_len
    
    # The maximum sequence length that this model might ever be used with.
    max_position_embeddings: int = 256
    
    # The number of layers in the model.
    num_hidden_layers: int = 4
    
    # The number of attention heads used in the attention layers of the model.
    num_attention_heads: int = 4
    
    # The number of key-value heads for implementing multi-query attention.
    num_key_value_heads: int = 1 
    
    # The hidden size of the model, AKA the embedding dimension. Each token embedding vector will be this long
    hidden_size: int = 128 
    
    # The inner dimension of the MLP part of the decoder layer
    intermediate_size: int = 512
    
    # The number of head dimensions
    head_dim: int = 32
    
    # The epsilon used by the rms normalization layers.
    rms_norm_eps: float = 1e-6 # this is to promote numerical stability & prevent dividing by 0
    
    # the scaling factor that determines the frequencies for the rotary positional encodings
    rope_theta = 100.0
    # smaller models should use a smaller theta, but I'm just guessing here. 1000 might work too. 10,000 is the usual

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # the number of predictions ahead that the FutureFormer can see
    jirachi = 2

    # whether to print out absolutely everything that happens during inference
    verbose = False

config = Config()

# Rotary Positional Encoding (RoPE)

Gemma uses RoPE, which is a popular relative positional encoding scheme. Positional encodings are designed to help the model understand the order of tokens in the text since transformers don't have a built-in ordering to them when reading a sequence. Instead of telling Gemma the precise location of each token (which would be "absolute" positional encoding), relative positional encodings only reveal to Gemma the placement of different tokens relative to each other. RoPE is a relative positional encoding scheme that functions by multiplying against query & key matrices in the attention mechanism (rather than, for example, addition along the token embeddings) and uses trigonometric functions to effecitvely "rotate" the matrices used in self-attention. RoPE is the standard-to-beat nowadays as it's also find in other notable open-source models like Llama. 

The essential idea is that any given vector within the query & key matrices has a location $m$ and $n$ respectively. Remember a given vector within a query or key matrix corresponds to a specific token in the sequence. You multiply the rotary matrix corresponding to that position against each the query & key, and then when you further multiply those together you effectively have a comparison that takes into account how far away the two tokens are. 

If you like math, then you'll see the relation to a kernel. For some arbitrary $ 0 < \epsilon \leq \frac{\pi}{2N}$ is chosen, where $N$ is the maximum sequence length.

$$ \text{RoPE}(x, m) = x e^{mi\epsilon} $$
$$ \langle\text{RoPE}(q_j, m), \text{RoPE}(k_j, n)\rangle = \langle q_je^{mi\epsilon}, k_je^{ni\epsilon} \rangle $$
$$ = q_jk_je^{mi\epsilon}e^{ni\epsilon} $$
$$ = q_jk_je^{(m-n)i\epsilon} $$
$$ = \text{RoPE}(q_jk_j, m-n) $$

To get a more thorough understanding, check out [the original paper](https://arxiv.org/abs/2104.09864), [this slightly more approachable but very thorough guide](https://blog.eleuther.ai/rotary-embeddings/) or [this far more approachable youtube video with visuals of the rotation](https://www.youtube.com/watch?v=o29P0Kpobz0). I'm not going to go super in-depth in my commenting on the code since this topic has been around for like 3 years now and it's not one of my interests and i'm so tired

In [5]:
def apply_rotary_emb(x: torch.Tensor, dim: int, theta: float = 10000.0) -> torch.Tensor:
    """Applies the rotary embedding to the inputted query or key tensor"""
    # Get sequence length
    seq_len = x.size(1)
    device = x.device
    
    # Dynamically compute frequency cis based on the input sequence length
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device).float() / dim))
    t = torch.arange(seq_len, device=device)
    freqs = torch.outer(t, freqs).float()
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64

    # Apply rotary embeddings to the input tensor
    x_ = torch.view_as_complex(torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), dim=-1))
    x_out = torch.view_as_real(x_ * freqs_cis.unsqueeze(0)).type_as(x)  # Ensure batch dimension is handled
    x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
    x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2], -1).transpose(1, 2)

    return x_out

# Root Mean Square Layer Normalization

Gemma uses [Root Mean Square Layer Normalization](https://arxiv.org/pdf/1910.07467.pdf) (RMSnorm) which is also not particularly new or interesting. As with other forms of layer normalization, the goal is to keep each layer within the same distribution so as to ensure smooth optimization. 

The root mean square is a statistic defined by squaring every entry in a given vector, averaging them, and then square-rooting that, which gives us a single floating point value.
$$ \text{RMS}(a) = \sqrt{\frac{1}{n}\sum\limits_{i=1}^n a_i^2} $$
Then we just divide the vector by that value
$$ \bar{a}_i = \frac{a_i}{\text{RMS}(a_i)}g_i$$

When the mean of summed inputs happens to equal zero, then RMSnorm lines up exactly with the older method [LayerNorm](https://arxiv.org/abs/1607.06450v1). The thesis of RMSnorm is essentially that scaling the variance is the thing that's actually helpful, and centering the mean is not so useful, therefore it makes sense for efficiency's sake to just not set the mean to 0. I'd also like to point out that RMSnorm places the vector onto a hypersphere (a sphere but with many more than 3 dimensions) with radius $\sqrt{n}$. This doesn't really matter here, but it's important for other projects I plan to release in the coming months.

In [6]:
class RMSNorm(torch.nn.Module):
    """
    Implements the RMS Normalization (Root Mean Square Normalization) layer.
    RMSNorm is a variant of layer normalization that normalizes the activations
    of the previous layer based on their root mean square value.

    Parameters:
    - dim (int): The dimension of the input features the normalization is applied to.
    - eps (float): A small value added to the denominator for numerical stability. Default is 1e-6.
    - add_unit_offset (bool): If True, adds a unit (1) to the learned scaling coefficient, effectively
      starting with no scaling. If False, the scaling coefficient starts from zero. Default is True.
    """

    def __init__(
        self,
        dim: int,
        eps: float = 1e-6,
        add_unit_offset: bool = True,
    ):
        super().__init__() 
        self.eps = eps  # Small epsilon value for numerical stability since you can't divide by 0
        self.add_unit_offset = add_unit_offset  # Flag to determine if a unit should be added to the weight
        
        # Initialize the weight parameter with zeros, which will be learned during training.
        # The shape of the weight is [dim], meaning one weight per feature dimension.
        self.weight = nn.Parameter(torch.zeros(dim))

    def _norm(self, x):
        """
        Private helper function to normalize the input tensor.

        Parameters:
        - x (Tensor): The input tensor to normalize.

        Returns:
        - Tensor: The normalized tensor.
        """
        # Calculate the root mean square value for each feature (across the last dimension),
        # then use reciprocal square root (rsqrt) for normalization.
        # Add self.eps to the denominator for numerical stability.
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        """
        Forward pass of the RMSNorm layer.

        Parameters:
        - x (Tensor): The input tensor to normalize.

        Returns:
        - output: The normalized and scaled tensor.
        """
        # Normalize the input tensor using the _norm function and ensure the data type matches the input.
        x = self._norm(x.float()).type_as(x)
        
        # If add_unit_offset is True, scale the normalized tensor by (1 + self.weight),
        # effectively starting with a scaling factor of 1 (no scaling).
        # Otherwise, scale by self.weight only.
        if self.add_unit_offset:
            output = x * (1 + self.weight)
        else:
            output = x * self.weight
            
        # Return the scaled output tensor.
        return output

# Multi-Layer Perceptron

Check out this [recent survey of 400 activation functions](https://arxiv.org/pdf/2402.09092.pdf) or [google's exploration of gated activation functions](https://arxiv.org/pdf/2002.05202.pdf) which introduced GeGLU, which is what we use here

In [7]:
class MLP(nn.Module):
    """
    This class implements a multi-layer perceptron with a GeGLU gating mechanism. The GeGLU
    activation combines a standard GeLU activation with a learned gating mechanism, enabling
    the network to control the flow of information more dynamically.

    Attributes:
        gate_proj (nn.Linear): A linear layer that transforms the input tensor to an intermediate
                               representation, which is then passed through a GeLU activation for
                               gating purposes.
        up_proj (nn.Linear): Another linear layer that transforms the input tensor to the same
                             intermediate representation but without gating. This representation
                             is element-wise multiplied by the gated tensor from `gate_proj`.
        down_proj (nn.Linear): A linear layer that transforms the gated and combined tensor back
                               to the original dimension of the hidden size, producing the final output.
    """

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
    ):
        """
        Initializes the GemmaMLP module.

        Parameters:
            hidden_size (int): The size of the input and output tensors.
            intermediate_size (int): The size of the tensor after the initial transformation
                                     and before the gating and final projection. This is typically
                                     larger than the hidden size to allow for a richer representation.
        """
        super().__init__()

        # Linear transformation for the gating mechanism, projecting input to an intermediate size.
        self.gate_proj = nn.Linear(hidden_size, intermediate_size)

        # Linear transformation for the input tensor, also projecting to the intermediate size but
        # intended for element-wise multiplication with the gated output.
        self.up_proj = nn.Linear(hidden_size, intermediate_size)

        # Linear transformation to project the gated and combined tensor back to the original
        # hidden size, completing the MLP structure.
        self.down_proj = nn.Linear(intermediate_size, hidden_size)

    def forward(self, x):
        """
        Defines the forward pass of the GemmaMLP module.

        Parameters:
            x (Tensor): The input tensor to the MLP.

        Returns:
            Tensor: The output tensor after applying the GeGLU gating mechanism and the MLP transformations.
        """
        # Applies linear transformation for gating.
        gate = self.gate_proj(x)

        # Applies GeLU activation to the gate, introducing non-linearity and enabling the gating mechanism.
        gate = F.gelu(gate)

        # Applies another linear transformation to the input tensor for subsequent combination with the gate.
        up = self.up_proj(x)

        # Element-wise multiplication of the gated tensor with the transformed input tensor, modulating
        # the input based on the gate's activation.
        fuse = gate * up

        # Applies the final linear transformation to project the modulated tensor back to the hidden size.
        outputs = self.down_proj(fuse)

        # Returns the final output tensor of the MLP, after gating and modulation.
        return outputs

# Attention

so in a normal NTP mask we'd just use lower-tril
```
[[1,0,0,0,0],
[1,1,0,0,0],
[1,1,1,0,0],
[1,1,1,1,0],
[1,1,1,1,1]]
```
but here we want the model to have the option of attending to a couple tokens into the future, but excluding the token that's its job to predict
```
[[1,0,1,0,0],
[1,1,0,1,0],
[1,1,1,0,1],
[1,1,1,1,0],
[1,1,1,1,1]]
```
how many? for now i've set that as a configurable parameter

In [11]:
def create_extended_attention_mask(seq_length, future_tokens):
    """ creates our future-sight attention mask """
    # this is a regular next-token prediction mask that only lets tokens see into the past, not the future
    tril = torch.tril(torch.ones((seq_length, seq_length)))
    
    # Instead of using loops, create a mask for future tokens
    # The diagonal of ones will be extended to the right by `future_tokens` positions
    if future_tokens > 0:
        # the 1's in the top-right of this matrix allow the model to see in the future
        future_sight_mask = torch.tril(torch.ones((seq_length, seq_length)), diagonal=future_tokens+1)

        # this will be used to mask out the next token specifically
        mask_out_next_token = torch.tril(torch.ones((seq_length, seq_length)), diagonal=1)

        # further into the future tokens minus the next token plus all the past/present tokens
        mask = future_sight_mask -1 * mask_out_next_token + tril
    else:
        # This function defaults to a regular NTP mask if you don't input anything
        mask = tril
        
    return mask

mask = create_extended_attention_mask(8, 2).unsqueeze(0).unsqueeze(0)
mask.shape, mask

(torch.Size([1, 1, 8, 8]),
 tensor([[[[1., 0., 1., 1., 0., 0., 0., 0.],
           [1., 1., 0., 1., 1., 0., 0., 0.],
           [1., 1., 1., 0., 1., 1., 0., 0.],
           [1., 1., 1., 1., 0., 1., 1., 0.],
           [1., 1., 1., 1., 1., 0., 1., 1.],
           [1., 1., 1., 1., 1., 1., 0., 1.],
           [1., 1., 1., 1., 1., 1., 1., 0.],
           [1., 1., 1., 1., 1., 1., 1., 1.]]]]))

In [12]:
class Attention(nn.Module):
    """
    Implements Multi-Query Attention which supports a distinct number of attention heads for queries and key-values (KV).
    In the case where the same number of queries and key-values are used, this implemenation is equivalent to regular Multi-Head Attention.  
    """
    
    def __init__(self, config: Config):
        super().__init__()

        self.num_heads = config.num_attention_heads
        self.num_kv_heads = config.num_key_value_heads
        
        # Ensures that the number of query heads is evenly divisible by the number of KV heads.
        assert self.num_heads % self.num_kv_heads == 0
        # Determines the number of query heads associated with each KV head.
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

        self.hidden_size = config.hidden_size
        self.head_dim = config.head_dim
        self.theta = config.rope_theta

        # Calculates the total size for all query projections.
        self.q_size = self.num_heads * self.head_dim
        # Calculates the total size for all key and value projections.
        self.kv_size = self.num_kv_heads * self.head_dim

        # Defines the scaling factor for the attention scores.
        self.scaling = self.head_dim**-0.5

        # Initializes the linear projection layer for queries, keys, and values
        self.qkv_proj = nn.Linear(self.hidden_size, (self.num_heads + 2 * self.num_kv_heads) * self.head_dim, bias=False)
        # Initializes the output projection layer, mapping the concatenated attention outputs back to the hidden size.
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)

        # for our attention mask we'll use very large negative values to prevent attending to certain tokens
        #mask_negatives = torch.full((1, 1, config.max_position_embeddings, config.max_position_embeddings),-2.3819763e38).to(torch.float)
        # then we'll replace the lower triangular ones with 0's to allow attention to see past tokens
        #mask = torch.triu(mask_negatives, diagonal=1).to(config.device)
        # then we'll use `register_buffer` to define self.mask as a tensor that shouldn't undergo gradient descent
        #self.register_buffer('mask', mask)
    
        # let's do a multiplicative next-token prediction attention mask
        #mask = torch.triu(torch.ones((1, 1, config.max_position_embeddings, config.max_position_embeddings), device=config.device), diagonal=1)
        #mask = 1.0 - mask  # Invert the mask: 1s become 0s (masked), and 0s become 1s (unmasked)
        mask = create_extended_attention_mask(config.max_position_embeddings, 
                                              config.jirachi).unsqueeze(0).unsqueeze(0)
        self.register_buffer('mask', mask)


    def forward(self,
                # The input tensor to the attention mechanism. shape (batch_size, input_len, hidden_size)
                hidden_states: torch.Tensor, 
                ) -> torch.Tensor:

        # Ensures the input tensor is 3-dimensional (batch_size, input_len, hidden_size).
        hidden_states_shape = hidden_states.shape
        assert len(hidden_states_shape) == 3

        # Extracts batch size and input sequence length from the hidden states tensor.
        batch_size, input_len, _ = hidden_states_shape

        # Applies the linear projection to the hidden state to retrieve our q, k & v projections
        qkv = self.qkv_proj(hidden_states)
        
        # Splits the combined QKV tensor into separate tensors for queries (xq), keys (xk), and values (xv) based on their respective sizes.
        xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size],dim=-1)
        # for readability's sake it would've made more sense to do these separately; this is just more efficient

        # Reshapes each of the Q, K, and V tensors to separate the heads and align the dimensions for attention operations.
        xq = xq.view(batch_size, -1, self.num_heads, self.head_dim)
        xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim)
        xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim)

        # Applies rotary positional embeddings to queries and keys to incorporate positional information.
        xq = apply_rotary_emb(xq, self.head_dim, self.theta)
        xk = apply_rotary_emb(xk, self.head_dim, self.theta)

        # If the number of KV heads is different from the number of query heads, adjusts keys and values to match the query heads count.
        if self.num_kv_heads != self.num_heads:
            # [batch_size, input_len, n_local_heads, head_dim]
            key = torch.repeat_interleave(xk, self.num_queries_per_kv, dim=2)
            value = torch.repeat_interleave(xv, self.num_queries_per_kv, dim=2)

        # Transposes Q, K, and V tensors to align them for the batch matrix multiplication in attention calculation.
        # [batch_size, n_local_heads, input_len, head_dim]
        q = xq.transpose(1, 2)
        k = xk.transpose(1, 2)
        v = xv.transpose(1, 2)

        # Calculates attention scores by performing a batch matrix multiplication between queries and keys, followed by scaling.
        # [batch_size, n_local_heads, input_len, input_len]
        scores = torch.matmul(q, k.transpose(2, 3)) * self.scaling
        
        # Applies the lower-triangular mask to the attention scores
        #scores = scores + self.mask[...,:input_len, :input_len] # make sure mask is the correct size. input_len <= max_seq_len
        # is it weird that we're masking with addition of 0's & big negatives instead of multiplication of 1's and 0's or -inf's? 
        # as far as i'm aware it's weird, although my knowledge could just be out of date
        # Let's do multiplicative masking instead
        scores = scores * self.mask[...,:input_len, :input_len]

        # Applies softmax to the scores to obtain attention probabilities
        scores = F.softmax(scores, dim=-1)
        
        # Computes the weighted sum of values based on the attention scores to obtain the output of the attention mechanism.
        # [batch_size, n_local_heads, input_len, head_dim]
        output = torch.matmul(scores, v)

        # Reshapes the attention output to match the expected output dimensions, combining the heads back into the hidden dimension.
        # [batch_size, input_len, hidden_dim]
        output = output.transpose(1, 2).contiguous().view(batch_size, input_len, -1)

        # Applies the final linear projection to the attention output, mapping it back to the hidden size dimension.
        output = self.o_proj(output)

        return output

# Layer

interestingly, here we normalize not only before or after the decoder block, but also in-between the attention and MLP. Other than that though there's nothing unusual about these layers

In [13]:
class Layer(nn.Module):
    """
    A decoder layer that integrates the GemmaAttention mechanism and multi-layer perceptron (MLP). It includes
    normalization steps both before and after the attention mechanism to stabilize and accelerate training.
    """

    def __init__(self, config: Config):
        super().__init__()

        # Initializes the GemmaAttention mechanism with parameters from the config, enabling self-attention within the decoder layer.
        self.self_attn = Attention(config)
        
        # Initializes the GemmaMLP module, providing a non-linear transformation after the attention mechanism.
        self.mlp = MLP(
            # the hidden dimension of the model
            hidden_size = config.hidden_size,
            # the number of nodes in the center of the two feedforward layers
            intermediate_size = config.intermediate_size,
        )
        
        # Applies RMSNorm normalization to the input of the decoder layer for stable training dynamics.
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps = config.rms_norm_eps)
        
        # Applies RMSNorm after the attention mechanism and before the MLP to ensure the output is well-conditioned for further processing.
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps = config.rms_norm_eps)

    def forward(self,
                # The input tensor to the decoder layer. shape (batch_size, input_len, hidden_size)
                hidden_states: torch.Tensor
                ) -> torch.Tensor:
        
        # Self Attention Block
        # Stores the original input for use as a residual connection, aiding in mitigating the vanishing gradient problem
        residual = hidden_states
        # Normalizes the input before processing by the attention mechanism.
        hidden_states = self.input_layernorm(hidden_states)
        # Processes the normalized input through the GemmaAttention mechanism
        hidden_states = self.self_attn(hidden_states=hidden_states)
        # The aforementioned residual connection
        hidden_states = residual + hidden_states

        # MLP Block
        # Again, stores the output of the attention block for use as a residual connection before processing by the MLP.
        residual = hidden_states
        # Normalizes the output of the attention block before passing it to the MLP, ensuring a stable input distribution.
        hidden_states = self.post_attention_layernorm(hidden_states)
        # Transforms the normalized attention output through the MLP, introducing additional non-linearity and capacity to the model.
        hidden_states = self.mlp(hidden_states)
        # Another residual connection
        hidden_states = residual + hidden_states

        return hidden_states

# The Body of the Model

In [14]:
class Body(nn.Module):
    """ the class that loops through each layer of Gemma """

    def __init__(self, config: Config):
        super().__init__()
        self.config = config
        self.vocab_size = config.vocab_size

        # Initialize a sequence of DecoderLayer instances as specified by the number of hidden layers in the config
        self.layers = nn.ModuleList(Layer(config) for _ in range(config.num_hidden_layers))

        # Initialize a normalization layer to be applied after the last decoder layer, stabilizing the output
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(self,
                # The first residual state of the model. shape (batch_size, input_len, hidden_size)
                hidden_states: torch.Tensor
               ) -> torch.Tensor:

        # Iteratively process the input through each DecoderLayer, passing along necessary parameters for attention.
        # The hidden states are updated at each layer, progressively incorporating more complex dependencies and transformations.
        for i in range(len(self.layers)):
            layer = self.layers[i]
            hidden_states = layer(hidden_states=hidden_states)
        
        # Apply normalization to the output of the final decoder layer, ensuring the model's output is well-conditioned for subsequent use.
        hidden_states = self.norm(hidden_states)
        
        return hidden_states

# The Model itself

In [24]:
class FutureFormer(nn.Module):

    def __init__(self,
        config: Config, # the hyperparameters
        tokenizer: tokenizer, # the tokenizer. we don't always store the tokenizer inside of the model, but it doesn't matter here
    ):
        super().__init__()
        self.config = config

        # the attention heads need to cleanly divide up the hidden_size of the model so that we can split it all apart & combine back together
        assert config.hidden_size % config.num_attention_heads == 0

        self.max_seq_len = config.max_position_embeddings
        self.head_dim = config.head_dim
        self.vocab_size = config.vocab_size
        self.tokenizer = tokenizer

         # the embedding matrix. for converting tokens to the first residual state, and the last residual state to logits
        self.embedder = nn.Embedding(self.vocab_size, config.hidden_size)
        
        # the body of the model; all the transformer decoder layers
        self.model = Body(config)

        # the loss function
        self.criterion = nn.CrossEntropyLoss()

        # the number of time periods into the future the model can look
        self.jirachi = config.jirachi

        # should we print absolutely everything in the generate function?
        self.verbose = config.verbose

    def forward(
        self,
        input_token_ids: torch.Tensor, # a shape (batch_size, input_seq_len) list of integer token ids
        target_token_ids: torch.Tensor = None, # a shape (batch_size, input_seq_len) list of token ids to train on
        ) -> torch.Tensor:

        # turn the input tokens into the first resudial state using the embedding matrix
        # (batch_size, input_len) & (vocab_size, hidden_size) -> (batch_size, input_len, hidden_size)
        hidden_states = self.embedder(input_token_ids)
        
        # Gemma normalizes the embedding by sqrt(hidden_size)
        hidden_states = hidden_states * (self.config.hidden_size**0.5)

        # this is where the bulk of the calculations are performed, the actual decoder layers
        hidden_states = self.model(hidden_states=hidden_states) # -> (batch_size, input_len, hidden_size)

        # grabbing the weights of the embedding matrix shape (vocab_size, hidden_dim) for use as the output layer
        embedder_weight = self.embedder.weight

        # the embedding matrix is also used as the output layer
        # this saves on parameters & makes sense for interpretability
        # (batch_size, input_len, hidden_size) @ (hidden_size, vocab_size) -> (batch_size, input_len, vocab_size)
        logits = torch.matmul(hidden_states, embedder_weight.t())
        
        if target_token_ids is None: # if we're not training, then we don't need to calculate loss
            loss = None
        else:
            # if we are training
            batch_size, input_len, vocab_size = logits.shape
            # then we reshape our logits & targets before calculating cross-entropy loss
            loss = self.criterion(logits.view(batch_size*input_len, vocab_size), 
                                  target_token_ids.view(batch_size*input_len))
        
        return logits, loss

    @torch.no_grad() # no need to keep track of gradients during inference
    def Sampler(
        self,
        logits: torch.Tensor, # shape (batch_size, vocab_size)
        temperature: float, # controls how boring vs random the outputs should be
        top_p: float, # the maximum cumulative probability of output options we're willing to consider
        top_k: int, # the maximum number of output options we're willing to consider
    ) -> torch.Tensor:
        """
        The Sampler function is responsible for generating token predictions from Gemma's output.
        It supports temperature scaling, top-p (nucleus) sampling, and top-k sampling 
        The class operates as follows:
        
        1. Computes logits by multiplying the selected hidden states with the transposed embedding matrix. 
    
        2. Temperature is used to scale the logits, making the distribution over tokens sharper (lower temperature) 
        or flatter (higher temperature), which affects the randomness of the sampling (flatter -> more random)
    
        3. The softmax function is applied to the scaled logits to obtain a probability distribution over the vocabulary.
    
        4. For top-p sampling, the function computes the cumulative sum of the sorted probabilities and masks out tokens until the 
        cumulative probability exceeds the threshold defined by `top_ps`. This allows the model to focus on a subset of the most 
        probable tokens while ignoring the long tail of less likely tokens. 
        We to ignore long tail probabilities to avoid nonsensical output
    
        5. For top-k sampling, the function masks out all tokens except the `k` most likely ones, as specified by `top_ks`. 
        This ensures that the model only considers a fixed number of the most probable tokens for the next token prediction.
    
        6. After applying both the top-p and top-k masks, the probabilities are re-normalized so that they sum up to 1
    
        7. The function then samples from the re-normalized probability distribution to select the next token. 
        """
        # Select the last element for each sequence.
        # (batch_size, input_len, vocab_size) -> (batch_size, vocab_size)
        #logits = logits[:,-1,:]
        # instead we're going to pass in logits that have already been selected
        
        # Apply temperature scaling
        # (batch_size, vocab_size) / float -> (batch_size, vocab_size)
        logits.div_(temperature) # div_ is an in-place operation which is ok since we don't record gradients during inference

        # Calculate probabilities with softmax.
        probs = torch.softmax(logits, dim=-1, dtype=torch.float) # dim=-1 is the vocab_size dimension that we calculate along

        # sort the probabilities to for use in top-p & top-k
        # both are (batch_size, vocab_size)
        probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
        # probs_sort contains float probabilities while probs_idx contains integer indices

        # calculating top-p
        # creates same-size tensor of cumulatve probabilities instead of indivdiual probs
        probs_sum = torch.cumsum(probs_sort, dim=-1) 
        # mask where 0's are top-p selections & 1's are to be excluded
        top_ps_mask = (probs_sum - probs_sort) > top_p
        # the original probabilities with excluded tokens changed to 0.0
        probs_sort = torch.where(top_ps_mask, 0, probs_sort) 

        # calculating top_k
        # create a shape (vocab_size) tensor that just iterates up by 1's
        top_ks_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device) 
        # expand our mask along the batch_size dimension to become size (batch_size, vocab_size)
        # "expand" means copy the original into this new size, so each length vocab_size row is the same
        top_ks_mask = top_ks_mask.expand(probs_idx.shape[0], -1)
        # top_ks is a list of integers. we keep whichever entries in top_ks_mask are greater than their corresponding entries in top_ks
        top_ks_mask = top_ks_mask >= top_k

        # we'll be combining top-p with top-k and using whichever gives us fewer tokens. a very conservative approach
        # this trims probs_sort to also fit within our top_k requirement
        probs_sort = torch.where(top_ks_mask, 0, probs_sort)

        # Re-normalization so that total probabilities add up to 1
        probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
        
        # now we rearrange the modified probabilities in probs_sort back to their original order according to probs_idx
        probs = torch.gather(probs_sort,
                             dim=-1,
                             index=torch.argsort(probs_idx, dim=-1))
        
        # samples from the distribution
        next_token_id = torch.multinomial(probs, num_samples=1)
        
        return next_token_id # returns the predicted token

    def replacement_index(self, n: int):
        # calculates the correct index to select from logits or replace from the temporary token sequence
        return -1 -(n * (n + 1) // 2)

    def generate(
        self,
        prompt: str,
        output_len: int = 100, # the model will output 100 tokens
        temperature: float = 0.95, # 0.95 is pretty close to not even using temperature at all (1.0 would be no effect)
        top_p: float = 1.0, # defaulting to 1 means we essentially don't use top-p
        top_k: int = 65, # setting top_k = vocab_size means we're effectively not using top_k at all
    ) -> str: 
        """Generates responses for given prompts using our FutureFormer algorithm"""
        
        # encoding the prompt into token indices
        tokens = self.tokenizer.encode(prompt)

        # turning it into the right tensor shape
        tokens = torch.tensor(tokens, device=config.device).unsqueeze(0)
        
        # we wouldn't want to go past the maximum context length we trained on
        assert len(tokens) + output_len <= self.max_seq_len

        # we'll be keeping track of the actual tokens and a set that will change
        tokens_temp = tokens

        # the initialization loop to build up our first set of future tokens
        if self.verbose: print("----------- initialization loop ----------")
        for i in range(self.jirachi):
            if self.verbose: print("i: ", i)
                
            for _ in range(i+1):
                if self.verbose: print("_: ", _)
                    
                for j in range(i+1):
                    if self.verbose: print("j: ", j)
                        
                    # get the model's output logits and ignore the loss, which would be a NoneType object
                    logits, _ = self(tokens_temp[:,:self.max_seq_len])
                    if self.verbose: print("logits: ", logits.shape)
                    
                    # select the timestep of interest
                    logits = logits[:, self.replacement_index(j), :]
                    if self.verbose: print("replacement index j: ", self.replacement_index(j))
                    if self.verbose: print("logits; ", logits.shape)
            
                    next_token = self.Sampler(logits = logits,
                                              temperature = temperature,
                                              top_p = top_p,
                                              top_k = top_k)
                    if self.verbose: print("next_token: ", next_token.shape, next_token)
        
                    if j == 0:
                        # add the token to the sequence
                        tokens_temp = torch.cat((tokens_temp[:, -self.max_seq_len:], next_token), dim=1)
                        if self.verbose: print("tokens_temp: ", tokens_temp.shape, tokens_temp)
                    else:
                        # replace the previously predicted token
                        tokens_temp[:,self.replacement_index(j)] = next_token # originally i
                        if self.verbose: print("replacement index j: ", self.replacement_index(j)) # originally i
                        if self.verbose: print("tokens_temp: ", tokens_temp.shape, tokens_temp)

        # the more repetitive loop that actually gets our final tokens
        if self.verbose: print("--------- inference loop ---------")
        for i in range(output_len):
            if self.verbose: print("i: ", i)
                
            for j in range(self.jirachi+1):
                if self.verbose: print("j: ", j)
                    
                # get the model's output logits and ignore the loss, which would be a NoneType object
                logits, _ = self(tokens_temp[:,:self.max_seq_len])
                if self.verbose: print("logits: ", logits.shape)
                    
                # select the timestep of interest
                logits = logits[:, self.replacement_index(j), :]
                if self.verbose: print("replacement index j: ", self.replacement_index(j))
                if self.verbose: print("logits: ", logits.shape)
            
                next_token = self.Sampler(logits = logits,
                                          temperature = temperature,
                                          top_p = top_p,
                                          top_k = top_k)
                if self.verbose: print("next_token: ", next_token.shape, next_token)
                
                if j == 0:
                    # add the token to the sequence
                    tokens_temp = torch.cat((tokens_temp[:, -self.max_seq_len:], next_token), dim=1)
                    if self.verbose: print("tokens_temp: ", tokens_temp.shape, tokens_temp)
                else:
                    tokens_temp[:,self.replacement_index(j)] = next_token # originally i
                    if self.verbose: print("replacement index j: ", self.replacement_index(j)) # originally i
                    if self.verbose: print("tokens_temp: ", tokens_temp.shape, tokens_temp)

                    if j == self.jirachi:  # originally i
                        # append the largest model's output to the actual sequence idx
                        tokens = torch.cat((tokens, next_token), dim=1)
                        if self.verbose: print("tokens: ", tokens.shape, tokens)

        # decode our list of tokens to an actual string
        output = self.tokenizer.decode(tokens.squeeze(0).tolist())

        return output

        #for i in range(output_len):
        #    # get the model's output logits and ignore the loss, which would be a NoneType object
        #    logits, _ = self(tokens[:,:self.max_seq_len])
        #    
        #    next_token = self.Sampler(
        #        logits = logits, # the actual output of the model
        #        temperature = temperature,
        #        top_p = top_p,
        #        top_k = top_k
        #    )
#
        #    # add our new token to the sequence
        #    tokens = torch.cat((tokens, next_token), dim=1)
#
        ## decode our list of tokens to an actual string
        #output = self.tokenizer.decode(tokens.squeeze(0).tolist())
#
        #return output

# Training-related Functions

In [16]:
# Train and test splits
data = torch.tensor(tokenizer.encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be our training dataset, the rest for validation
train_data = data[:n]
val_data = data[n:]

In [17]:
# data loading for training which generates a small batch of data of inputs x and targets y
def get_batch(split, batch_size):
    # whether we grab from our training or validation dataset
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - config.max_position_embeddings, (batch_size,))
    x = torch.stack([data[i:i+config.max_position_embeddings] for i in ix])
    y = torch.stack([data[i+1:i+config.max_position_embeddings+1] for i in ix])
    x, y = x.to(config.device), y.to(config.device)
    return x, y

In [18]:
# a demonstration of what a batch with batch_size=1 looks like. Notice the one-token offset in characters
xb, yb = get_batch('train', 1)
print(xb)
print(tokenizer.decode(xb.squeeze(0).tolist()))
print("-------")
print(yb)
print(tokenizer.decode(yb.squeeze(0).tolist()))

tensor([[121,   1,  57,  87,   0,  57,  92,  81,   1,  92,  95,   1,  90,   1,
          51,  73,  43,   1,  43,  63,  89,   1,  84,   1,  91,  43,   1, 100,
          65,  39,  50,   1,  65,  70,   1,  39,   1,  41,  78,   8,   0,  37,
          67,   1,  49,  90,  61,   1, 102,  51,   1, 116,  66,  57, 112,  85,
          20,  27,  30,  32,  17,  26,  31,  21,  27,  71,  32,  77,  56,  63,
          66,  28,  43,  58,  56,  59, 108,  47,  53,  66,  21,   1,  51,  59,
          80,   1,  45,  53,   1, 100,  65,   1,  72,  43,  82,  18,  73,   1,
          69,   1,  14,  39,  54,  58,  74,  58,  39,   5,  57,   1, 124,  43,
          54,   1, 101,   1,  58,  93, 107,  59,  93,   1,  74,  71,  20,  43,
           1,  92,  65,   1,  72,   1,  48,  43, 119,  50,   1,  94,   1, 101,
           1, 117,  44,  43,   1,  69,   1,  46,  53, 111,  82,  20,  74,   1,
          88,  52,  45,  89,  58,   1,  42,  39,  59, 122,  58,  68,  66,  98,
          39, 114,  47,  44,  59,  50,   1,  14,  69

In [19]:
@torch.no_grad()
def estimate_loss(model, batch_size, eval_iters = 10): # to estimate loss during the training loop
    out = {}
    model.eval() # sets model to eval mode
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split, batch_size)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train() # just resets to training mode
    return out

# Instantiating a brand new model

In [25]:
model = FutureFormer(config, tokenizer).to(config.device)

# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')

print(model)

972.416 K parameters
FutureFormer(
  (embedder): Embedding(128, 128)
  (model): Body(
    (layers): ModuleList(
      (0-3): 4 x Layer(
        (self_attn): Attention(
          (qkv_proj): Linear(in_features=128, out_features=192, bias=False)
          (o_proj): Linear(in_features=128, out_features=128, bias=False)
        )
        (mlp): MLP(
          (gate_proj): Linear(in_features=128, out_features=512, bias=True)
          (up_proj): Linear(in_features=128, out_features=512, bias=True)
          (down_proj): Linear(in_features=512, out_features=128, bias=True)
        )
        (input_layernorm): RMSNorm()
        (post_attention_layernorm): RMSNorm()
      )
    )
    (norm): RMSNorm()
  )
  (criterion): CrossEntropyLoss()
)


# Load a Pretrained Model

In [18]:
# Initialize a blank model
model = minGemma(config, tokenizer).to(config.device)  

# here's the path to a minGemma model that i've trained with roughly 1m parameters
path = 'models/minGemma-vocab_size128-max_position_embeddings256-num_hidden_layers4-num_attention_heads4-num_key_value_heads1-hidden_size128-intermediate_size512-head_dim32-rms_norm_eps1e-06-rope_theta100.0--2024-02-26|11-10-53.pth'

# Load the saved state dictionary
model.load_state_dict(torch.load(path))
# REMEMBER TO CHANGE VALUES IN CONFIG TO MATCH THE MODEL YOU'VE LOADED

# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')

# If you only plan to do inference, switch to evaluation mode
model.eval()

# If you plan to continue training the model, switch to training mode
#model.train()

964.352 K parameters


minGemma(
  (embedder): Embedding(65, 128)
  (model): GemmaBody(
    (layers): ModuleList(
      (0-3): 4 x GemmaDecoderLayer(
        (self_attn): GemmaAttention(
          (qkv_proj): Linear(in_features=128, out_features=192, bias=False)
          (o_proj): Linear(in_features=128, out_features=128, bias=False)
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear(in_features=128, out_features=512, bias=True)
          (up_proj): Linear(in_features=128, out_features=512, bias=True)
          (down_proj): Linear(in_features=512, out_features=128, bias=True)
        )
        (input_layernorm): RMSNorm()
        (post_attention_layernorm): RMSNorm()
      )
    )
    (norm): RMSNorm()
  )
  (criterion): CrossEntropyLoss()
)

# Training

In [26]:
# create a PyTorch optimizer
# this is not what they used, but this learning rate & weight decay work for our tiny minGemma
learning_rate = 3e-4
weight_decay = 0.01
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# how long we want to train for
max_iters = 2

# how often we want to check & see how our loss is doing
eval_interval = 2

# batch size to use
batch_size = 32

In [27]:
start_time = time.time()

# Enable anomaly detection. uncomment these lines if you need to do extensive debugging
#torch.autograd.set_detect_anomaly(True)

for iter in range(max_iters):

    # sample a batch of data
    xb, yb = get_batch('train', batch_size)
    
    # train
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    
    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        current_time = time.time()
        elapsed_time = current_time - start_time
        losses = estimate_loss(model, batch_size)
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, time elapsed: {elapsed_time:.2f} seconds")

# Disable anomaly detection after the training loop
#torch.autograd.set_detect_anomaly(False)

step 0: train loss 129.3206, val loss 129.5448, time elapsed: 0.68 seconds
step 1: train loss 129.1268, val loss 129.3054, time elapsed: 5.94 seconds


# Saving your model

In [20]:
# save the model currently held in memory
# the filename specifies the model's class, hyperparameters, and date/time it was saved
torch.save(model.state_dict(),
           f'models/{model.__class__.__name__}'
           f'-vocab_size{config.vocab_size}'
           f'-max_position_embeddings{config.max_position_embeddings}'
           f'-num_hidden_layers{config.num_hidden_layers}'
           f'-num_attention_heads{config.num_attention_heads}'
           f'-num_key_value_heads{config.num_key_value_heads}'
           f'-hidden_size{config.hidden_size}'
           f'-intermediate_size{config.intermediate_size}'
           f'-head_dim{config.head_dim}'
           f'-rms_norm_eps{config.rms_norm_eps}'
           f'-rope_theta{config.rope_theta}'
           f'-jirachi{config.jirachi}'
           f'--{time.strftime("%Y-%m-%d|%H-%M-%S")}.pth')

# Inference

In [28]:
input_str = "JULIET:\nO Romeo, Romeo! wherefore art thou " # the classic line
max_useable_output_len = config.max_position_embeddings - len(input_str)
output = model.generate(input_str, output_len = max_useable_output_len)
print(output)

JULIET:
O Romeo, Romeo! wherefore art thou                                                                                                                                                                                                                      
