# ToDo

- [ ] figure out how to implement memory_saver_div into the kv cache
- [ ] add dropout
- [ ] train bigger version (longer context length?)
- [ ] copy & paste into model.py
- [ ] make a train.py
- [ ] make a params.py
- [ ] build colab notebook

# Setup

In [None]:
# my virtual environments are rarely properly connected to jupyter so this fixes that
import sys
import os
current_dir = os.getcwd()  # Get the current working directory
venv_dir = os.path.join(current_dir, 'venv') 
python_version = str(sys.version_info.major) + '.' + str(sys.version_info.minor)
site_packages_path = os.path.join(venv_dir, 'lib', 'python' + python_version, 'site-packages')
sys.path.append(site_packages_path) 

In [1]:
# Importing pytorch
import torch
import torch.nn as nn
from torch.nn import functional as F

# imports for the tokenizer
from tiny_shakespeare_tokenizer import *

# Imports used for the params
from dataclasses import dataclass
from typing import Optional

# for the dataset
from torch.utils.data import Dataset, DataLoader
import pickle

# Imports used for the model
import re
from typing import Any, List, Sequence, Tuple, Union
import math

# used in the training loop
import time

# used to save & load models
import json
from dataclasses import asdict

# Params

In [2]:
tokenizer = get_tokenizer(size = 256) # size options are 128, 256, 512 and 1024

In [3]:
@dataclass
class ModelArgs:
    dim: int = 128 # 4096
    n_layers: int = 8 # 32
    n_heads: int = 4 # 32
    n_kv_heads: Optional[int] = 1 # None
    vocab_size: int = tokenizer.vocab_len # -1
    multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2
    ffn_dim_multiplier: Optional[float] = None
    norm_eps: float = 1e-5
    rope_theta: float = 10000 # 500000
    max_batch_size: int = 32
    max_seq_len: int = 256 # 2048
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    dropout_rate = 0.1

params = ModelArgs()

# RMSNorm

In [4]:
class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

# RoPE

In [5]:
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device, dtype=torch.float32)
    freqs = torch.outer(t, freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis

In [6]:
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1]), f'freqs_cis.shape {freqs_cis.shape} != (x.shape[1], x.shape[-1]) {(x.shape[1], x.shape[-1])}'
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)

In [7]:
def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

# ATTENTION

In [8]:
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
    bs, seqlen, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        x[:, :, :, None, :]
        .expand(bs, seqlen, n_kv_heads, n_rep, head_dim)
        .reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
    )

In [9]:
class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.n_heads = args.n_heads
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        self.n_rep = args.n_heads // self.n_kv_heads
        self.head_dim = args.dim // args.n_heads

        self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
        self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)

        self.cache_k = torch.zeros(
            (args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim),
            requires_grad = False
        ).to(args.device)
        self.cache_v = torch.zeros(
            (args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim),
            requires_grad = False
        ).to(args.device)

    def forward(
        self,
        x: torch.Tensor,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
        start_pos: int = None,
    ):
        bsz, seqlen, _ = x.shape
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

        xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim)

        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

        if start_pos is not None: # if we're performing inference, use kv caching
            self.cache_k = self.cache_k.to(xq)
            self.cache_v = self.cache_v.to(xq)

            #print(start_pos, seqlen, self.cache_k.shape, xk.shape)
            self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
            self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
    
            keys = self.cache_k[:bsz, : start_pos + seqlen]
            values = self.cache_v[:bsz, : start_pos + seqlen]
        else: # if we're training, do full sequence length
            keys, values = xk, xv

        # repeat k/v heads if n_kv_heads < n_heads
        keys = repeat_kv(keys, self.n_rep)  # (bs, cache_len + seqlen, n_heads, head_dim)
        values = repeat_kv(values, self.n_rep)  # (bs, cache_len + seqlen, n_heads, head_dim)

        xq = xq.transpose(1, 2)  # (bs, n_heads, seqlen, head_dim)
        keys = keys.transpose(1, 2)  # (bs, n_heads, cache_len + seqlen, head_dim)
        values = values.transpose(1, 2)  # (bs, n_heads, cache_len + seqlen, head_dim)
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores + mask  # (bs, n_heads, seqlen, cache_len + seqlen)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        output = torch.matmul(scores, values)  # (bs, n_heads, seqlen, head_dim)
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
        return self.wo(output)

# FFWD

In [10]:
class FeedForward(nn.Module):
    def __init__(
        self,
        dim: int,
        hidden_dim: int,
        multiple_of: int,
        ffn_dim_multiplier: Optional[float],
    ):
        super().__init__()
        hidden_dim = int(2 * hidden_dim / 3)
        # custom dim factor multiplier
        if ffn_dim_multiplier is not None:
            hidden_dim = int(ffn_dim_multiplier * hidden_dim)
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

# Layer

In [11]:
class TransformerBlock(nn.Module):
    def __init__(self, layer_id: int, args: ModelArgs):
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // args.n_heads
        self.attention = Attention(args)
        self.feed_forward = FeedForward(
            dim=args.dim,
            hidden_dim=4 * args.dim,
            multiple_of=args.multiple_of,
            ffn_dim_multiplier=args.ffn_dim_multiplier,
        )
        self.layer_id = layer_id
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)

    def forward(
        self,
        x: torch.Tensor,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
        start_pos: int = None,
    ):
        h = x + self.attention(self.attention_norm(x), freqs_cis, mask, start_pos)
        out = h + self.feed_forward(self.ffn_norm(h))
        return out

# Model

In [19]:
class Llama3(nn.Module):
    def __init__(self, params: ModelArgs, tokenizer):
        super().__init__()
        self.params = params
        self.vocab_size = params.vocab_size
        self.n_layers = params.n_layers
        self.max_seq_len = params.max_seq_len
        self.tokenizer = tokenizer

        self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)

        self.layers = torch.nn.ModuleList()
        for layer_id in range(params.n_layers):
            self.layers.append(TransformerBlock(layer_id, params))

        self.norm = RMSNorm(params.dim, eps=params.norm_eps)
        self.output = nn.Linear(
            params.dim, 
            params.vocab_size, 
            bias=False)

        self.freqs_cis = precompute_freqs_cis(
            params.dim // params.n_heads,
            params.max_seq_len * 2,
            params.rope_theta,)

        mask = torch.full((params.max_seq_len, params.max_seq_len), 
                          float("-inf"), 
                          device=params.device)
        mask = torch.triu(mask, diagonal=1)
        self.register_buffer('mask', mask)

        self.criterion = nn.CrossEntropyLoss()

    def forward(self, # specifically for training
                tokens: torch.Tensor, 
                targets: torch.Tensor):
        bsz, seqlen = tokens.shape
        assert tokens.shape == targets.shape
        assert seqlen == self.max_seq_len
        h = self.tok_embeddings(tokens)
        freqs_cis = self.freqs_cis.to(h.device)
        freqs_cis = self.freqs_cis[:seqlen]
        
        for layer in self.layers:
            h = layer(h, freqs_cis, self.mask, start_pos=None)
        h = self.norm(h)
        logits = self.output(h).float()

        loss = self.criterion(
            logits.view(bsz * seqlen, self.vocab_size),
            targets.reshape(bsz * seqlen))
        
        return logits, loss

    @torch.inference_mode()
    def forward_inference(self, 
                          tokens: torch.Tensor,
                          start_pos: int,
                          max_context_window: int,
                         ):
        _bsz, seqlen = tokens.shape
        h = self.tok_embeddings(tokens)
        self.freqs_cis = self.freqs_cis.to(h.device)
        freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]

        mask = self.mask[:seqlen, :seqlen]
        # When performing key-value caching, we compute the attention scores
        # only for the new sequence. Thus, the matrix of scores is of size
        # (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
        # j > cache_len + i, since row i corresponds to token cache_len + i.
        mask = torch.hstack(
            [torch.zeros((seqlen, start_pos), device=tokens.device), mask]
        ).type_as(h)

        for layer in self.layers:
            h = layer(h, freqs_cis, mask, start_pos=start_pos)
        h = self.norm(h)
        logits = self.output(h).float()
        return logits

    @torch.inference_mode() # no need to keep track of gradients during inference
    def Sampler(
        self,
        logits: torch.Tensor, # shape (batch_size, input_len, vocab_size)
        temperature: float, # controls how boring vs random the outputs should be
        top_p: float, # the maximum cumulative probability of output options we're willing to consider
        top_k: int, # the maximum number of output options we're willing to consider
    ) -> torch.Tensor:
        """
        The Sampler function is responsible for generating token predictions
        It supports temperature scaling, top-p (nucleus) sampling, and top-k sampling 
        """
        # Select the last element for each sequence.
        logits = logits[:,-1,:] # (batch_size, input_len, vocab_size) -> (batch_size, vocab_size)
        
        # Apply temperature scaling
        logits.div_(temperature) # (batch_size, vocab_size) / float -> (batch_size, vocab_size)

        # Calculate probabilities with softmax.
        probs = torch.softmax(logits, dim=-1, dtype=torch.float) # dim=-1 is the vocab_size dimension that we calculate along

        # sort the probabilities to for use in top-p & top-k. both are (batch_size, vocab_size)
        probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)

        ### calculating top-p
        # creates same-size tensor of cumulatve probabilities instead of indivdiual probs
        probs_sum = torch.cumsum(probs_sort, dim=-1) 
        # mask where 0's are top-p selections & 1's are to be excluded
        top_ps_mask = (probs_sum - probs_sort) > top_p
        # the original probabilities with excluded tokens changed to 0.0
        probs_sort = torch.where(top_ps_mask, 0, probs_sort) 

        ### calculating top_k
        # create a shape (vocab_size) tensor that just iterates up by 1's
        top_ks_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device) 
        # expand our mask along the batch_size dimension to become size (batch_size, vocab_size)
        top_ks_mask = top_ks_mask.expand(probs_idx.shape[0], -1)
        # top_ks is a list of integers. we keep whichever entries in top_ks_mask are greater than their corresponding entries in top_ks
        top_ks_mask = top_ks_mask >= top_k

        # we'll be combining top-p with top-k and using whichever gives us fewer tokens. a very conservative approach
        # this trims probs_sort to also fit within our top_k requirement
        probs_sort = torch.where(top_ks_mask, 0, probs_sort)

        # Re-normalization so that total probabilities add up to 1
        probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
        
        # now we rearrange the modified probabilities in probs_sort back to their original order according to probs_idx
        probs = torch.gather(probs_sort, dim=-1, index=torch.argsort(probs_idx, dim=-1))
        
        # samples from the distribution
        next_token_id = torch.multinomial(probs, num_samples=1)
        
        return next_token_id # returns the predicted token

    @torch.inference_mode()
    def generate(
        self,
        prompt: str,
        max_gen_len: int = None,
        memory_saver_div: int = 1, # defaults to full max_seq_len**2 memory use. must be power of 2
        temperature: float = 0.6, # default value in meta's code
        top_p: float = 0.9, # default value in meta's code
        top_k: int = 32, # meta's code doesn't bother with topk
    ) -> str: 
        """ Wrapper around sampler() that deals with manipulation of the sequence """
        assert ((memory_saver_div & (memory_saver_div-1)) == 0) & (memory_saver_div > 0), f'memory_saver_div {memory_saver_div} must be power of 2'
        max_context_window = self.max_seq_len // memory_saver_div
        
        # encoding the prompt into token indices
        tokens = self.tokenizer.encode(prompt)
        
        if max_gen_len is None:
            max_gen_len = self.max_seq_len - len(tokens)
        elif max_gen_len + len(tokens) > self.max_seq_len:
            print(f'capping max_gen_len at max_seq_len={self.max_seq_len} including input')
            max_gen_len = self.max_seq_len - len(tokens)

        # turning it into the right tensor shape
        tokens = torch.tensor(tokens, device=self.params.device)
        tokens = tokens.unsqueeze(0) if len(tokens.shape)==1 else tokens # jic we need to add a batch dimension
        
        start_pos = max(tokens.shape[1] - max_context_window, 0)
        
        for i in range(max_gen_len):
            # get the model's output logits and ignore the loss, which would be a NoneType object
            logits = self.forward_inference(
                tokens[:,:self.max_seq_len],
                start_pos = start_pos,
                max_context_window = max_context_window
            )
            
            next_token = self.Sampler(
                logits = logits,
                temperature = temperature,
                top_p = top_p,
                top_k = top_k
            )

            # add our new token to the sequence
            tokens = torch.cat((tokens, next_token), dim=1)
            
            if tokens.shape[1] >= self.max_seq_len:
                start_pos += 1

        # decode our list of tokens to an actual string
        output = self.tokenizer.decode(tokens.squeeze(0).tolist())

        return output
            

# Dataset

In [13]:
# load the dataset
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# the first 200 characters. It's just one continuous text document with all of the works of shakespeare back-to-back
print(text[:200])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you


In [14]:
# Train and test splits
data = torch.tensor(tokenizer.encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be our training dataset, the rest for validation
train_data = data[:n]
val_data = data[n:]

In [15]:
# data loading for training which generates a small batch of data of inputs x and targets y
def get_batch(split, batch_size):
    # whether we grab from our training or validation dataset
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - params.max_seq_len, (batch_size,))
    x = torch.stack([data[i:i+params.max_seq_len] for i in ix])
    y = torch.stack([data[i+1:i+params.max_seq_len+1] for i in ix])
    x, y = x.to(params.device), y.to(params.device)
    return x, y

# Instantiate a brand new model

In [25]:
model = Llama3(params, tokenizer).to(params.device)

# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')

print(model)

1968.256 K parameters
Llama3(
  (tok_embeddings): Embedding(256, 128)
  (layers): ModuleList(
    (0-7): 8 x TransformerBlock(
      (attention): Attention(
        (wq): Linear(in_features=128, out_features=128, bias=False)
        (wk): Linear(in_features=128, out_features=32, bias=False)
        (wv): Linear(in_features=128, out_features=32, bias=False)
        (wo): Linear(in_features=128, out_features=128, bias=False)
      )
      (feed_forward): FeedForward(
        (w1): Linear(in_features=128, out_features=512, bias=False)
        (w2): Linear(in_features=512, out_features=128, bias=False)
        (w3): Linear(in_features=128, out_features=512, bias=False)
      )
      (attention_norm): RMSNorm()
      (ffn_norm): RMSNorm()
    )
  )
  (norm): RMSNorm()
  (output): Linear(in_features=128, out_features=256, bias=False)
  (criterion): CrossEntropyLoss()
)


# Training

In [19]:
@torch.no_grad()
def estimate_loss(model, batch_size, eval_iters = 5): # to estimate loss during the training loop
    out = {}
    model.eval() # sets model to eval mode
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split, batch_size)
            logits, loss = model(X, targets=Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train() # just resets to training mode
    return out

In [27]:
# create a PyTorch optimizer
# this is not what they used, but this learning rate & weight decay work for our tiny minGemma
lr_init = 1e-2
weight_decay = 0.02
optimizer = torch.optim.AdamW(model.parameters(), lr=lr_init, weight_decay=weight_decay)

# how long we want to train for
max_iters = 500

# how often we want to check & see how our loss is doing
eval_interval = 25

# Warmup setup
warmup_iters = 25  # Number of warmup iterations
warmup_factor = 1e-2  # Warmup factor (initial learning rate is multiplied by this factor)

lr_final = 1e-5  # Minimum learning rate

def lr_lambda(current_iter):
    if current_iter < warmup_iters:
        # Warmup phase
        return warmup_factor + (1 - warmup_factor) * current_iter / warmup_iters
    else:
        # Cosine decay phase with minimum learning rate
        decay_iters = max_iters - warmup_iters
        cosine_decay = 0.5 * (1 + math.cos(math.pi * (current_iter - warmup_iters) / decay_iters))
        return max(cosine_decay, lr_final / lr_init)
        
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

In [28]:
start_time = time.time()

# Enable anomaly detection. uncomment these lines if you need to do extensive debugging
#torch.autograd.set_detect_anomaly(True)

for iter in range(max_iters):

    # sample a batch of data
    xb, yb = get_batch('train', params.max_batch_size)
    
    # train
    logits, loss = model(xb, targets=yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    # Update the learning rate
    scheduler.step()
    
    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        current_time = time.time()
        elapsed_time = current_time - start_time
        losses = estimate_loss(model, params.max_batch_size)
        current_lr = optimizer.param_groups[0]['lr']
        print(f"step {iter:04d}: lr {current_lr:.6f}, train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, time elapsed: {elapsed_time:.2f} seconds")

# Disable anomaly detection after the training loop
#torch.autograd.set_detect_anomaly(False)

step 0: lr 0.000496, train loss 5.5877, val loss 5.5893, time elapsed: 1.07 seconds
step 25: lr 0.010000, train loss 3.2807, val loss 3.4082, time elapsed: 25.13 seconds
step 50: lr 0.009926, train loss 3.1071, val loss 3.2393, time elapsed: 49.10 seconds
step 75: lr 0.009718, train loss 2.9478, val loss 3.1434, time elapsed: 73.62 seconds
step 100: lr 0.009382, train loss 2.8066, val loss 3.0346, time elapsed: 97.60 seconds
step 125: lr 0.008925, train loss 2.7220, val loss 3.0126, time elapsed: 122.02 seconds
step 150: lr 0.008362, train loss 2.6286, val loss 2.9169, time elapsed: 146.24 seconds
step 175: lr 0.007707, train loss 2.5054, val loss 2.8463, time elapsed: 170.00 seconds
step 200: lr 0.006978, train loss 2.4186, val loss 2.7004, time elapsed: 194.29 seconds
step 225: lr 0.006195, train loss 2.3221, val loss 2.6358, time elapsed: 218.16 seconds
step 250: lr 0.005380, train loss 2.2511, val loss 2.5394, time elapsed: 242.06 seconds
step 275: lr 0.004554, train loss 2.1958, v

# Saving your model

In [40]:
name = f'models/{model.__class__.__name__}_{time.strftime("%Y-%m-%d|%H-%M-%S")}'
torch.save(model.state_dict(), f'{name}.pth')

# Convert the dataclass object to a dictionary
params_dict = asdict(params)

# Serialize the dictionary to a JSON file
with open(f'{name}.json', 'w') as f:
    json.dump(params_dict, f)

# Load a Pretrained Model

In [20]:
name = 'Llama3_2024-04-19|04-00-15'

# Deserialize the JSON file back to a dictionary
with open(f'models/{name}.json', 'r') as f:
    params_dict = json.load(f)

# Convert the dictionary back to a dataclass object
params = ModelArgs(**params_dict)

# Initialize a blank model
model = Llama3(params, tokenizer).to(params.device)  

# here's the path to a minGemma model that i've trained with roughly 1m parameters
path = f'models/{name}.pth'

# Load the saved state dictionary
model.load_state_dict(torch.load(path)) 
# REMEMBER TO CHANGE VALUES IN params TO MATCH THE MODEL YOU'VE LOADED

# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')

# If you only plan to do inference, switch to evaluation mode
model.eval()

# If you plan to continue training the model, switch to training mode
#model.train()

1968.256 K parameters


Llama3(
  (tok_embeddings): Embedding(256, 128)
  (layers): ModuleList(
    (0-7): 8 x TransformerBlock(
      (attention): Attention(
        (wq): Linear(in_features=128, out_features=128, bias=False)
        (wk): Linear(in_features=128, out_features=32, bias=False)
        (wv): Linear(in_features=128, out_features=32, bias=False)
        (wo): Linear(in_features=128, out_features=128, bias=False)
      )
      (feed_forward): FeedForward(
        (w1): Linear(in_features=128, out_features=512, bias=False)
        (w2): Linear(in_features=512, out_features=128, bias=False)
        (w3): Linear(in_features=128, out_features=512, bias=False)
      )
      (attention_norm): RMSNorm()
      (ffn_norm): RMSNorm()
    )
  )
  (norm): RMSNorm()
  (output): Linear(in_features=128, out_features=256, bias=False)
  (criterion): CrossEntropyLoss()
)

# Inference

In [17]:
input_str = "JULIET:\nO Romeo, Romeo! wherefore art thou " # the classic line
max_useable_output_len = params.max_seq_len - len(input_str)
output = model.generate(input_str, max_gen_len = max_useable_output_len)
print(output)

JULIET:
O Romeo, Romeo! wherefore art thou shalt a bond,
And marry to the blood of work with me,
So shall be more to a true work to the king.

KING RICHARD III:
We will not speak a bastard did in the rest,
As I am in bright of the triumphant,
And you have great a dead the house of it,
That prays of him in the first a while,
Or back in the hour of the world whom you
W


In [21]:
input_str = "JULIET:\nO Romeo, Romeo! wherefore art thou R" # the classic line
max_useable_output_len = params.max_seq_len - len(input_str)
output = model.generate(input_str, max_gen_len = max_useable_output_len, memory_saver_div = 4)
print(output)

JULIET:
O Romeo, Romeo! wherefore art thou Richard of Henry to you.

DUKE OF AUMERLE:
Why, sir, sir.

MENENIUS:
What is my lordship.

KING RICHARD III:
So what I will do see the prince, and would a word,
Who have his work of this slave from the light,
And to be protest of the state of death
That would I know the rest spirit, I would have
That thus may speak the city of 
