In [None]:
from google.colab import drive
import shutil
import os

# Mount Drive
drive.mount('/content/drive')

# Setup directories
drive_cs336_dir = "/content/drive/MyDrive/Colab/cs336"

# Copy tokenizer files from Drive to local
def setup_tokenizer_files():
   """Copy tokenizer files from Google Drive to local directory"""
   vocab_source = f"{drive_cs336_dir}/tinystories_vocab.json"
   merges_source = f"{drive_cs336_dir}/tinystories_merges.txt"

   vocab_dest = "tinystories_vocab.json"
   merges_dest = "tinystories_merges.txt"

   if os.path.exists(vocab_source):
       shutil.copy2(vocab_source, vocab_dest)

   if os.path.exists(merges_source):
       shutil.copy2(merges_source, merges_dest)

# Copy encoded data files from Drive to local
def setup_data_files():
   """Copy encoded data files from Google Drive to local directory"""
   train_source = f"{drive_cs336_dir}/TinyStoriesV2-GPT4-train.npy"
   valid_source = f"{drive_cs336_dir}/TinyStoriesV2-GPT4-valid.npy"

   train_dest = "TinyStoriesV2-GPT4-train.npy"
   valid_dest = "TinyStoriesV2-GPT4-valid.npy"

   if os.path.exists(train_source):
       shutil.copy2(train_source, train_dest)
       print(f"✓ Copied training data: {train_dest}")
   else:
       print(f"❌ Training data not found: {train_source}")

   if os.path.exists(valid_source):
       shutil.copy2(valid_source, valid_dest)
       print(f"✓ Copied validation data: {valid_dest}")
   else:
       print(f"❌ Validation data not found: {valid_source}")

# Setup all files
setup_tokenizer_files()
setup_data_files()

# Now you can use the files locally:
# tokenizer = Tokenizer.from_files("tinystories_vocab.json", "tinystories_merges.txt", ["<|endoftext|>"])
# train_data = np.load("TinyStoriesV2-GPT4-train.npy")
# valid_data = np.load("TinyStoriesV2-GPT4-valid.npy")

Mounted at /content/drive
✓ Copied training data: TinyStoriesV2-GPT4-train.npy
✓ Copied validation data: TinyStoriesV2-GPT4-valid.npy


In [None]:
!pip install jaxtyping
!pip install wandb -qU

Collecting jaxtyping
  Downloading jaxtyping-0.3.2-py3-none-any.whl.metadata (7.0 kB)
Collecting wadler-lindig>=0.1.3 (from jaxtyping)
  Downloading wadler_lindig-0.1.7-py3-none-any.whl.metadata (17 kB)
Downloading jaxtyping-0.3.2-py3-none-any.whl (55 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/55.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m55.4/55.4 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading wadler_lindig-0.1.7-py3-none-any.whl (20 kB)
Installing collected packages: wadler-lindig, jaxtyping
Successfully installed jaxtyping-0.3.2 wadler-lindig-0.1.7


In [None]:
# Standard library imports
import argparse
import base64
import json
import os
import sys
from pathlib import Path
from typing import Iterable, Iterator

# Third-party imports
import numpy as np
import regex as re
from tqdm import tqdm

def pretokenize_for_encoding(text, special_tokens=None):
    """Tokenize text into a list of byte tuples for encoding."""
    pattern = re.compile(r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")

    if not special_tokens:
        # Fast path for no special tokens
        matches = list(pattern.finditer(text))
        result = []
        for match in matches:
            token_bytes = tuple(match.group(0).encode("utf-8"))
            result.append(token_bytes)
        return result

    # Sort special tokens by length (longest first) to handle overlapping tokens correctly
    special_tokens_sorted = sorted(special_tokens, key=len, reverse=True)
    special_tokens_set = set(special_tokens)

    # Create pattern that matches longest tokens first
    special_pattern = '|'.join(re.escape(token) for token in special_tokens_sorted)

    # Split text on special tokens
    text_segments = re.split(f'({special_pattern})', text)

    result = []
    for segment in text_segments:
        if not segment:
            continue
        if segment in special_tokens_set:
            result.append(tuple(segment.encode("utf-8")))
        else:
            # Process regular text segments
            matches = list(pattern.finditer(segment))
            for match in matches:
                token_bytes = tuple(match.group(0).encode("utf-8"))
                result.append(token_bytes)

    return result

class Tokenizer:
    def __init__(self, vocab: dict[int, bytes], merges: list[tuple[bytes, bytes]], special_tokens: list[str] | None = None):
        self.vocab = vocab
        self._rvocab = {v: k for k, v in vocab.items()}
        self.merges = merges
        self.special_tokens = special_tokens
        # Create merge lookup table for O(1) access
        self.merge_lookup = {}
        for i, (left, right) in enumerate(merges):
            pair = (left, right)
            self.merge_lookup[pair] = i

    @classmethod
    def from_files(cls, vocab_filepath: str, merge_filepath:str, special_tokens: list[str] = ["<|endoftext|>"]) -> "Tokenizer":
        with open(vocab_filepath, "r", encoding='utf-8') as f:
            vocab_data = json.load(f)
            vocab = {int(k): base64.b64decode(v.encode('utf-8')) for k, v in vocab_data.items()}

        with open(merge_filepath, "r", encoding='utf-8') as f:
            merges = []
            for line in f:
                parts = line.rstrip().split(" ")
                if len(parts) == 2:
                    left = base64.b64decode(parts[0])
                    right = base64.b64decode(parts[1])
                    merges.append((left, right))
        return Tokenizer(vocab, merges, special_tokens)

    def encode(self, text: str) -> list[int]:
        # Get pretokens (each is a tuple of bytes)
        print("Pretokenizing...")
        pretokens = pretokenize_for_encoding(text, self.special_tokens)
        all_tokens = []

        # Add progress bar for processing pretokens
        for pretoken in tqdm(pretokens, desc="Encoding pretokens", unit="pretoken"):
            # Check if this pretoken is a special token
            pretoken_bytes = bytes(pretoken)
            pretoken_str = pretoken_bytes.decode('utf-8', errors='ignore')

            if self.special_tokens and pretoken_str in self.special_tokens:
                # Handle special token - look it up directly in vocab
                if pretoken_bytes in self._rvocab:
                    special_token_id = self._rvocab[pretoken_bytes]
                    all_tokens.append(special_token_id)
                else:
                    raise ValueError(f"Special token '{pretoken_str}' not found in vocabulary")
            else:
                # Handle regular token - convert to individual bytes first
                tokens = []
                for byte_val in pretoken:
                    single_byte = bytes([byte_val])  # Convert int to bytes object
                    if single_byte in self._rvocab:
                        token_id = self._rvocab[single_byte]
                        tokens.append(token_id)
                    else:
                        raise ValueError(f"Byte {single_byte} (ASCII {byte_val}) not found in vocabulary")

                # Apply merges using efficient algorithm
                while True:
                    # Find the earliest merge available
                    earliest_merge = None  # (position, merge_index)

                    for i in range(len(tokens) - 1):
                        # Get the byte pair at position i
                        left_bytes = self.vocab[tokens[i]]
                        right_bytes = self.vocab[tokens[i + 1]]
                        pair = (left_bytes, right_bytes)

                        # Check if this pair has a merge rule
                        merge_index = self.merge_lookup.get(pair, -1)
                        if merge_index != -1:
                            # If this is the earliest merge found so far, save it
                            if earliest_merge is None or merge_index < earliest_merge[1]:
                                earliest_merge = (i, merge_index)

                    # If no merge found, we're done
                    if earliest_merge is None:
                        break

                    # Apply the earliest merge
                    pos, merge_idx = earliest_merge
                    left_bytes = self.vocab[tokens[pos]]
                    right_bytes = self.vocab[tokens[pos + 1]]
                    merged_bytes = left_bytes + right_bytes

                    if merged_bytes in self._rvocab:
                        merged_token_id = self._rvocab[merged_bytes]
                        # Replace the two tokens with the merged token
                        tokens = tokens[:pos] + [merged_token_id] + tokens[pos + 2:]
                    else:
                        # This shouldn't happen if vocab is consistent with merges
                        break

                # Add processed tokens from this pretoken to final result
                all_tokens.extend(tokens)

        return all_tokens

    def decode(self, ids: list[int]) -> str:
        decoded_bytes = b''.join(self.vocab[id] for id in ids)
        return decoded_bytes.decode("utf-8", errors="replace")

    def encode_iterable(self, iterable: Iterable[str]) -> Iterator[int]:
        for chunk in iterable:
            # Use your existing pretokenize_for_encoding function
            pretokens = pretokenize_for_encoding(chunk, self.special_tokens)

            for pretoken in pretokens:
                # Check if this pretoken is a special token
                pretoken_bytes = bytes(pretoken)
                pretoken_str = pretoken_bytes.decode('utf-8', errors='ignore')

                if self.special_tokens and pretoken_str in self.special_tokens:
                    # Handle special token - look it up directly in vocab
                    if pretoken_bytes in self._rvocab:
                        special_token_id = self._rvocab[pretoken_bytes]
                        yield special_token_id
                    else:
                        raise ValueError(f"Special token '{pretoken_str}' not found in vocabulary")
                    continue

                # Handle regular token - convert to individual bytes first
                tokens = []
                for byte_val in pretoken:
                    single_byte = bytes([byte_val])  # Convert int to bytes object
                    if single_byte in self._rvocab:
                        token_id = self._rvocab[single_byte]
                        tokens.append(token_id)
                    else:
                        raise ValueError(f"Byte {single_byte} (ASCII {byte_val}) not found in vocabulary")

                # Apply merges to this pretoken
                for left_bytes, right_bytes in self.merges:
                    new_tokens = []
                    i = 0
                    while i < len(tokens):
                        # Check if we can merge at position i
                        if (i < len(tokens) - 1 and
                            self.vocab[tokens[i]] == left_bytes and
                            self.vocab[tokens[i + 1]] == right_bytes):
                            # Merge: find token ID for merged bytes
                            merged_bytes = left_bytes + right_bytes
                            if merged_bytes in self._rvocab:
                                merged_token_id = self._rvocab[merged_bytes]
                                new_tokens.append(merged_token_id)
                                i += 2  # Skip both tokens
                            else:
                                new_tokens.append(tokens[i])
                                i += 1
                        else:
                            new_tokens.append(tokens[i])
                            i += 1
                    tokens = new_tokens

                # Yield the processed tokens from this pretoken
                for token_id in tokens:
                    yield token_id

In [None]:
from jaxtyping import Float, Int
import numpy.typing as npt
from torch import Tensor
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import einsum, rearrange


class Linear(nn.Module):
    def __init__(self, in_features:int, out_features:int, device: torch.device | None = None, dtype: torch.dtype | None =None):
        super().__init__()
        W = torch.empty(out_features, in_features, device=device, dtype=dtype)
        std = (2/(in_features + out_features)**(0.5))
        torch.nn.init.trunc_normal_(W, mean=0, std=std, a=-3*std, b=3*std)
        self.W = nn.Parameter(W)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return einsum(x, self.W, "... d_in, d_out d_in -> ... d_out")


class Embedding(nn.Module):

    def __init__(self, vocab_size:int, d_model:int, device: torch.device | None = None, dtype: torch.dtype | None =None):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        W = torch.empty(self.vocab_size, self.d_model, device=device, dtype=dtype)
        torch.nn.init.trunc_normal_(W, mean=0, std=1, a=-3, b=3)
        self.W = nn.Parameter(W)

    def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
        return self.W[token_ids]


class RMSNorm(nn.Module):

    def __init__(self, d_model: int, eps: float = 1e-5, device: torch.device | None = None, dtype: torch.dtype | None =None):
        super().__init__()
        # self.eps = eps
        # self.G = nn.Parameter(torch.ones(d_model, device=device, dtype=dtype))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x
        # in_dtype = x.dtype
        # x = x.to(torch.float32)
        # x_squared = x**2
        # x_squared_mean = x_squared.mean(-1, keepdim=True)
        # rms = (x_squared_mean + self.eps)**(0.5)
        # x_normalized = x / rms
        # result = einsum(x_normalized, self.G, "... d_model, d_model -> ... d_model")
        # return result.to(in_dtype)

class SWIGLU(nn.Module):

    def __init__(self, d_model: int, d_ff: int, device: torch.device | None = None, dtype: torch.dtype | None =None):
        super().__init__()
        self.W1 = Linear(d_model, d_ff, device, dtype)
        self.W2 = Linear(d_ff, d_model, device, dtype)
        self.W3 = Linear(d_model, d_ff, device, dtype)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        w1x = self.W1(x)
        silux = w1x * torch.sigmoid(w1x)
        w3x = self.W3(x)
        elew1w3 = silux * w3x
        return self.W2(elew1w3)


class RotaryPositionalEmbedding(nn.Module):
    def __init__(self, theta: float, d_k: int, max_seq_len: int, device=None):
        super().__init__()
        self.r = torch.zeros(max_seq_len, d_k, d_k, device=device)

        for i in range(max_seq_len):
            for k in range(d_k//2):
                freq = 1.0 / (theta ** (2*k / d_k))
                angle = i * freq

                cos_val = torch.cos(torch.tensor(angle, device=device))
                sin_val = torch.sin(torch.tensor(angle, device=device))

                self.r[i, 2*k, 2*k] = cos_val
                self.r[i, 2*k, 2*k+1] = -sin_val
                self.r[i, 2*k+1, 2*k] = sin_val
                self.r[i, 2*k+1, 2*k+1] = cos_val

    def forward(self, x: torch.Tensor, token_positions: torch.Tensor) -> torch.Tensor:
        ri_token_pos = self.r[token_positions]
        return einsum(x, ri_token_pos, "... seq d_k_in, ... seq d_k_out d_k_in -> ... seq d_k_out")


def softmax(x: torch.Tensor, dim: int):
    max_xi = torch.amax(x, dim=dim, keepdim=True)
    x_shifted = x - max_xi
    x_exp = torch.exp(x_shifted)
    sum_x_exp = torch.sum(x_exp, dim=dim, keepdim=True)
    result = x_exp / sum_x_exp
    return result

def scaled_dot_product_attention(
    Q: Float[Tensor, " ... queries d_k"],
    K: Float[Tensor, " ... keys d_k"],
    V: Float[Tensor, " ... values d_v"],
    mask: Float[Tensor, " ... queries keys"] | None = None
) -> Float[Tensor, " ... queries d_v"]:
    wei = einsum(Q, K, "... queries d_k, ... keys d_k -> ... queries keys") / (Q.shape[-1] ** 0.5)
    if mask is not None:
        mask = mask.to(wei.device)
        wei = wei.masked_fill(mask == 0, float('-inf'))
    wei = softmax(wei, dim=-1)
    return einsum(wei, V, "... queries keys, ... keys d_v -> ... queries d_v")

class MultiheadAttention(nn.Module):
    def __init__(self, d_model:int, num_heads:int, device=None, rope=None):
        super().__init__()
        self.num_heads = num_heads
        self.dk = d_model // num_heads
        self.d_model = d_model
        self.Q = Linear(d_model, self.dk * self.num_heads, device=device)
        self.K = Linear(d_model, self.dk * self.num_heads, device=device)
        self.V = Linear(d_model, self.dk * self.num_heads, device=device)
        self.Wo = Linear(self.dk * num_heads, d_model, device=device)
        self.rope = rope

    def forward(self, x: torch.Tensor, token_positions: torch.Tensor | None = None) -> torch.Tensor:
        q,k,v = self.Q(x), self.K(x), self.V(x)
        q = rearrange(q, "... seq (num_heads dk) -> ... num_heads seq dk", num_heads=self.num_heads)
        k = rearrange(k, "... seq (num_heads dk) -> ... num_heads seq dk", num_heads=self.num_heads)
        v = rearrange(v, "... seq (num_heads dv) -> ... num_heads seq dv", num_heads=self.num_heads)

        if self.rope != None:
            q = self.rope(q, token_positions)
            k = self.rope(k, token_positions)

        seq = k.shape[-2]
        attn = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        attn = rearrange(attn, "... num_heads seq dv -> ... seq (num_heads dv)")
        return self.Wo(attn)


class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, device=None, rope=None):
        super().__init__()
        self.mha = MultiheadAttention(d_model, num_heads, device=device, rope=rope)
        self.ffn = SWIGLU(d_model, d_ff, device=device)
        self.ln1 = RMSNorm(d_model, device=device)
        self.ln2 = RMSNorm(d_model, device=device)

    def forward(self, x: torch.Tensor, token_positions: torch.Tensor | None = None) -> torch.Tensor:
        if token_positions == None:
            token_positions = torch.arange(x.shape[1])
        x = x + self.mha(self.ln1(x), token_positions)
        x = x + self.ffn(self.ln2(x))
        return x

class Transformer(nn.Module):
    def __init__(self, vocab_size:int, context_length:int, num_layers:int, d_model:int, num_heads:int, d_ff:int, rope_theta=None, device=None):
        super().__init__()
        self.emb = Embedding(vocab_size, d_model, device)
        d_k = d_model // num_heads
        self.rope = RotaryPositionalEmbedding(rope_theta, d_k, context_length, device)
        self.blocks = nn.Sequential(*[TransformerBlock(d_model, num_heads, d_ff, device, self.rope) for _ in range(num_layers)])
        self.lnf = RMSNorm(d_model, device=device)
        self.lm_head = Linear(d_model, vocab_size, device=device)

    def forward(self, x:torch.Tensor):
        tok_emb = self.emb(x)
        block_out = self.blocks(tok_emb)
        norm_out = self.lnf(block_out)
        logits = self.lm_head(norm_out)
        return logits

In [None]:
# Standard library imports
import argparse
import math
import os
from collections.abc import Callable, Iterable
from pathlib import Path
from typing import Any, BinaryIO, IO, Optional

# Third-party imports
import numpy as np
import numpy.typing as npt
import torch
import torch.nn as nn
from einops import einsum, rearrange
from jaxtyping import Float, Int
from torch import Tensor
from tqdm import tqdm

def cross_entropy(inputs: Float[Tensor, " batch_size vocab_size"], targets: Int[Tensor, " batch_size"]) -> Float[Tensor, ""]:
    inputs_shifted = inputs - torch.max(inputs, dim=-1, keepdim=True).values

    log_sum_exp = torch.log(torch.sum(torch.exp(inputs_shifted), dim=-1, keepdim=True))

    logits = inputs_shifted - log_sum_exp

    nlls = -logits[torch.arange(len(targets)), targets]

    return nlls.mean()

class SGD(torch.optim.Optimizer):

    def __init__(self, params, lr=1e-3):
        defaults = {"lr" : lr}
        super().__init__(params, defaults)


    def step(self, closure: Optional[Callable] = None):
        loss = None if closure is None else closure()
        for group in self.param_groups:
            lr = group["lr"]
            for p in group["params"]:
                if p.grad is None:
                    continue
                state = self.state[p]
                t = state.get("t", 0)
                p.data -= lr / math.sqrt(t + 1) * p.grad.data
                state["t"] = t + 1
        return loss

class AdamW(torch.optim.Optimizer):

    def __init__(self, params, lr=1e-3, weight_decay=0.0, betas=(0.9,0.999), eps = 1e-8):
        defaults = {"lr" : lr, "betas": betas, "weight_decay": weight_decay, "eps": eps}
        super().__init__(params, defaults)


    def step(self, closure: Optional[Callable] = None):
        loss = None if closure is None else closure()
        for group in self.param_groups:
            lr = group["lr"]
            b1,b2 = group["betas"]
            eps = group["eps"]
            weight_decay = group["weight_decay"]
            for p in group["params"]:
                if p.grad is None:
                    continue
                state = self.state[p]
                t = state.get("t", 1)
                m = state.get("m", torch.zeros_like(p.data))
                v = state.get("v", torch.zeros_like(p.data))
                g = p.grad.data
                m = b1 * m + (1-b1) * g
                v = b2*v + (1-b2) * g**2
                lr_t = lr * math.sqrt((1-b2**t)) / (1 - b1**t)
                p.data -= lr_t * m / (torch.sqrt(v) + eps)
                p.data -= lr * weight_decay * p.data
                state["t"] = t + 1
                state["m"] = m
                state["v"] = v
        return loss

def lr_cosine_schedule(
    it: int,
    max_lr: float,
    min_lr: float,
    warmup_steps: int,
    cosine_cycle_iters: int,
):
    # 1) linear warmup for warmup_iters steps
    if it < warmup_steps:
        return max_lr * it / warmup_steps
    # 2) if it > cosine_cycle_iters, return min learning rate
    if it >= cosine_cycle_iters:
        return min_lr
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (it - warmup_steps) / (cosine_cycle_iters - warmup_steps)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff starts at 1 and goes to 0
    return min_lr + coeff * (max_lr - min_lr)

def gradient_clipping(parameters: Iterable[torch.nn.Parameter], max_l2_norm: float, eps=1e-6) -> None:
    """Given a set of parameters, clip their combined gradients to have l2 norm at most max_l2_norm.

    Args:
        parameters (Iterable[torch.nn.Parameter]): collection of trainable parameters.
        max_l2_norm (float): a positive value containing the maximum l2-norm.

    The gradients of the parameters (parameter.grad) should be modified in-place.
    """
    grad_params = [p for p in parameters if p.grad is not None]
    l2norm = torch.sqrt(sum([torch.sum(p.grad **2) for p in grad_params]))
    if l2norm < max_l2_norm:
        return l2norm
    for p in grad_params:
        p.grad *= (max_l2_norm/ (l2norm + eps))
    return l2norm


def get_batch(
    dataset: npt.NDArray, batch_size: int, context_length: int, device: str
) -> tuple[torch.Tensor, torch.Tensor]:
    # Pre-allocate numpy arrays
    xs = np.zeros((batch_size, context_length), dtype=np.int64)
    ys = np.zeros((batch_size, context_length), dtype=np.int64)

    for i in range(batch_size):
        idx = np.random.randint(0, len(dataset) - context_length)
        xs[i] = dataset[idx: idx + context_length]
        ys[i] = dataset[idx+1: idx + context_length + 1]

    # Convert to tensors
    xs = torch.from_numpy(xs).to(device)
    ys = torch.from_numpy(ys).to(device)
    return (xs, ys)

def save_checkpoint(
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    iteration: int,
    out: str | os.PathLike | BinaryIO | IO[bytes],
):
    """
    Given a model, optimizer, and an iteration number, serialize them to disk.

    Args:
        model (torch.nn.Module): Serialize the state of this model.
        optimizer (torch.optim.Optimizer): Serialize the state of this optimizer.
        iteration (int): Serialize this value, which represents the number of training iterations
            we've completed.
        out (str | os.PathLike | BinaryIO | IO[bytes]): Path or file-like object to serialize the model, optimizer, and iteration to.
    """
    checkpoint = {
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "iteration": iteration
    }
    torch.save(checkpoint, out)



def load_checkpoint(
    src: str | os.PathLike | BinaryIO | IO[bytes],
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
):
    """
    Given a serialized checkpoint (path or file-like object), restore the
    serialized state to the given model and optimizer.
    Return the number of iterations that we previously serialized in
    the checkpoint.

    Args:
        src (str | os.PathLike | BinaryIO | IO[bytes]): Path or file-like object to serialized checkpoint.
        model (torch.nn.Module): Restore the state of this model.
        optimizer (torch.optim.Optimizer): Restore the state of this optimizer.
    Returns:
        int: the previously-serialized number of iterations.
    """
    checkpoint = torch.load(src)
    model.load_state_dict(checkpoint["model_state"])
    optimizer.load_state_dict(checkpoint["optimizer_state"])
    return checkpoint["iteration"]

In [None]:
import wandb
import random
import math

wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mkl4kennylee81[0m ([33mkl4kennylee81-kenneth-personal[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [None]:
import wandb
import torch
import numpy as np
import os
from dataclasses import dataclass, asdict
from typing import Literal, cast, Optional, BinaryIO, IO
import time
from tqdm.notebook import tqdm

# Default training configuration at top of file
DefaultTrainModelArgs = {
    # Model args
    "vocab_size": 10000,
    "context_length": 256,
    "num_layers": 4,
    "d_model": 512,
    "num_heads": 16,
    "d_ff": 1344,
    "rope_theta": 10000,

    # Optimizer args
    "weight_decay": 0.01,
    "betas": (0.9, 0.999),

    # Learning rate schedule
    "max_learning_rate": 1e-3,
    "min_learning_rate": 1e-5,
    "warmup_iters": 2000,
    "cosine_cycle_iters": 40960,

    # Data paths - keep as is
    "training_set": "TinyStoriesV2-GPT4-train.npy",
    "validation_set": "TinyStoriesV2-GPT4-valid.npy",
    "tokenizer_vocab": "tinystories_vocab.json",
    "tokenizer_merges": "tinystories_merges.txt",

    # Training config
    "validation_step_interval": 500,
    "checkpoint_step_interval": 10000,
    "steps": 40960,  # 327M tokens target
    "batch_size": 32,
    "gradient_clipping": 1.0,

    # gdrive
    "save_gdrive": False,
    "load_model_gdrive": "",

    # Device
    "device": torch.device('cuda' if torch.cuda.is_available() else 'cpu'),

    # wandb
    "wandb_active": False,
    "wandb_run": ""
}

@dataclass
class TrainModelArgs:
    # model args
    vocab_size: int = 10000
    context_length: int = 256
    num_layers: int = 4
    d_model: int = 512
    num_heads: int = 16
    d_ff: int = 1344
    rope_theta: Optional[int] = 10000

    # adamw args
    weight_decay: float = 0.01
    betas: tuple[float, float] = (0.9, 0.999)

    # Learning rate schedule
    max_learning_rate: float = 1e-3
    min_learning_rate: float = 1e-5
    warmup_iters: int = 2000
    cosine_cycle_iters: int = 40960

    # training loop args
    training_set: str | os.PathLike | BinaryIO | IO[bytes] = "TinyStoriesV2-GPT4-train.npy"
    validation_set: str | os.PathLike | BinaryIO | IO[bytes] = "TinyStoriesV2-GPT4-valid.npy"
    tokenizer_vocab: str | os.PathLike | BinaryIO | IO[bytes] = "tinystories_vocab.json"
    tokenizer_merges: str | os.PathLike | BinaryIO | IO[bytes] = "tinystories_merges.txt"

    validation_step_interval: int = 500
    checkpoint_step_interval: int = 10000
    steps: int = 40960
    batch_size: int = 32
    gradient_clipping: Optional[float] = 1.0

    # gdrive
    save_gdrive: bool = False
    load_model_gdrive: str = ""

    # wandb logging
    wandb_active: bool = False
    wandb_run: Optional[str] = ""

    # device
    device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            setattr(self, k, v)

class TrainModel:
    def __init__(self, args: TrainModelArgs):
        self.args = args
        self.cur_step = 0
        self.model = Transformer(
            vocab_size=args.vocab_size,
            context_length=args.context_length,
            num_layers=args.num_layers,
            num_heads=args.num_heads,
            d_model=args.d_model,
            d_ff=args.d_ff,
            rope_theta=args.rope_theta,
            device=args.device
        )
        self.optimizer = AdamW(
            self.model.parameters(),
            lr=self.args.max_learning_rate,
            weight_decay=args.weight_decay,
            betas=args.betas
        )

        self.tokenizer = Tokenizer.from_files(args.tokenizer_vocab, args.tokenizer_merges, ["<|endoftext|>"])

        self.training_set = np.load(self.args.training_set, mmap_mode='r')
        self.validation_set = np.load(self.args.validation_set, mmap_mode='r')

        if args.wandb_active and wandb.run:
            wandb.watch(self.model, log=cast(Literal["gradients", "parameters", "all"], "gradients"), log_freq=10)

    def evaluate(self):
        self.model.eval()
        with torch.no_grad():
            total_loss = 0.0
            total_size = self.training_set.size + self.validation_set.size
            eval_size = total_size // 1000
            num_batches = eval_size // (self.args.batch_size * self.args.context_length)

            num_batches = max(1, num_batches)

            for _ in range(num_batches):
                x, label = get_batch(self.validation_set, self.args.batch_size, self.args.context_length, device=self.args.device)
                with torch.autocast(device_type=self.args.device, dtype=torch.bfloat16):
                  output = self.model(x)
                loss = cross_entropy(output, label)
                total_loss += loss.item()

            avg_loss = torch.tensor(total_loss / num_batches)
            perplexity = avg_loss.exp()
            return avg_loss, perplexity

    def train(self):
        if self.args.load_model_gdrive != "":
          self.cur_step = load_checkpoint(self.args.load_model_gdrive, self.model, self.optimizer)

        valid_loss, valid_perplexity = self.evaluate()
        if self.args.wandb_active and wandb.run:
            wandb.log({"valid_loss": valid_loss, "valid_perplexity": valid_perplexity}, step=self.cur_step)

        pbar = tqdm(range(self.cur_step, self.args.steps))
        start_time = time.time()
        tokens_processed = 0

        for step in pbar:
            step_start_time = time.time()

            self.cur_step = step
            self.model.train()
            self.optimizer.zero_grad()

            lr = lr_cosine_schedule(
                step,
                self.args.max_learning_rate,
                self.args.min_learning_rate,
                self.args.warmup_iters,
                self.args.cosine_cycle_iters)
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = lr

            x, targets = get_batch(self.training_set, self.args.batch_size, self.args.context_length, device=self.args.device)
            with torch.autocast(device_type=self.args.device, dtype=torch.bfloat16):
              logits = self.model(x)
            loss = cross_entropy(logits, targets)
            loss.backward()
            l2norm = gradient_clipping(self.model.parameters(), self.args.gradient_clipping)
            self.optimizer.step()

            # Calculate metrics
            batch_tokens = x.shape[0] * x.shape[1]
            tokens_processed += batch_tokens
            elapsed_time = time.time() - start_time
            tokens_per_second = tokens_processed / elapsed_time if elapsed_time > 0 else 0
            dt = time.time() - step_start_time

            if self.args.save_gdrive and step % self.args.checkpoint_step_interval == 0 and step > 0 :
                os.makedirs(f'{drive_cs336_dir}/output', exist_ok=True)
                save_checkpoint(self.model, self.optimizer, step, f'{drive_cs336_dir}/output/checkpoint-{step}.pth')

            if (step % self.args.validation_step_interval == 0 and step > 0) or (step == self.args.steps-1):
                valid_loss, valid_perplexity = self.evaluate()

            pbar.set_postfix({
                "loss": f"{loss.item():.2f}",
                "valid_loss": f"{valid_loss.item():.2f}",
                "valid_perplexity": f"{valid_perplexity.item():.2f}",
            })

            if self.args.wandb_active and wandb.run:
                wandb.log({
                    "train_loss": loss.item(),
                    "train_perplexity": loss.exp().item(),
                    "valid_loss": valid_loss.item(),
                    "valid_perplexity": valid_perplexity.item(),
                    "grad_norm": l2norm,
                    "lr": lr,
                    "tokens_per_second": tokens_per_second,
                    "step_time_seconds": dt,
                    "gpu_memory_gb": torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0,
                    "tokens_processed": tokens_processed,
                }, step=step)

        # Save final checkpoint
        if self.args.save_gdrive:
          save_checkpoint(self.model, self.optimizer, step, f'{drive_cs336_dir}/output/checkpoint-{step}.pth')

        if self.args.wandb_active and wandb.run:
            local_checkpoint_path = f'{self.args.wandb_run}-checkpoint-{step}.pth'
            save_checkpoint(self.model, self.optimizer, step, local_checkpoint_path)

            artifact = wandb.Artifact(f"{self.args.wandb_run}-checkpoint_{step}", type="model")
            artifact.add_file(local_checkpoint_path)
            wandb.log_artifact(artifact)

            # Clean up local file after uploading to wandb
            os.remove(local_checkpoint_path)


In [None]:
import wandb
from dataclasses import asdict

# Complete sweep configuration
sweep_config = {
    'method': 'grid',
    'metric': {
        'name': 'valid_loss',
        'goal': 'minimize'
    },
    'parameters': {
        'batch_size': {'values': [32,64,128,256]},
    }
}

def train_sweep():
    # Initialize wandb run
    run = wandb.init(
      group="Tinystories-lr-sweep",
      force=True
    )
    config = wandb.config

    # Create training arguments - only override what's different from defaults
    train_args = {
        **DefaultTrainModelArgs,

        # Only the sweep parameters that differ from DefaultTrainModelArgs
        "steps": 160000 // config.batch_size, # 40,960,000 tokens processed
        "batch_size": config.batch_size,
        "validation_step_interval": 4096 // config.batch_size,
        "cosine_cycle_iters": 160000 // config.batch_size,
        "warmup_iters": 8000 // config.batch_size,
        "max_learning_rate": 1e-3 * (config.batch_size//16),
        "min_learning_rate": 1e-5 * (config.batch_size//16),

        # wandb settings - always override for sweep
        "wandb_active": True,                     # Enable for sweep
        "wandb_run" : f"batch_size_{config.batch_size}"
    }

    # Initialize and run training
    trainer = TrainModel(TrainModelArgs(**train_args))
    config.update(asdict(trainer.args))
    wandb.run.name = trainer.args.wandb_run
    trainer.train()
    wandb.finish()

# Create the sweep
sweep_id = wandb.sweep(
    sweep=sweep_config,
    project="cs336-llm-assignment1",
    entity="kl4kennylee81-kenneth-personal"
)

print(f"Sweep created successfully!")
print(f"Sweep ID: {sweep_id}")
print(f"Project: cs336-llm-assignment1")
print(f"wandb agent {sweep_id}")

# Run the sweep agent
# wandb.agent(sweep_id, train_sweep, count=8)

Create sweep with ID: i8mzslpy
Sweep URL: https://wandb.ai/kl4kennylee81-kenneth-personal/cs336-llm-assignment1/sweeps/i8mzslpy
Sweep created successfully!
Sweep ID: i8mzslpy
Project: cs336-llm-assignment1
wandb agent i8mzslpy


[34m[1mwandb[0m: Agent Starting Run: iwalby0f with config:
[34m[1mwandb[0m: 	batch_size: 16




  0%|          | 0/10000 [00:00<?, ?it/s]

0,1
gpu_memory_gb,▂▄▁▄▃▄▂▅▃▄▁▇▁▃▃█▂▅▃▂▃▃▂▃▅▄▃▁▄▂▄▄▄▂▅▂▁▄▂▃
grad_norm,█▂▃▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr,▅▆▇██████▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁
step_time_seconds,▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▇▁█▁▁█▁▁▇▁▁▁▁▁▁▁▁▁▁▁▂▁
tokens_per_second,▁▁▄██▃▄▂▃▄▅▅▂▄▁▁▁▂▃▂▂▃▁▂▂▃▂▁▂▂▂▂▂▁▂▂▂▂▂▁
tokens_processed,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▄▄▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇████
train_loss,█▄▃▃▃▃▃▃▂▃▂▂▂▂▂▂▂▁▂▂▁▂▁▂▂▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁
train_perplexity,█▄▄▄▂▂▂▂▂▂▂▂▂▂▂▁▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
valid_loss,█▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
valid_perplexity,███▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
gpu_memory_gb,0.54792
grad_norm,0.41117
lr,1e-05
step_time_seconds,0.09436
tokens_per_second,52447.0584
tokens_processed,40960000.0
train_loss,1.76822
train_perplexity,5.86039
valid_loss,1.68932
valid_perplexity,5.41581


[34m[1mwandb[0m: Agent Starting Run: 1ms0miee with config:
[34m[1mwandb[0m: 	batch_size: 32




  0%|          | 0/5000 [00:00<?, ?it/s]

0,1
gpu_memory_gb,▂▂▄▃▄▄█▁▃▆▃▆▅▁▃▂▃▅▃▂▄▅▁▂▃▄▁▄▄▂▂▃▂▂▃▃▄▂▃▅
grad_norm,█▆▅▅▄▃▄▄▄▃▃▄▄▂▂▃▂▂▂▂▃▂▃▃▂▂▂▂▂▂▂▂▁▁▂▁▂▂▂▂
lr,▅▇█████████▇▇▆▆▆▆▅▅▅▅▅▅▄▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁
step_time_seconds,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁
tokens_per_second,██▃▃▁▂▂▂▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
tokens_processed,▁▁▁▁▁▂▂▂▂▂▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_loss,█▇█▆▇▆▆▅▄▆▅▄▄▅▃▄▄▃▃▂▃▃▄▂▂▂▂▂▂▃▂▂▂▂▂▂▁▁▁▁
train_perplexity,█▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
valid_loss,███▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁
valid_perplexity,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
gpu_memory_gb,0.71196
grad_norm,0.29603
lr,2e-05
step_time_seconds,0.14097
tokens_per_second,61168.39582
tokens_processed,40960000.0
train_loss,1.87473
train_perplexity,6.51907
valid_loss,1.80881
valid_perplexity,6.10321


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: 0q84cjq0 with config:
[34m[1mwandb[0m: 	batch_size: 64




  0%|          | 0/2500 [00:00<?, ?it/s]

0,1
gpu_memory_gb,▁▁▁▁█▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁█▁█▁▁▁▁█▁▁▁▁▁▁
grad_norm,▁▁▁▁▁▁▁▁▁▁▁▂▂▂▁▃▃▂▃█▅▂▃▂▂▂▄▂▅▂▇▂▃▇▂▃█▃▂▂
lr,▆████████▇▇▇▇▇▇▆▅▅▅▄▄▄▄▄▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁
step_time_seconds,▁▁▁▁▁█▁█▁▂██▁▁▁▇▁▁▂▂▂▁▂▇▁▂▁▂▁▁▁▁▂▂▁▁█▁▂█
tokens_per_second,█▂▂▃▃▂▃▂▂▂▁▂▂▂▁▂▂▂▁▁▁▂▁▂▁▂▁▁▂▁▁▁▂▁▁▁▁▁▁▁
tokens_processed,▁▁▁▁▂▂▃▃▃▃▄▄▄▄▄▄▄▄▅▅▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇████
train_loss,█▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_perplexity,█▇▅▅▆▅▆▄▅▆▇▅▆▃▄▃▄▄▃▄▂▃▃▃▃▃▂▂▂▂▂▁▂▂▂▂▂▂▂▁
valid_loss,█▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
valid_perplexity,██▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
gpu_memory_gb,1.04062
grad_norm,0.67908
lr,4e-05
step_time_seconds,0.23545
tokens_per_second,67020.07536
tokens_processed,40960000.0
train_loss,2.20698
train_perplexity,9.08824
valid_loss,2.26537
valid_perplexity,9.6347


[34m[1mwandb[0m: Agent Starting Run: fdwacwaa with config:
[34m[1mwandb[0m: 	batch_size: 128




  0%|          | 0/1250 [00:00<?, ?it/s]

0,1
gpu_memory_gb,▁▇▄█▁▄▁▄▄▁▁▁▄▁▄▄▁▁▁█▁▁▁▁▁▁▁▁▁▄▁▁▄▁▁▁▁▁▂▁
grad_norm,▂▁▁▁▁▁▁▁▁▁▂█▁▁▂▁▁▁▂▁▁▁▁▄▁▁▁▂▁▁▁▁▁▁▁▂▂▁▁▂
lr,▁▂▅▆▇████████▇▇▇▇▇▆▆▅▅▅▅▅▄▄▄▃▃▂▂▂▂▂▁▁▁▁▁
step_time_seconds,▁▁▁▁▁▁▁▁▁▁▁▁▂▂▁▁▁▁▁▁█▁█▁▁▁▁█▁▁▁▁▁▁▁▁▁▂▂▁
tokens_per_second,█▆▃▂▁▁▂▂▂▂▂▂▁▂▁▂▁▂▂▂▁▂▂▁▁▁▂▂▁▂▁▁▁▂▁▁▁▁▁▁
tokens_processed,▁▂▂▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇███
train_loss,██▇▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_perplexity,█▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
valid_loss,███▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
valid_perplexity,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
gpu_memory_gb,1.6947
grad_norm,0.33957
lr,8e-05
step_time_seconds,0.41715
tokens_per_second,72234.29814
tokens_processed,40960000.0
train_loss,2.50557
train_perplexity,12.25052
valid_loss,2.45273
valid_perplexity,11.62008


[34m[1mwandb[0m: Agent Starting Run: oyd5jpgn with config:
[34m[1mwandb[0m: 	batch_size: 256




  0%|          | 0/625 [00:00<?, ?it/s]

0,1
gpu_memory_gb,▅▅▅▅▅▂▂▂▁▁▂▃▆▁█▅▇▂▂▅▃▅▅▅▂▆▆▇▅▂▃▂▇▂▂▂▂▇▂▅
grad_norm,▂▃▂▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▂▂▁▁▁▁▂▂▂█▁▁▁▁▁▂▁▁▁▁▂▁
lr,▅███████▇▇▇▇▆▆▆▆▅▅▅▅▄▄▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁
step_time_seconds,▁▂▂▁▂▂▂█▂▂▂▇▁▁▁▁▁▁▁▁▁▁▇▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
tokens_per_second,█▃▄▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
tokens_processed,▁▁▁▁▁▂▂▂▂▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇█████
train_loss,█▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▂▂▁▁▁▁▁▁▁▁▁
train_perplexity,█▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
valid_loss,█▅▃▃▃▃▃▃▂▂▂▂▂▂▂▂▃▃▂▂▂▂▂▂▂▁▁▁▁▂▂▂▂▂▁▁▁▁▁▁
valid_perplexity,██▂▂▂▂▁▁▁▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
gpu_memory_gb,3.00647
grad_norm,0.44879
lr,0.00016
step_time_seconds,0.75532
tokens_per_second,74893.84578
tokens_processed,40960000.0
train_loss,3.28608
train_perplexity,26.73779
valid_loss,3.28382
valid_perplexity,26.67759


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Sweep Agent: Exiting.


In [None]:
def generate_text(trainer, input_text, max_length, topk=50, temperature=0.5):
    # Encode input and add batch dimension
    tokens = trainer.tokenizer.encode(input_text)
    eot = trainer.tokenizer.encode('<|endoftext|>')
    x = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(trainer.args.device)

    trainer.model.eval()
    with torch.no_grad():
        for _ in range(max_length - x.size(1)):  # Limit iterations
            # Truncate if exceeds context length
            if x.size(1) >= trainer.args.context_length:
                x = x[:, -trainer.args.context_length:]

            # Get logits and apply temperature
            logits = trainer.model(x)[:, -1, :] / temperature
            probs = torch.nn.functional.softmax(logits, dim=-1)

            # Top-k sampling
            topk_probs, topk_indices = torch.topk(probs, topk, dim=-1)
            ix = torch.multinomial(topk_probs, 1)
            next_token = torch.gather(topk_indices, -1, ix)


            # Append token
            x = torch.cat((x, next_token), dim=1)

            # Optional: Stop if end token is generated
            if next_token.item() == eot[0]:
              break

    # Decode and return
    tokens = x[0].tolist()
    return trainer.tokenizer.decode(tokens)