# minGrok

for the original guide see https://github.com/evintunador/minGrok

In [1]:
import sys
sys.path.append('./venv/lib/python3.10/site-packages')
import dataclasses
import torch
import torch.nn as nn
from torch.nn import functional as F
from tokenizer import SimpleTokenizer, loaded_stoi, loaded_merges

In [2]:
# load the dataset
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# the first 200 characters. It's just one continuous text document with all of the works of shakespeare back-to-back
print(text[:200])

# here are all the unique characters that occur in this text and how many there are
chars = sorted(list(set(text)))
v = len(chars)
print(chars)
print(v)

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you
['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
65


In [3]:
tokenizer = SimpleTokenizer(loaded_stoi, loaded_merges)
print("vocab length: ", tokenizer.vocab_len)

# Encoding text
encoded_text = tokenizer.encode("JULIET:\nO Romeo, Romeo! wherefore art thou Romeo?")
print("Encoded:", encoded_text)

# Decoding back
decoded_text = tokenizer.decode(encoded_text)
print("Decoded:", decoded_text)

vocab length:  128
Encoded: [22, 33, 24, 21, 17, 32, 71, 27, 1, 30, 53, 83, 53, 66, 30, 53, 83, 53, 2, 1, 61, 87, 93, 105, 43, 1, 77, 58, 1, 65, 67, 1, 30, 53, 83, 53, 12]
Decoded: JULIET:
O Romeo, Romeo! wherefore art thou Romeo?


In [4]:
@dataclasses.dataclass
class Config:
    # v was defined earlier when we loaded TinyShakespeare. In Grok it's 131,072
    vocab_size: int = tokenizer.vocab_len

    # The maximum sequence length that this model might ever be used with.
    max_position_embeddings: int = 256 # in Grok it's 8,192

    # The number of layers in the model.
    num_layers: int = 4 # In Grok it's 64

    # The number of attention heads used in the attention layers of the model.
    num_attention_heads: int = 4 # In Grok it's 48

    # The number of key-value heads for implementing attention.
    num_key_value_heads: int = 1 # In Grok it's 8

    # The hidden size of the model, AKA the embedding dimension. Each token embedding vector will be this long
    hidden_size: int = 96 # In Grok it's 6,144

    # How much wider should the inner dimension of the experts be than the model's embedding dimension?
    embedding_multiplier_scale: int = 2 # In Grok it's roughly 5.33

    # how many experts?
    tot_num_experts: int = 4 # in Grok it's 8

    # how many active experts per token?
    chosen_num_experts: int = 2 # in Grok it's also 2

    # what amount of noise should be injected into the router during training?
    noise_std = 0.1 # the value for Grok has not been shared

    # When we create a loss to encourage all experts to be used, how should that loss be weighted?
    lambadada = 10 # Grok's value has not been shared
    # excuse my silly naming

    # The number of head dimensions
    head_dim: int = 24 # In Grok it's 128

    # The epsilon used by the rms normalization layers.
    rms_norm_eps: float = 1e-5 # this is to promote numerical stability & prevent dividing by 0

    # the scaling factor that determines the frequencies for the rotary positional encodings
    rope_theta = 100.0 # Grok and most models use 10,000
    # smaller models should use a smaller theta, but I'm just guessing here. 1000 might work too

    # whether to use a linear layer after normalization
    use_scale: bool = True # same in Grok

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # the dropout rate to use during training
    dropout = 0.05
    
config = Config()

In [None]:
def apply_rotary_emb(x: torch.Tensor, dim: int, theta: float = 10000.0) -> torch.Tensor:
    """Applies the rotary embedding to the inputted query or key tensor"""
    # Get sequence length
    seq_len = x.size(1)
    device = x.device

    # Dynamically compute frequency cis based on the input sequence length
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device).float() / dim))
    t = torch.arange(seq_len, device=device)
    freqs = torch.outer(t, freqs).float()
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64

    # Apply rotary embeddings to the input tensor
    x_ = torch.view_as_complex(torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), dim=-1))
    x_out = torch.view_as_real(x_ * freqs_cis.unsqueeze(0)).type_as(x)  # Ensure batch dimension is handled
    x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
    x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2], -1).transpose(1, 2)

    return x_out

In [None]:
class MQA(nn.Module):
    """
    Implements Multi-Query Attention which supports a distinct number of attention heads for queries and key-values (KV).
    In the case where the same number of queries and key-values are used, this implemenation is equivalent to regular Multi-Head Attention.
    """
    def __init__(self, config):
        super().__init__()

        self.num_heads = config.num_attention_heads
        self.num_kv_heads = config.num_key_value_heads
        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

        self.hidden_size = config.hidden_size
        self.head_dim = config.head_dim
        self.theta = config.rope_theta

        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim

        self.qkv_proj = nn.Linear(self.hidden_size, (self.num_heads + 2 * self.num_kv_heads) * self.head_dim, bias=False)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)

        # Create a mask tensor with shape [batch_size, num_heads, seq_len, seq_len]
        self.mask = torch.tril(torch.ones((config.max_position_embeddings, config.max_position_embeddings), 
                                     dtype=torch.uint8)).view(1, 1, config.max_position_embeddings, config.max_position_embeddings).to(dtype=torch.bool)
        #self.mask = mask.expand(-1, self.num_heads, -1, -1)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states_shape = hidden_states.shape
        assert len(hidden_states_shape) == 3
        batch_size, input_len, _ = hidden_states_shape

        # Applies the linear projection to the hidden state to retrieve our q, k & v projections
        qkv = self.qkv_proj(hidden_states)
        xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size],dim=-1)

        # Reshapes each to separate the heads and align the dimensions for attention operations.
        xq = xq.view(batch_size, -1, self.num_heads, self.head_dim)
        xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim)
        xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim)

        # Applies rotary positional embeddings to queries and keys to incorporate positional information.
        xq = apply_rotary_emb(xq, self.head_dim, self.theta)
        xk = apply_rotary_emb(xk, self.head_dim, self.theta)

        # If the number of KV heads is different from the number of query heads, adjusts keys and values to match the query heads count.
        if self.num_kv_heads != self.num_heads:
            xk = torch.repeat_interleave(xk, self.num_queries_per_kv, dim=2)
            xv = torch.repeat_interleave(xv, self.num_queries_per_kv, dim=2)

        # Transposes to align them for the batch matrix multiplication in attention calculation.
        q = xq.transpose(1, 2)
        k = xk.transpose(1, 2)
        v = xv.transpose(1, 2)

        # Calculates attention logits by performing a batch matrix multiplication between queries and keys
        logits = torch.matmul(q, k.transpose(2, 3))

        # Grok's unusual scaling method
        # If anyone knows why they use 0.08838834764831845 in Grok please lmk. Maybe it's a learned value?
        logits *= 0.08838834764831845
        # Next here we'll scale and clip our attention logits
        # the tanh is a nonlinear function that pushes all of the entries in logits into the range (-1, 1)
        # then they're scaled up to the range (-30, 30). The number 30 is an arbitrary choice
        # the purpose of this scaling is to regularize and prevent numerical stability that might otherwise mess with the upcoming softmax
        max_attn_val = torch.tensor(30.0, dtype = logits.dtype)
        logits = max_attn_val * torch.tanh(logits / max_attn_val)
        # other transformers would replace the last three lines with a multiplication by torch.sqrt(self.hidden_size)

        # Applies the lower-triangular mask to the attention logits
        logits = torch.where(self.mask[..., :input_len, :input_len].expand_as(logits), logits, torch.tensor(-1e30, device=logits.device, dtype=logits.dtype))

        # Applies softmax to the logits to obtain attention probabilities
        scores = F.softmax(logits, dim=-1)

        # Computes the weighted sum of values based on the attention scores to obtain the output of the attention mechanism.
        output = torch.matmul(scores, v)

        # Reshapes the attention output to match the expected output dimensions, combining the heads back into the hidden dimension.
        output = output.transpose(1, 2).contiguous().view(batch_size, input_len, -1)

        # Applies the final linear projection to the attention output, mapping it back to the hidden size dimension.
        output = self.o_proj(output)

        return output

In [None]:
class Expert(nn.Module):
    def __init__(self, model_dim, hidden_dim):
        super().__init__()
        self.layer1 = nn.Linear(model_dim, hidden_dim * 2, bias=False)  # Double the output for gating
        self.layer2 = nn.Linear(hidden_dim, model_dim, bias=False)  # Output layer remains the same

    def forward(self, x):
      # Split the output of the first layer for gating
        x, gate = self.layer1(x).chunk(2, dim=-1)

        # Apply GeLU to the gate, and then multiply element-wise
        x = F.gelu(gate) * x
        x = self.layer2(x)

        return x

In [None]:
class Router(nn.Module):
    def __init__(self, input_size, tot_num_experts, noise_std: float = 0.1):
        super().__init__()
        self.tot_num_experts = tot_num_experts
        self.router_weights = nn.Linear(input_size, tot_num_experts, bias=False)
        self.noise_std = noise_std

    def forward(self, inputs, training: bool = False):
        routing_logits = self.router_weights(inputs)
        if training: routing_logits = routing_logits + torch.randn_like(routing_logits) * self.noise_std
        routing_probs = F.softmax(routing_logits, dim=-1)
        return routing_probs

In [None]:
class MoELayer(nn.Module):
    def __init__(self, model_dim, expert_hidden_dim, tot_num_experts, chosen_num_experts, noise_std):
        super().__init__()
        self.model_dim = model_dim
        self.tot_num_experts = tot_num_experts
        self.chosen_num_experts = chosen_num_experts
        self.experts = nn.ModuleList([Expert(model_dim, expert_hidden_dim) for _ in range(tot_num_experts)])
        self.router = Router(model_dim, tot_num_experts, noise_std)

    def forward(self, inputs, training: bool = False):
        b, seq_len, _ = inputs.shape

        # get the output of all the experts
        expert_outputs = [expert(inputs.view(-1, self.model_dim)) for expert in self.experts]
        expert_outputs = torch.cat(expert_outputs, dim=0).view(b, seq_len, self.tot_num_experts, self.model_dim)

        # get the output of the router and create out expert mask
        routing_probs = F.softmax(self.router(inputs), dim=-1)
        with torch.no_grad():
          expert_indices = torch.topk(routing_probs, k=self.chosen_num_experts, sorted=True).indices
          multi_hot_indices = torch.zeros(b, seq_len, self.tot_num_experts, device=inputs.device)
          multi_hot_indices = multi_hot_indices.scatter(2, expert_indices, 1)

        # Apply the multi-hot mask (first expand dimensions for broadcasting)
        multi_hot_expanded = multi_hot_indices.unsqueeze(-1).expand_as(expert_outputs)
        output_masked = expert_outputs * multi_hot_expanded.float()

        # then weight our experts' outputs by the softmax values (which we first must broadcast to the right shape) and sum them
        routing_probs_expanded = routing_probs.unsqueeze(-1).expand_as(output_masked)
        MoE_output = (output_masked * routing_probs_expanded).sum(dim=2)

        return MoE_output, routing_probs # we also output routing_probs to be used in the loss function later

In [None]:
class RMSNorm(nn.Module): # the same RMSNorm we wrote earlier
    def __init__(self, num_features, eps=1e-5, use_scale=True):
        super(RMSNorm, self).__init__()
        self.eps = eps
        self.scale = nn.Parameter(torch.ones(num_features)) if use_scale else None

    def forward(self, inputs):
        # Calculate the mean squared value for each feature
        mean_squared = inputs.pow(2).mean(dim=-1, keepdim=True)

        # Normalize inputs
        normed_inputs = inputs * torch.rsqrt(mean_squared + self.eps)

        # Apply scale if it exists
        if self.scale is not None:
            normed_inputs = normed_inputs * self.scale

        return normed_inputs

In [None]:
class DecoderLayer(nn.Module):
    """
    A decoder layer that integrates the Attention mechanism and MoE. It includes
    normalization steps both before and after the MQA and MoE but never actually normalized the residual connection
    """

    def __init__(self, config):
        super().__init__()

        self.mqa = MQA(config)

        self.moe = MoELayer(
            model_dim = config.hidden_size,
            expert_hidden_dim = config.hidden_size * config.embedding_multiplier_scale,
            tot_num_experts = config.tot_num_experts,
            chosen_num_experts = config.chosen_num_experts,
            noise_std = config.noise_std
        )

        self.pre_mqa_norm = RMSNorm(config.hidden_size, eps = config.rms_norm_eps, use_scale = config.use_scale)
        self.post_mqa_norm = RMSNorm(config.hidden_size, eps = config.rms_norm_eps, use_scale = config.use_scale)
        self.pre_moe_norm = RMSNorm(config.hidden_size, eps = config.rms_norm_eps, use_scale = config.use_scale)
        self.post_moe_norm = RMSNorm(config.hidden_size, eps = config.rms_norm_eps, use_scale = config.use_scale)

        self.drop = nn.Dropout(config.dropout)

    def forward(self, x: torch.Tensor, training: bool = False) -> torch.Tensor:
        if training:
            x = x + self.drop(self.post_mqa_norm(self.mqa(self.pre_mqa_norm(x))))
            moe_out, routing_probs = self.moe(self.pre_moe_norm(x), training)
            x = x + self.drop(self.post_moe_norm(moe_out))
        else:
            x = x + self.post_mqa_norm(self.mqa(self.pre_mqa_norm(x)))
            moe_out, routing_probs = self.moe(self.pre_moe_norm(x), training)
            x = x + self.post_moe_norm(moe_out)
        return x, routing_probs

In [None]:
class minGrok(nn.Module):

    def __init__(self, config, tokenizer):
        super().__init__()
        self.config = config

        # the attention heads need to cleanly divide up the hidden_size of the model so that we can split it all apart & combine back together
        assert config.hidden_size % config.num_attention_heads == 0

        self.max_seq_len = config.max_position_embeddings
        self.head_dim = config.head_dim
        self.vocab_size = config.vocab_size
        self.tokenizer = tokenizer

         # the embedding matrix. for converting tokens to the first residual state, and the last residual state to logits
        self.embedder = nn.Embedding(self.vocab_size, config.hidden_size)

        # Initialize a sequence of DecoderLayer instances as specified by the number of layers in the config
        self.layers = nn.ModuleList(DecoderLayer(config) for _ in range(config.num_layers))

        # Initialize a normalization layer to be applied after the last decoder layer, stabilizing the output
        self.final_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        # the primary loss function
        self.criterion = nn.CrossEntropyLoss()
    
        # the hyperparameter weighting the secondary loss function
        self.lambadada = config.lambadada

    def calc_moe_loss(self, routing_probs_list):
        # this is silly and inefficient but i'm tired and bored of this project ngl
        # basically i'm choosing to sum the per-layer MoE variances
        cum_var = torch.tensor(0.0) # this will be encouraged to be 0 so it doesn't even matter if we record the gradient
        for routing_probs in routing_probs_list:
            expert_usage = routing_probs.sum(dim=0)
            usage_mean = expert_usage.mean()
            expert_variance = ((expert_usage - usage_mean) ** 2).mean()
            cum_var = cum_var + expert_variance

        return cum_var

    # a more efficient version ChatGPT4 made that i'm too lazy to test, but go ahead if you want
    #def calc_moe_loss(self, routing_probs_list):
        # Concatenate all tensors along a new dimension (say, dim=0)
        # This results in a new tensor of shape (N, b, t, c) where N is the number of tensors in routing_probs_list
        #all_routing_probs = torch.cat([x.unsqueeze(0) for x in routing_probs_list], dim=0)
        
        # Sum across the batch (b) and time (t) dimensions, resulting in a shape of (N, c)
        #expert_usage = all_routing_probs.sum(dim=1).sum(dim=1)
        
        # Calculate the mean across the new dimension (N) and the experts (c), resulting in a single mean value
        #usage_mean = expert_usage.mean(dim=0).mean(dim=0)
        
        # Calculate the variance
        #expert_variance = ((expert_usage - usage_mean) ** 2).mean(dim=0).mean(dim=0)
        
        # Sum the variance across all layers (N)
        #cum_var = expert_variance.sum()
        
        #return cum_var

    def forward(
        self,
        input_token_ids: torch.Tensor, # a shape (batch_size, input_seq_len) list of integer token ids
        target_token_ids: torch.Tensor = None, # a shape (batch_size, input_seq_len) list of token ids to train on
        ) -> torch.Tensor:
        training = False if target_token_ids is None else True

        # turn the input tokens into the first resudial state using the embedding matrix
        x = self.embedder(input_token_ids) * self.config.hidden_size**0.5 # Grok normalizes the embedding by sqrt(hidden_size)

        # initialize a list to store the routing probs of each layer in
        routing_probs_list = []
        # Iteratively process the input through each DecoderLayer
        for i in range(len(self.layers)):
            layer = self.layers[i]
            x, routing_probs = layer(x, training)
            if training: routing_probs_list.append(routing_probs)

        # Apply normalization to the output of the final decoder layer
        x = self.final_norm(x)

        # grabbing the weights of the embedding matrix shape (vocab_size, hidden_dim) for use as the output layer
        embedder_weight = self.embedder.weight

        # the embedding matrix is also used as the output layer
        # this saves on parameters & makes sense for interpretability
        # (batch_size, input_len, hidden_size) @ (hidden_size, vocab_size) -> (batch_size, input_len, vocab_size)
        logits = torch.matmul(x, embedder_weight.t())

        if training: # if we are training
            batch_size, input_len, vocab_size = logits.shape

            # we reshape our logits & targets before calculating cross-entropy loss
            CEloss = self.criterion(logits.view(batch_size*input_len, vocab_size),
                                    target_token_ids.view(batch_size*input_len))
            
            # calculating the MoE loss that encourages all experts to be utilized
            MoEloss = self.calc_moe_loss(routing_probs_list)

            # our final loss value
            loss = CEloss + MoEloss * self.lambadada
        else:
            loss = None # if we're not training, then we don't need to calculate loss

        return logits, loss

    @torch.no_grad() # no need to keep track of gradients during inference
    def Sampler(
        self,
        logits: torch.Tensor, # shape (batch_size, input_len, vocab_size)
        temperature: float, # controls how boring vs random the outputs should be
        top_p: float, # the maximum cumulative probability of output options we're willing to consider
        top_k: int, # the maximum number of output options we're willing to consider
    ) -> torch.Tensor:
        """
        The Sampler function is responsible for generating token predictions from Grok's output.
        It supports temperature scaling, top-p (nucleus) sampling, and top-k sampling
        """
        # Select the last element for each sequence.
        logits = logits[:,-1,:]

        # Apply temperature scaling
        logits.div_(temperature) # div_ is an in-place operation which is ok since we don't record gradients during inference

        # Calculate probabilities
        probs = torch.softmax(logits, dim=-1, dtype=torch.float)

        # sort the probabilities to for use in top-p & top-k
        probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)

        # calculating top_k
        probs_sum = torch.cumsum(probs_sort, dim=-1) # creates same-size tensor of cumulatve probabilities instead of indivdiual probs
        top_ps_mask = (probs_sum - probs_sort) > top_p # mask where 0's are top-p selections & 1's are to be excluded
        probs_sort = torch.where(top_ps_mask, 0, probs_sort)  # the original probabilities with excluded tokens changed to 0.0

        # calculating top_k
        top_ks_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device) # create a shape (vocab_size) tensor that just iterates up by 1's
        top_ks_mask = top_ks_mask.expand(probs_idx.shape[0], -1) # expand our mask along the batch_size dimension to become size (batch_size, vocab_size)
        top_ks_mask = top_ks_mask >= top_k # top_ks is a list of integers. we keep whichever entries in top_ks_mask are greater than their corresponding entries in top_ks

        # we'll be combining top-p with top-k and using whichever gives us fewer tokens. a very conservative approach
        probs_sort = torch.where(top_ks_mask, 0, probs_sort)
        probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) # Re-normalization so that total probabilities add up to 1
        # now we rearrange the modified probabilities in probs_sort back to their original order according to probs_idx
        probs = torch.gather(probs_sort,
                             dim=-1,
                             index=torch.argsort(probs_idx, dim=-1))

        # samples from the distribution
        next_token_id = torch.multinomial(probs, num_samples=1)

        return next_token_id

    def generate(
        self,
        prompt: str,
        output_len: int = 100, # the model will output 100 tokens
        temperature: float = 0.95, # 0.95 is pretty close to not even using temperature at all (1.0 would be no effect)
        top_p: float = 1.0, # defaulting to 1 means we essentially don't use top-p
        top_k: int = 65, # setting top_k = vocab_size means we're effectively not using top_k at all
    ) -> str:
        """Generates responses for given prompts using Grok model."""

        # encoding the prompt into token indices
        tokens = self.tokenizer.encode(prompt)

        # turning it into the right tensor shape
        tokens = torch.tensor(tokens, device=self.config.device).unsqueeze(0)

        # we wouldn't want to go past the maximum context length we trained on
        assert len(tokens) + output_len <= self.config.max_position_embeddings

        for i in range(output_len):
            # get the model's output logits and ignore the loss, which would be a NoneType object
            logits, _ = self(tokens[:,:self.max_seq_len])

            next_token = self.Sampler(
                logits = logits, # the actual output of the model
                temperature = temperature,
                top_p = top_p,
                top_k = top_k
            )

            # add our new token to the sequence
            tokens = torch.cat((tokens, next_token), dim=1)

        # decode our list of tokens to an actual string
        output = self.tokenizer.decode(tokens.squeeze(0).tolist())

        return output

### training a model

In [5]:
# Train and test splits
data = torch.tensor(tokenizer.encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be our training dataset, the rest for validation
train_data = data[:n]
val_data = data[n:]

In [6]:
# data loading for training which generates a small batch of data of inputs x and targets y
def get_batch(split, batch_size):
    # whether we grab from our training or validation dataset
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - config.max_position_embeddings, (batch_size,))
    x = torch.stack([data[i:i+config.max_position_embeddings] for i in ix])
    y = torch.stack([data[i+1:i+config.max_position_embeddings+1] for i in ix])
    x, y = x.to(config.device), y.to(config.device)
    return x, y

In [7]:
@torch.no_grad()
def estimate_loss(model, batch_size, eval_iters = 10): # to periodically estimate loss during the training loop
    out = {}
    model.eval() # sets model to eval mode
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split, batch_size)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train() # just resets to training mode
    return out

In [8]:
# instantiate a new model
model = minGrok(config, tokenizer).to(config.device)

# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')

print(model)

992.352 K parameters
minGrok(
  (embedder): Embedding(128, 96)
  (layers): ModuleList(
    (0-3): 4 x DecoderLayer(
      (mqa): MQA(
        (qkv_proj): Linear(in_features=96, out_features=144, bias=False)
        (o_proj): Linear(in_features=96, out_features=96, bias=False)
      )
      (moe): MoELayer(
        (experts): ModuleList(
          (0-3): 4 x Expert(
            (layer1): Linear(in_features=96, out_features=384, bias=False)
            (layer2): Linear(in_features=192, out_features=96, bias=False)
          )
        )
        (router): Router(
          (router_weights): Linear(in_features=96, out_features=4, bias=False)
        )
      )
      (pre_mqa_norm): RMSNorm()
      (post_mqa_norm): RMSNorm()
      (pre_moe_norm): RMSNorm()
      (post_moe_norm): RMSNorm()
      (drop): Dropout(p=0.05, inplace=False)
    )
  )
  (final_norm): RMSNorm()
  (criterion): CrossEntropyLoss()
)


In [10]:
# create a PyTorch optimizer
# this is not what they used, but this learning rate & weight decay work for our tiny minGemma
learning_rate = 3e-4
weight_decay = 0.01
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# how long we want to train for
max_iters = 5000

# how often we want to check & see how our loss is doing
eval_interval = 100

# batch size to use
batch_size = 32

import time as time

In [11]:
start_time = time.time()

# Enable anomaly detection. uncomment these lines if you need to do extensive debugging
#torch.autograd.set_detect_anomaly(True)

for iter in range(max_iters):

    # sample a batch of data
    xb, yb = get_batch('train', batch_size)

    # train
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        current_time = time.time()
        elapsed_time = current_time - start_time
        losses = estimate_loss(model, batch_size)
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, time elapsed: {elapsed_time:.2f} seconds")

# Disable anomaly detection after the training loop
#torch.autograd.set_detect_anomaly(False)

step 0: train loss 2.4603, val loss 2.6593, time elapsed: 1.15 seconds
step 100: train loss 2.4402, val loss 2.6460, time elapsed: 102.51 seconds
step 200: train loss 2.4155, val loss 2.6502, time elapsed: 201.02 seconds
step 300: train loss 2.4238, val loss 2.6558, time elapsed: 302.31 seconds
step 400: train loss 2.4197, val loss 2.6455, time elapsed: 399.47 seconds
step 500: train loss 2.4427, val loss 2.6603, time elapsed: 496.25 seconds
step 600: train loss 2.4317, val loss 2.6517, time elapsed: 594.00 seconds
step 700: train loss 2.4523, val loss 2.6506, time elapsed: 695.84 seconds
step 800: train loss 2.4124, val loss 2.6615, time elapsed: 799.43 seconds
step 900: train loss 2.4369, val loss 2.6301, time elapsed: 900.91 seconds
step 1000: train loss 2.4354, val loss 2.6568, time elapsed: 1002.66 seconds
step 1100: train loss 2.4033, val loss 2.6486, time elapsed: 1098.36 seconds
step 1200: train loss 2.4403, val loss 2.6508, time elapsed: 1194.98 seconds
step 1300: train loss 2

In [12]:
# save the model currently held in memory
# the filename specifies the model's class, hyperparameters, and date/time it was saved
import os

# Ensure the directory exists
model_dir = 'models'
if not os.path.exists(model_dir):
    os.makedirs(model_dir)

# Create a shorter, more concise filename
filename = (f'{model.__class__.__name__}'
           f'-v{config.vocab_size}'
           f'-max_t{config.max_position_embeddings}'
           f'-layers{config.num_layers}'
           f'-heads{config.num_attention_heads}'
           f'-kv_heads{config.num_key_value_heads}'
           f'-hidden{config.hidden_size}'
           f'-embedding_multiplier_scale{config.embedding_multiplier_scale}'
           f'-head_dim{config.head_dim}'
           f'-theta{config.rope_theta}'
           f'-lr{learning_rate}'
           f'-decay{weight_decay}'
            f'-tot_num_experts{config.tot_num_experts}'
            f'-chosen_num_experts{config.chosen_num_experts}'
            f'-use_scale{config.use_scale}'
           f'-batch{batch_size}'
            f'-train_iter{max_iters}'
           f'--{time.strftime("%Y-%m-%d_%H-%M-%S")}.pth')

# Save the model
model_path = os.path.join(model_dir, filename)
torch.save(model.state_dict(), model_path)

### Alternatively, you can load the 1m parameter model I already trained

In [8]:
# Initialize a blank model
model = minGrok(config, tokenizer).to(config.device)

# here's the path to a minGemma model that i've trained with roughly 1m parameters
path = 'models/minGrok-v128-max_t256-layers4-heads4-kv_heads1-hidden96-embedding_multiplier_scale2-head_dim24-theta100.0-lr0.0003-decay0.01-tot_num_experts4-chosen_num_experts2-use_scaleTrue-batch32-train_iter5000--2024-03-21_18-20-32.pth'

# Load the saved state dictionary
model.load_state_dict(torch.load(path))
# REMEMBER TO CHANGE VALUES IN CONFIG TO MATCH THE MODEL YOU'VE LOADED

# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')

# If you only plan to do inference, switch to evaluation mode
model.eval()

# If you plan to continue training the model, switch to training mode
#model.train()

992.352 K parameters


minGrok(
  (embedder): Embedding(128, 96)
  (layers): ModuleList(
    (0-3): 4 x DecoderLayer(
      (mqa): MQA(
        (qkv_proj): Linear(in_features=96, out_features=144, bias=False)
        (o_proj): Linear(in_features=96, out_features=96, bias=False)
      )
      (moe): MoELayer(
        (experts): ModuleList(
          (0-3): 4 x Expert(
            (layer1): Linear(in_features=96, out_features=384, bias=False)
            (layer2): Linear(in_features=192, out_features=96, bias=False)
          )
        )
        (router): Router(
          (router_weights): Linear(in_features=96, out_features=4, bias=False)
        )
      )
      (pre_mqa_norm): RMSNorm()
      (post_mqa_norm): RMSNorm()
      (pre_moe_norm): RMSNorm()
      (post_moe_norm): RMSNorm()
      (drop): Dropout(p=0.05, inplace=False)
    )
  )
  (final_norm): RMSNorm()
  (criterion): CrossEntropyLoss()
)

### Testing (performing inference)

In [13]:
input_str = "JULIET:\nO Romeo, Romeo! wherefore art thou" # the classic line
max_useable_output_len = config.max_position_embeddings - len(input_str)
output = model.generate(input_str, output_len = max_useable_output_len)
print(output)

JULIET:
O Romeo, Romeo! wherefore art thou in.

Tome?

Nurse:
Third peaguisrener:
Lo, show and go yours, here mace meraticome
For a thee be oneeget and the lambron a it-ntard; whileTHerle you fair murfeen a 'tis to like.

MENEnguill Yort death their honour mind,
If such therese the curry woront, I that mine,
Why the stays is of him in still.

F
