# Next-Concept Prediction v11.1

So the idea here is to
1. turn minGemma into a matryoshkaGPT ~~(or maybe just a MatFormer, aka ignoring the attention mechanism, for simplicity?)~~
2. have the smallest level deal with tokens, medium with first-level concepts, and large with 2nd-level concepts
    - figure out the ideal token vs concept layout for the sequences
4. make a way to dynamically generate concept vectors based on cosine similarity rather than storing huge vectors
5. can i figure out a way to have tokens dyanmically use any given level rather than being stuck to a preset number of combinations? idk prolly not. Maybe a single sequence & an MoE router that decides which level to output at a given moment and then automatically concatenates any old tokens into the largest possible size?!?!?!?! idk how that'd work but maybe

In [1]:
# Importing pytorch
import torch
import torch.nn as nn
from torch.nn import functional as F

# imports for the tokenizer
from tokenizer import SimpleTokenizer, loaded_stoi, loaded_merges

# Imports used for the config
import dataclasses 
from typing import Optional

# Imports used for the model
import re
from typing import Any, List, Sequence, Tuple, Union

# used in the training loop
import time

In [2]:
# load the dataset
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

tokenizer = SimpleTokenizer(loaded_stoi, loaded_merges)

vocab length:  128


In [3]:
@dataclasses.dataclass # a class meant specifically to just hold data
class Config:
    """ 
    The default configuration & hyperparameters for next-concept predictor
    """
    ### boring hyperparameters ###
    vocab_size: int = tokenizer.vocab_len
    max_seq_len: int = 256
    num_hidden_layers: int = 4
    num_q_heads: int = 4
    num_kv_heads: int = 1 
    assert num_q_heads % num_kv_heads == 0
    embed_dim: int = 128 
    mlp_multiplier: int = 4
    head_dim: int = 32
    theta = 100.0
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    ### hyperparameters related to matryoshka concept embeddings ###
    levels = 3
    split = 2
    
    @property
    def embed_dim_list(self):
        return [self.embed_dim // (self.split ** (i-1)) for i in range(self.levels, 0, -1)]
    
    @property
    def head_dim_list(self):
        return [self.head_dim // (self.split ** (i-1)) for i in range(self.levels, 0, -1)]

    @property
    def seq_len_list(self):
        return [self.max_seq_len // (self.split ** (i-1)) for i in range(self.levels, 0, -1)]

config = Config()
#print(config.embed_dim_list, config.head_dim_list, config.seq_len_list)

In [4]:
def RoPE(x: torch.Tensor, dim: int, theta: float = 10000.0) -> torch.Tensor:
    """Applies the rotary embedding to the inputted query or key tensor"""
    # Get sequence length
    seq_len = x.size(1)
    device = x.device
    
    # Dynamically compute frequency cis based on the input sequence length
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device).float() / dim))
    t = torch.arange(seq_len, device=device)
    freqs = torch.outer(t, freqs).float()
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64

    # Apply rotary embeddings to the input tensor
    x_ = torch.view_as_complex(torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), dim=-1))
    x_out = torch.view_as_real(x_ * freqs_cis.unsqueeze(0)).type_as(x)  # Ensure batch dimension is handled
    x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
    x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2], -1).transpose(1, 2)

    return x_out

In [22]:
class MQA(nn.Module):
    
    def __init__(self, config: Config):
        super().__init__()

        self.num_q_heads = config.num_q_heads
        self.num_kv_heads = config.num_kv_heads
        assert self.num_q_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_q_heads // self.num_kv_heads

        self.embed_dim = config.embed_dim
        self.head_dim = config.head_dim
        self.theta = config.theta

        self.Wqkv = nn.Parameter(torch.Tensor(self.embed_dim, (self.num_q_heads + 2 * self.num_kv_heads) * self.head_dim))
        nn.init.uniform_(self.Wqkv, -((1 / self.embed_dim) ** 0.5), (1 / self.embed_dim) ** 0.5)
        
        self.Wo = nn.Parameter(torch.Tensor(self.num_q_heads * self.head_dim, self.embed_dim))
        nn.init.uniform_(self.Wo, -((1 / (self.num_q_heads * self.head_dim)) ** 0.5), (1 / (self.num_q_heads * self.head_dim)) ** 0.5)
    
        # for our attention mask we'll create a boolean mask that'll later be turned into large negative values
        self.mask = torch.tril(torch.ones((config.max_seq_len, config.max_seq_len), dtype=torch.uint8)
                              ).view(1, 1, config.max_seq_len, config.max_seq_len).to(dtype=torch.bool)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Extracts batch size and input sequence length from the hidden states tensor.
        batch_size, input_len, d_i = x.shape
        h_i = self.head_dim / (self.embed_dim / d_i)

        # splicing our projection to get the correct sub-matrices
        Wq, Wk, Wv = self.Wqkv.split([self.num_q_heads * self.head_dim,
                                      self.num_kv_heads * self.head_dim,
                                      self.num_kv_heads * self.head_dim],dim = -1)
        Wq = torch.cat([Wq[:d_i, j*self.head_dim:j*self.head_dim + h_i] for j in range(self.num_q_heads)], dim = 1)
        Wk = torch.cat([Wk[:d_i, j*self.head_dim:j*self.head_dim + h_i] for j in range(self.num_kv_heads)], dim = 1)
        Wv = torch.cat([Wv[:d_i, j*self.head_dim:j*self.head_dim + h_i] for j in range(self.num_kv_heads)], dim = 1)

        # Applies the linear projection to the hidden state to retrieve our q, k & v projections
        xq, xk, xv = x @ Wq, x @ Wk, x @ Wv

        # Reshapes each of the Q, K, and V tensors to separate the heads and align the dimensions for attention operations.
        xq = xq.view(batch_size, -1, self.num_q_heads, h_i)
        xk = xk.view(batch_size, -1, self.num_kv_heads, h_i)
        xv = xv.view(batch_size, -1, self.num_kv_heads, h_i)

        # Applies rotary positional embeddings to queries and keys to incorporate positional information.
        xq = RoPE(xq, h_i, self.theta)
        xk = RoPE(xk, h_i, self.theta)

        # If the number of KV heads is different from the number of query heads, adjusts keys and values to match the query heads count.
        if self.num_kv_heads != self.num_q_heads:
            xk = torch.repeat_interleave(xk, self.num_queries_per_kv, dim=2) # [batch_size, input_len, n_local_heads, head_dim]
            xv = torch.repeat_interleave(xv, self.num_queries_per_kv, dim=2)

        # Transposes Q, K, and V tensors to align them for the batch matrix multiplication in attention calculation.
        q = xq.transpose(1, 2) # [batch_size, n_local_heads, input_len, head_dim]
        k = xk.transpose(1, 2)
        v = xv.transpose(1, 2)

        # Calculates attention scores by performing a batch matrix multiplication between queries and keys, followed by scaling.
        logits = torch.matmul(q, k.transpose(2, 3)) * (h_i ** -0.5) # [batch_size, n_local_heads, input_len, input_len]
        
        # Applies the lower-triangular mask to the attention logits
        logits = torch.where(self.mask[..., :input_len, :input_len].expand_as(logits), 
                             logits, 
                             torch.tensor(-1e30, device=logits.device, dtype=logits.dtype))

        # Applies softmax to the logits to obtain attention probabilities
        scores = F.softmax(logits, dim=-1)
        
        # Computes the weighted sum of values based on the attention scores to obtain the output of the attention mechanism.
        output = scores @ v # [batch_size, n_local_heads, input_len, head_dim]

        # Reshapes the attention output to match the expected output dimensions, combining the heads back into the hidden dimension.
        output = output.transpose(1, 2).contiguous().view(batch_size, input_len, -1) # [batch_size, input_len, hidden_dim]

        # Applies the final linear projection to the attention output, mapping it back to d_i
        Wo = torch.cat([self.Wo[j*self.head_dim :j*self.head_dim + h_i, :d_i] for j in range(self.num_q_heads)], dim=0)
        return output @ Wo

In [23]:
class MLP(nn.Module):
    def __init__(self, embed_dim: int, mlp_multiplier: int):
        super().__init__()
        self.mlp_multiplier = mlp_multiplier
        self.hidden_size = embed_dim * mlp_multiplier

        # the gate
        self.Wgate = nn.Parameter(torch.Tensor(embed_dim, self.hidden_size))
        self.Bgate = nn.Parameter(torch.Tensor(self.hidden_size))
        torch.nn.init.uniform_(self.Wgate, -((1/embed_dim) ** 0.5), (1/embed_dim) ** 0.5)
        torch.nn.init.uniform_(self.Bgate, -((1/embed_dim) ** 0.5), (1/embed_dim) ** 0.5)

        # the up projection
        self.Wup = nn.Parameter(torch.Tensor(embed_dim, self.hidden_size))
        self.Bup = nn.Parameter(torch.Tensor(self.hidden_size))
        torch.nn.init.uniform_(self.Wup, -((1/embed_dim) ** 0.5), (1/embed_dim) ** 0.5)
        torch.nn.init.uniform_(self.Bup, -((1/embed_dim) ** 0.5), (1/embed_dim) ** 0.5)

        # the down projection
        self.Wdown = nn.Parameter(torch.Tensor(self.hidden_size, embed_dim))
        self.Bdown = nn.Parameter(torch.Tensor(embed_dim))
        torch.nn.init.uniform_(self.Wdown, -((1/self.hidden_size) ** 0.5), (1/self.hidden_size) ** 0.5)
        torch.nn.init.uniform_(self.Bdown, -((1/self.hidden_size) ** 0.5), (1/self.hidden_size) ** 0.5)
        
    def forward(self, x):
        d_i = x.shape[-1]
        gate = F.gelu(x @ self.Wgate[:d_i, :d_i * self.mlp_multiplier] + self.Bgate[:d_i * self.mlp_multiplier])
        fuse = gate * (x @ self.Wup[:d_i, :d_i * self.mlp_multiplier] + self.Bup[:d_i * self.mlp_multiplier])
        return fuse @ self.Wdown[:d_i * self.mlp_multiplier, :d_i] + self.Bdown[:d_i]

In [24]:
class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__() 
        self.eps = eps 
        
        # Initialize the weight parameter with zeros, which will be learned during training.
        # The shape of the weight is [dim], meaning one weight per feature dimension.
        self.weight = nn.Parameter(torch.zeros(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x: torch.Tensor, model: int = 0) -> torch.Tensor:
        # Normalize the input tensor using the _norm function and ensure the data type matches the input.
        x = self._norm(x.float()).type_as(x)
        
        # grabbing x's dimension to use for splicing
        d_i = x.shape[-1]
        
        # scale the normalized tensor by (1 + self.weight), which effectively starts with no scaling
        return x * (1 + self.weight[:d_i])

# ------ BOOKMARK ------- 

In [15]:
class Layer(nn.Module):
    def __init__(self, config: Config):
        super().__init__()

        self.mqa = MQA(config)
        self.mlp = MLP(config.embed_dim, config.mlp_multiplier)
        
        self.pre_mqa_norm = RMSNorm(config.embed_dim)
        self.post_mqa_norm = RMSNorm(config.embed_dim)
        self.pre_mlp_norm = RMSNorm(config.embed_dim)
        self.post_mlp_norm = RMSNorm(config.embed_dim)

    def forward(self, x: torch.Tensor ) -> torch.Tensor:
        x = x + self.post_mqa_norm(self.mqa(self.pre_mqa_norm(x)))
        x = x + self.post_mlp_norm(self.mlp(self.pre_mlp_norm(x)))
        return x

In [16]:
class customGPT(nn.Module):

    def __init__(self,
        config: Config, # the hyperparameters
        tokenizer: tokenizer, # the tokenizer. we don't always store the tokenizer inside of the model, but it doesn't matter here
    ):
        super().__init__()
        self.config = config

        # the attention heads need to cleanly divide up the embed_dim of the model so that we can split it all apart & combine back together
        assert config.embed_dim % config.num_q_heads == 0

        self.max_seq_len = config.max_seq_len
        self.head_dim = config.head_dim
        self.vocab_size = config.vocab_size
        self.tokenizer = tokenizer

        # the embedding matrix. for converting tokens to the first residual state, and the last residual state to logits
        self.embedder = nn.Embedding(self.vocab_size, config.embed_dim)
        self.scaling = config.embed_dim ** 0.5 # for normalizing the first embedding
        
        # Initialize a sequence of DecoderLayer instances as specified by the number of hidden layers in the config
        self.layers = nn.ModuleList(Layer(config) for _ in range(config.num_hidden_layers))

        # Initialize a normalization layer to be applied after the last decoder layer, stabilizing the output
        self.final_norm = RMSNorm(config.embed_dim)

        # the loss function
        self.criterion = nn.CrossEntropyLoss()

    def forward(
        self,
        input_token_ids: torch.Tensor, # a shape (batch_size, input_seq_len) list of integer token ids
        target_token_ids: torch.Tensor = None, # a shape (batch_size, input_seq_len) list of token ids to train on
        ) -> torch.Tensor:

        # turn the input tokens into the first resudial state using the embedding matrix
        x = self.embedder(input_token_ids) * self.scaling # (batch_size, input_len) & (vocab_size, embed_dim) -> (batch_size, input_len, embed_dim)

        # Iteratively process the input through each Layer
        for i in range(len(self.layers)):
            layer = self.layers[i]
            x = layer(x)
        
        # Apply normalization to the output of the final decoder layer
        x = self.final_norm(x)

        # grabbing the weights of the embedding matrix shape (vocab_size, hidden_dim) for use as the output layer
        embedder_weight = self.embedder.weight

        # the embedding matrix is also used as the output layer
        logits = torch.matmul(x, embedder_weight.t()) # (batch_size, input_len, embed_dim) @ (embed_dim, vocab_size) -> (batch_size, input_len, vocab_size)
        
        if target_token_ids is None: # if we're not training, then we don't need to calculate loss
            loss = None
        else:
            # if we are training
            batch_size, input_len, vocab_size = logits.shape
            # then we reshape our logits & targets before calculating cross-entropy loss
            loss = self.criterion(logits.view(batch_size*input_len, vocab_size), 
                                  target_token_ids.view(batch_size*input_len))
        
        return logits, loss

    @torch.no_grad() # no need to keep track of gradients during inference
    def Sampler(
        self,
        logits: torch.Tensor, # shape (batch_size, input_len, vocab_size)
        temperature: float, # controls how boring vs random the outputs should be
        top_p: float, # the maximum cumulative probability of output options we're willing to consider
        top_k: int, # the maximum number of output options we're willing to consider
    ) -> torch.Tensor:
        """
        The Sampler function is responsible for generating token predictions
        It supports temperature scaling, top-p (nucleus) sampling, and top-k sampling 
        """
        # Select the last element for each sequence.
        logits = logits[:,-1,:] # (batch_size, input_len, vocab_size) -> (batch_size, vocab_size)
        
        # Apply temperature scaling
        logits.div_(temperature) # (batch_size, vocab_size) / float -> (batch_size, vocab_size)

        # Calculate probabilities with softmax.
        probs = torch.softmax(logits, dim=-1, dtype=torch.float) # dim=-1 is the vocab_size dimension that we calculate along

        # sort the probabilities to for use in top-p & top-k. both are (batch_size, vocab_size)
        probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)

        ### calculating top-p
        # creates same-size tensor of cumulatve probabilities instead of indivdiual probs
        probs_sum = torch.cumsum(probs_sort, dim=-1) 
        # mask where 0's are top-p selections & 1's are to be excluded
        top_ps_mask = (probs_sum - probs_sort) > top_p
        # the original probabilities with excluded tokens changed to 0.0
        probs_sort = torch.where(top_ps_mask, 0, probs_sort) 

        ### calculating top_k
        # create a shape (vocab_size) tensor that just iterates up by 1's
        top_ks_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device) 
        # expand our mask along the batch_size dimension to become size (batch_size, vocab_size)
        top_ks_mask = top_ks_mask.expand(probs_idx.shape[0], -1)
        # top_ks is a list of integers. we keep whichever entries in top_ks_mask are greater than their corresponding entries in top_ks
        top_ks_mask = top_ks_mask >= top_k

        # we'll be combining top-p with top-k and using whichever gives us fewer tokens. a very conservative approach
        # this trims probs_sort to also fit within our top_k requirement
        probs_sort = torch.where(top_ks_mask, 0, probs_sort)

        # Re-normalization so that total probabilities add up to 1
        probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
        
        # now we rearrange the modified probabilities in probs_sort back to their original order according to probs_idx
        probs = torch.gather(probs_sort, dim=-1, index=torch.argsort(probs_idx, dim=-1))
        
        # samples from the distribution
        next_token_id = torch.multinomial(probs, num_samples=1)
        
        return next_token_id # returns the predicted token
        
    def generate(
        self,
        prompt: str,
        output_len: int = 100, # the model will output 100 tokens by default
        temperature: float = 0.95, # 0.95 is pretty close to not even using temperature at all (1.0 would be no effect)
        top_p: float = 1.0, # defaulting to 1 means we essentially don't use top-p
        top_k: int = 65, # setting top_k = vocab_size means we're effectively not using top_k at all
    ) -> str: 
        """ Wrapper around sampler() that deals with manipulation of the sequence """
        
        # encoding the prompt into token indices
        tokens = self.tokenizer.encode(prompt)

        # turning it into the right tensor shape
        tokens = torch.tensor(tokens, device=config.device).unsqueeze(0)
        
        # we wouldn't want to go past the maximum context length we trained on
        assert len(tokens) + output_len <= self.config.max_seq_len

        for i in range(output_len):
            # get the model's output logits and ignore the loss, which would be a NoneType object
            logits, _ = self(tokens[:,:self.max_seq_len])
            
            next_token = self.Sampler(
                logits = logits,
                temperature = temperature,
                top_p = top_p,
                top_k = top_k
            )

            # add our new token to the sequence
            tokens = torch.cat((tokens, next_token), dim=1)

        # decode our list of tokens to an actual string
        output = self.tokenizer.decode(tokens.squeeze(0).tolist())

        return output

# Training-related Functions

In [9]:
# Train and test splits
data = torch.tensor(tokenizer.encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be our training dataset, the rest for validation
train_data = data[:n]
val_data = data[n:]

In [10]:
# data loading for training which generates a small batch of data of inputs x and targets y
def get_batch(split, batch_size):
    # whether we grab from our training or validation dataset
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - config.max_seq_len, (batch_size,))
    x = torch.stack([data[i:i+config.max_seq_len] for i in ix])
    y = torch.stack([data[i+1:i+config.max_seq_len+1] for i in ix])
    x, y = x.to(config.device), y.to(config.device)
    return x, y

In [11]:
@torch.no_grad()
def estimate_loss(model, batch_size, eval_iters = 5): # to estimate loss during the training loop
    out = {}
    model.eval() # sets model to eval mode
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split, batch_size)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train() # just resets to training mode
    return out

# Instantiate a brand new model

In [17]:
model = customGPT(config, tokenizer).to(config.device)

# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')

print(model)

975.616 K parameters
customGPT(
  (embedder): Embedding(128, 128)
  (layers): ModuleList(
    (0-3): 4 x Layer(
      (mqa): MQA(
        (qkv_proj): Linear(in_features=128, out_features=192, bias=False)
        (o_proj): Linear(in_features=128, out_features=128, bias=False)
      )
      (mlp): MLP(
        (gate_proj): Linear(in_features=128, out_features=512, bias=True)
        (up_proj): Linear(in_features=128, out_features=512, bias=True)
        (down_proj): Linear(in_features=512, out_features=128, bias=True)
      )
      (pre_mqa_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (post_mqa_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (pre_mlp_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (post_mlp_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
  )
  (final_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (criterion): CrossEntropyLoss()
)


# Load a Pretrained Model

In [19]:
# Initialize a blank model
model = customGPT(config, tokenizer).to(config.device)  

# here's the path to a minGemma model that i've trained with roughly 1m parameters
path = 'models/?.pth'

# Load the saved state dictionary
model.load_state_dict(torch.load(path))
# REMEMBER TO CHANGE VALUES IN CONFIG TO MATCH THE MODEL YOU'VE LOADED

# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')

# If you only plan to do inference, switch to evaluation mode
model.eval()

# If you plan to continue training the model, switch to training mode
#model.train()

972.416 K parameters


minGemma(
  (embedder): Embedding(128, 128)
  (model): Body(
    (layers): ModuleList(
      (0-3): 4 x Layer(
        (self_attn): Attention(
          (qkv_proj): Linear(in_features=128, out_features=192, bias=False)
          (o_proj): Linear(in_features=128, out_features=128, bias=False)
        )
        (mlp): MLP(
          (gate_proj): Linear(in_features=128, out_features=512, bias=True)
          (up_proj): Linear(in_features=128, out_features=512, bias=True)
          (down_proj): Linear(in_features=512, out_features=128, bias=True)
        )
        (input_layernorm): RMSNorm()
        (post_attention_layernorm): RMSNorm()
      )
    )
    (norm): RMSNorm()
  )
  (criterion): CrossEntropyLoss()
)

# Training

In [18]:
# create a PyTorch optimizer
# this is not what they used, but this learning rate & weight decay work for our tiny minGemma
learning_rate = 3e-4
weight_decay = 0.01
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# how long we want to train for
max_iters = 10

# how often we want to check & see how our loss is doing
eval_interval = 2

# batch size to use
batch_size = 32

In [19]:
start_time = time.time()

# Enable anomaly detection. uncomment these lines if you need to do extensive debugging
#torch.autograd.set_detect_anomaly(True)

for iter in range(max_iters):

    # sample a batch of data
    xb, yb = get_batch('train', batch_size)
    
    # train
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    
    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        current_time = time.time()
        elapsed_time = current_time - start_time
        losses = estimate_loss(model, batch_size)
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, time elapsed: {elapsed_time:.2f} seconds")

# Disable anomaly detection after the training loop
#torch.autograd.set_detect_anomaly(False)

step 0: train loss 120.9438, val loss 120.9641, time elapsed: 0.79 seconds
step 2: train loss 117.1023, val loss 117.3068, time elapsed: 6.65 seconds
step 4: train loss 113.4949, val loss 113.9522, time elapsed: 12.06 seconds
step 6: train loss 110.0082, val loss 110.3763, time elapsed: 17.70 seconds
step 8: train loss 106.7198, val loss 107.1945, time elapsed: 23.63 seconds
step 9: train loss 105.0491, val loss 105.6224, time elapsed: 29.06 seconds


# Saving your model

In [20]:
# save the model currently held in memory
# the filename specifies the model's class, hyperparameters, and date/time it was saved
torch.save(model.state_dict(),
           f'models/{model.__class__.__name__}'
           f'-vocab_size{config.vocab_size}'
           f'-max_seq_len{config.max_seq_len}'
           f'-num_hidden_layers{config.num_hidden_layers}'
           f'-num_q_heads{config.num_q_heads}'
           f'-num_kv_heads{config.num_kv_heads}'
           f'-embed_dim{config.embed_dim}'
           f'-mlp_multiplier{config.mlp_multiplier}'
           f'-head_dim{config.head_dim}'
           f'-theta{config.theta}'
           f'--{time.strftime("%Y-%m-%d|%H-%M-%S")}.pth')

# Inference

In [20]:
input_str = "JULIET:\nO Romeo, Romeo! wherefore art thou " # the classic line
max_useable_output_len = config.max_seq_len - len(input_str)
output = model.generate(input_str, output_len = max_useable_output_len)
print(output)

JULIET:
O Romeo, Romeo! wherefore art thou                                                                                                                                                                                                                      


In [21]:
len(output)

256