In [1]:
"""
Contains the base Tokenizer class and a few common helper functions.
The base class also contains the (common) save/load functionality.
It would be possible to be a lot more strict about the interface and
e.g. isolating all regex/pattern parts to the RegexTokenizer, but
some concessions are made for simplicity.
"""
import unicodedata

# -----------------------------------------------------------------------------
# a few helper functions useful for both BasicTokenizer and RegexTokenizer

def get_stats(ids, counts=None):
    """
    Given a list of integers, return a dictionary of counts of consecutive pairs
    Example: [1, 2, 3, 1, 2] -> {(1, 2): 2, (2, 3): 1, (3, 1): 1}
    Optionally allows to update an existing dictionary of counts
    """
    counts = {} if counts is None else counts
    for pair in zip(ids, ids[1:]): # iterate consecutive elements
        counts[pair] = counts.get(pair, 0) + 1
    return counts


def merge(ids, pair, idx):
    """
    In the list of integers (ids), replace all consecutive occurrences
    of pair with the new integer token idx
    Example: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4]
    """
    newids = []
    i = 0
    while i < len(ids):
        # if not at the very last position AND the pair matches, replace it
        if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]:
            newids.append(idx)
            i += 2
        else:
            newids.append(ids[i])
            i += 1
    return newids

# first two helper functions...
def replace_control_characters(s: str) -> str:
    # we don't want to print control characters
    # which distort the output (e.g. \n or much worse)
    # https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python/19016117#19016117
    # http://www.unicode.org/reports/tr44/#GC_Values_Table
    chars = []
    for ch in s:
        if unicodedata.category(ch)[0] != "C":
            chars.append(ch) # this character is ok
        else:
            chars.append(f"\\u{ord(ch):04x}") # escape
    return "".join(chars)

def render_token(t: bytes) -> str:
    # pretty print a token, escaping control characters
    s = t.decode('utf-8', errors='replace')
    s = replace_control_characters(s)
    return s

# -----------------------------------------------------------------------------
# the base Tokenizer class

class Tokenizer:
    """Base class for Tokenizers"""

    def __init__(self):
        # default: vocab size of 256 (all bytes), no merges, no patterns
        self.merges = {} # (int, int) -> int
        self.pattern = "" # str
        self.special_tokens = {} # str -> int, e.g. {'<|endoftext|>': 100257}
        self.vocab = self._build_vocab() # int -> bytes

    def train(self, text, vocab_size, verbose=False):
        # Tokenizer can train a vocabulary of size vocab_size from text
        raise NotImplementedError

    def encode(self, text):
        # Tokenizer can encode a string into a list of integers
        raise NotImplementedError

    def decode(self, ids):
        # Tokenizer can decode a list of integers into a string
        raise NotImplementedError

    def _build_vocab(self):
        # vocab is simply and deterministically derived from merges
        vocab = {idx: bytes([idx]) for idx in range(256)}
        for (p0, p1), idx in self.merges.items():
            vocab[idx] = vocab[p0] + vocab[p1]
        for special, idx in self.special_tokens.items():
            vocab[idx] = special.encode("utf-8")
        return vocab

    def save(self, file_prefix):
        """
        Saves two files: file_prefix.vocab and file_prefix.model
        This is inspired (but not equivalent to!) sentencepiece's model saving:
        - model file is the critical one, intended for load()
        - vocab file is just a pretty printed version for human inspection only
        """
        # write the model: to be used in load() later
        model_file = file_prefix + ".model"
        with open(model_file, 'w') as f:
            # write the version, pattern and merges, that's all that's needed
            f.write("minbpe v1\n")
            f.write(f"{self.pattern}\n")
            # write the special tokens, first the number of them, then each one
            f.write(f"{len(self.special_tokens)}\n")
            for special, idx in self.special_tokens.items():
                f.write(f"{special} {idx}\n")
            # the merges dict
            for idx1, idx2 in self.merges:
                f.write(f"{idx1} {idx2}\n")
        # write the vocab: for the human to look at
        vocab_file = file_prefix + ".vocab"
        inverted_merges = {idx: pair for pair, idx in self.merges.items()}
        with open(vocab_file, "w", encoding="utf-8") as f:
            for idx, token in self.vocab.items():
                # note: many tokens may be partial utf-8 sequences
                # and cannot be decoded into valid strings. Here we're using
                # errors='replace' to replace them with the replacement char �.
                # this also means that we couldn't possibly use .vocab in load()
                # because decoding in this way is a lossy operation!
                s = render_token(token)
                # find the children of this token, if any
                if idx in inverted_merges:
                    # if this token has children, render it nicely as a merge
                    idx0, idx1 = inverted_merges[idx]
                    s0 = render_token(self.vocab[idx0])
                    s1 = render_token(self.vocab[idx1])
                    f.write(f"[{s0}][{s1}] -> [{s}] {idx}\n")
                else:
                    # otherwise this is leaf token, just print it
                    # (this should just be the first 256 tokens, the bytes)
                    f.write(f"[{s}] {idx}\n")

    def load(self, model_file):
        """Inverse of save() but only for the model file"""
        assert model_file.endswith(".model")
        # read the model file
        merges = {}
        special_tokens = {}
        idx = 256
        with open(model_file, 'r', encoding="utf-8") as f:
            # read the version
            version = f.readline().strip()
            assert version == "minbpe v1"
            # read the pattern
            self.pattern = f.readline().strip()
            # read the special tokens
            num_special = int(f.readline().strip())
            for _ in range(num_special):
                special, special_idx = f.readline().strip().split()
                special_tokens[special] = int(special_idx)
            # read the merges
            for line in f:
                idx1, idx2 = map(int, line.split())
                merges[(idx1, idx2)] = idx
                idx += 1
        self.merges = merges
        self.special_tokens = special_tokens
        self.vocab = self._build_vocab()

In [2]:
"""
Minimal (byte-level) Byte Pair Encoding tokenizer.

Algorithmically follows along the GPT tokenizer:
https://github.com/openai/gpt-2/blob/master/src/encoder.py

Unlike BasicTokenizer:
- RegexTokenizer handles an optional regex splitting pattern.
- RegexTokenizer handles optional special tokens.
"""

import regex as re


# the main GPT text split patterns, see
# https://github.com/openai/tiktoken/blob/main/tiktoken_ext/openai_public.py
GPT2_SPLIT_PATTERN = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""


class RegexTokenizer(Tokenizer):

    def __init__(self, pattern=None):
        """
        - pattern: optional string to override the default (GPT-4 split pattern)
        - special_tokens: str -> int dictionary of special tokens
          example: {'<|endoftext|>': 100257}
        """
        super().__init__()
        self.pattern = GPT4_SPLIT_PATTERN if pattern is None else pattern
        self.compiled_pattern = re.compile(self.pattern)
        self.special_tokens = {}
        self.inverse_special_tokens = {}

    def train(self, text, vocab_size, verbose=False):
        assert vocab_size >= 256
        num_merges = vocab_size - 256

        # split the text up into text chunks
        text_chunks = re.findall(self.compiled_pattern, text)

        # input text preprocessing
        ids = [list(ch.encode("utf-8")) for ch in text_chunks]

        # iteratively merge the most common pairs to create new tokens
        merges = {} # (int, int) -> int
        vocab = {idx: bytes([idx]) for idx in range(256)} # idx -> bytes
        for i in range(num_merges):
            # count the number of times every consecutive pair appears
            stats = {}
            for chunk_ids in ids:
                # passing in stats will update it in place, adding up counts
                get_stats(chunk_ids, stats)
            # find the pair with the highest count
            pair = max(stats, key=stats.get)
            # mint a new token: assign it the next available id
            idx = 256 + i
            # replace all occurrences of pair in ids with idx
            ids = [merge(chunk_ids, pair, idx) for chunk_ids in ids]
            # save the merge
            merges[pair] = idx
            vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
            # prints
            if verbose:
                print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")

        # save class variables
        self.merges = merges # used in encode()
        self.vocab = vocab   # used in decode()

    def register_special_tokens(self, special_tokens):
        # special_tokens is a dictionary of str -> int
        # example: {"<|endoftext|>": 100257}
        self.special_tokens = special_tokens
        self.inverse_special_tokens = {v: k for k, v in special_tokens.items()}

    def decode(self, ids):
        # given ids (list of integers), return Python string
        part_bytes = []
        for idx in ids:
            if idx in self.vocab:
                part_bytes.append(self.vocab[idx])
            elif idx in self.inverse_special_tokens:
                part_bytes.append(self.inverse_special_tokens[idx].encode("utf-8"))
            else:
                raise ValueError(f"invalid token id: {idx}")
        text_bytes = b"".join(part_bytes)
        text = text_bytes.decode("utf-8", errors="replace")
        return text

    def _encode_chunk(self, text_bytes):
        # return the token ids
        # let's begin. first, convert all bytes to integers in range 0..255
        ids = list(text_bytes)
        while len(ids) >= 2:
            # find the pair with the lowest merge index
            stats = get_stats(ids)
            pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
            # subtle: if there are no more merges available, the key will
            # result in an inf for every single pair, and the min will be
            # just the first pair in the list, arbitrarily
            # we can detect this terminating case by a membership check
            if pair not in self.merges:
                break # nothing else can be merged anymore
            # otherwise let's merge the best pair (lowest merge index)
            idx = self.merges[pair]
            ids = merge(ids, pair, idx)
        return ids

    def encode_ordinary(self, text):
        """Encoding that ignores any special tokens."""
        # split text into chunks of text by categories defined in regex pattern
        text_chunks = re.findall(self.compiled_pattern, text)
        # all chunks of text are encoded separately, then results are joined
        ids = []
        for chunk in text_chunks:
            chunk_bytes = chunk.encode("utf-8") # raw bytes
            chunk_ids = self._encode_chunk(chunk_bytes)
            ids.extend(chunk_ids)
        return ids

    def encode(self, text, allowed_special="none_raise"):
        """
        Unlike encode_ordinary, this function handles special tokens.
        allowed_special: can be "all"|"none"|"none_raise" or a custom set of special tokens
        if none_raise, then an error is raised if any special token is encountered in text
        this is the default tiktoken behavior right now as well
        any other behavior is either annoying, or a major footgun
        """
        # decode the user desire w.r.t. handling of special tokens
        special = None
        if allowed_special == "all":
            special = self.special_tokens
        elif allowed_special == "none":
            special = {}
        elif allowed_special == "none_raise":
            special = {}
            assert all(token not in text for token in self.special_tokens)
        elif isinstance(allowed_special, set):
            special = {k: v for k, v in self.special_tokens.items() if k in allowed_special}
        else:
            raise ValueError(f"allowed_special={allowed_special} not understood")
        if not special:
            # shortcut: if no special tokens, just use the ordinary encoding
            return self.encode_ordinary(text)
        # otherwise, we have to be careful with potential special tokens in text
        # we handle special tokens by splitting the text
        # based on the occurrence of any exact match with any of the special tokens
        # we can use re.split for this. note that surrounding the pattern with ()
        # makes it into a capturing group, so the special tokens will be included
        special_pattern = "(" + "|".join(re.escape(k) for k in special) + ")"
        special_chunks = re.split(special_pattern, text)
        # now all the special characters are separated from the rest of the text
        # all chunks of text are encoded separately, then results are joined
        ids = []
        for part in special_chunks:
            if part in special:
                # this is a special token, encode it separately as a special case
                ids.append(special[part])
            else:
                # this is an ordinary sequence, encode it normally
                ids.extend(self.encode_ordinary(part))
        return ids

In [3]:
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F
import pandas as pd
import os
import inspect


class DataLoader:
    def __init__(self, B, T):
        self.B = B
        self.T = T

        vocab_size = 488
        tokenizer = RegexTokenizer()
        df = pd.read_csv("newspeak.csv")
        text = "".join([v for v in df["Newspeak Translation"]])

        if not os.path.exists("tok400.model"):
            tokenizer.train(text, vocab_size=vocab_size, verbose=True) # 9746 is the size of the vocabulary
            tokenizer.save("tok400") # writes tok32k.model and tok32k.vocab
        
        tokenizer.load("tok400.model") # loads the model back from disk
        self.tokens = torch.tensor(tokenizer.encode(text), dtype=torch.long)
        print("loaded {len(self.tokens)} tokens")
        print("1 epoch = {len(self.tokens) // (B * T)} batches")

        self.curent_idx = 0

    def next_batch(self):
        # example batch of 4 sequences of length 6. we'll use the first 24 tokens as input, 
        # and the last 24 tokens as labels. this is a common trick in language modeling, 
        # where we want to predict the next token in the sequence.
        B, T = self.B, self.T
        buf = self.tokens[self.curent_idx:self.curent_idx + B * T + 1]
        x = (buf[:-1]).view(B, T)
        y = (buf[1:]).view(B, T)

        self.curent_idx += B * T
        if self.curent_idx + B * T > len(self.tokens):
            self.curent_idx = 0
        return x, y


@dataclass
class Config:
    vocab_size: int = 512
    n_embed: int = 384
    n_heads: int = 6
    n_layers: int = 6
    context_length: int = 256
    dropout: float = 0.2
    bias: bool = True 


class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embed),
            wpe = nn.Embedding(config.context_length, config.n_embed),
            drop = nn.Dropout(config.dropout),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layers)]),
            ln_f = nn.LayerNorm(config.n_embed, bias=config.bias),
        ))
        self.lm_head = nn.Linear(config.n_embed, config.vocab_size, bias=False)

        self.transformer.wte.weight = self.lm_head.weight # tie weights | weight sharing scheme

        self.apply(self._init_weights)
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                # this is the projection layer of the attention head
                # we want to initialize it with the same weights as the embedding layer
                # multiply n_layers by 2, because we have 2 additive connections in the block
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layers))

        print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))

    def _init_weights(self, module):
        if isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
    
    def get_num_params(self, non_embedding=True):
        """
        Returns the number of parameters in the model.
        If non_embedding is True, it excludes the embedding layer.
        """
        n_params = sum(p.numel() for p in self.parameters())
        if non_embedding:
            n_params -= self.transformer.wpe.weight.numel()
        return n_params

    @classmethod
    def from_pretrained(cls, model_name):
        # load pretrained model from huggingface
        assert model_name in ["gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl"]
        from transformers import GPT2LMHeadModel

        config_args = {
            "gpt2": {"n_embed": 768, "n_heads": 12, "n_layers": 12},
            "gpt2-medium": {"n_embed": 1024, "n_heads": 16, "n_layers": 24},
            "gpt2-large": {"n_embed": 1280, "n_heads": 20, "n_layers": 36},
            "gpt2-xl": {"n_embed": 1600, "n_heads": 25, "n_layers": 48},
        }[model_name]
        config_args["context_length"] = 1024
        config_args["vocab_size"] = 50257
        config = Config(**config_args)
        model = GPT(config)
        sd = model.state_dict()
        sd_keys = [k for k in sd.keys() if not k.endswith(".attn.bias")]

        model_hf = GPT2LMHeadModel.from_pretrained(model_name)
        sd_hf = model_hf.state_dict()
        # remove the extra keys from the huggingface model
        sd_keys_hf = [k for k in sd_hf.keys() if not k.endswith(".attn.bias") or not k.endswith(".attn.masked_bias")]
        transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
        for k in sd_keys_hf:
            if k not in sd_keys:
                print(f"Warning: {k} not in model state dict")
        assert len(sd_keys_hf) == len(sd_keys), f"Mismatch in number of keys: {len(sd_keys_hf)} != {len(sd_keys)}"
        for k in sd_keys_hf:
            if any(k.endswith(w) for w in transposed):
                assert sd_hf[k].shape[::-1] == sd[k].shape, f"Shape mismatch for {k}: {sd_hf[k].shape} != {sd[k].shape}"
                with torch.no_grad():
                    sd[k].copy_(sd_hf[k].t())
        
            else:
                assert sd_hf[k].shape == sd[k].shape, f"Shape mismatch for {k}: {sd_hf[k].shape} != {sd[k].shape}"
                with torch.no_grad():
                    sd[k].copy_(sd_hf[k])

        return model
    
    def forward(self, idx, targets=None):
        B, T = idx.shape # B, T
        # make sure the input is within the context length
        assert T <= self.config.context_length, f"Cannot forward, model block size is {self.config.context_length}, but input is {T}"
        # forward the token embeddings and position embeddings
        pos = torch.arange(0, T, dtype=torch.long, device=idx.device) # shape (T)
        pos_emb = self.transformer.wpe(pos) # positional embedding of shape (T, n_embed)
        tok_emb = self.transformer.wte(idx) # token embedding of shape (B, T, n_embed)
        x = tok_emb + pos_emb
        # forward the transformer blocks
        for block in self.transformer.h:
            x = block(x)
        # forward the final layer norm and the lm head
        x = self.transformer.ln_f(x)

        if targets is not None:
            logits = self.lm_head(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
        else:
            logits = self.lm_head(x[:, [-1], :]) # shape (B, 1, vocab_size)
            loss = None
            
        return logits, loss

    def configure_optimizers(self, weight_decay, learning_rate, device):
        # start with all of the candidate parameters (that require grad)
        param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad}
        # create optim groups. any parameters that is 2D will be weight delayed, otherwise not
        # i.e. all weight tensors in matmuls + embeddings decay , all biases and layernorms dont
        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
        optim_groups = [
            {"params": decay_params, "weight_decay": weight_decay},
            {"params": nodecay_params, "weight_decay": 0.0},
        ]
        num_decay_params = sum(p.numel() for p in decay_params)
        num_nodecay_params = sum(p.numel() for p in nodecay_params)
        print(f"number of decay params: {num_decay_params / 1e6:.2f}M")
        print(f"number of nodecay params: {num_nodecay_params / 1e6:.2f}M")
        # create the optimizer and use fused version if available 
        fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
        use_fused = fused_available and 'cuda' in device
        print("using fused adamw" if use_fused else "using adamw")
        optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=(0.9, 0.95), eps=1e-8, fused=use_fused)
        return optimizer


class LayerNorm(nn.Module):
    """LayerNorm with optioanl bias"""

    def __init__(self, n_dim, bias):
        super().__init__()
        self.weigth = nn.Parameter(torch.ones(n_dim))
        self.bias = nn.Parameter(torch.zeros(n_dim)) if bias else None

    def forward(self, x):
        return F.layer_norm(x, self.weigth.shape, self.weigth, self.bias, 1e-5)
    

class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attn = CausalSelfAttention(config)
        self.mlp = FeedForward(config)
        self.ln_1 = nn.LayerNorm(config.n_embed, bias=config.bias)
        self.ln_2 = nn.LayerNorm(config.n_embed, bias=config.bias)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x
    

class FeedForward(nn.Module):
    def __init__(self, config):   
        super().__init__()
        
        self.c_fc = nn.Linear(config.n_embed, 4 * config.n_embed, bias=config.bias)
        self.gelu = nn.GELU(approximate='tanh')
        self.c_proj = nn.Linear(4 * config.n_embed, config.n_embed, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x
    

class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embed % config.n_heads == 0, "n_embed must be divisible by n_heads"
        self.c_attn = nn.Linear(config.n_embed, 3 * config.n_embed)
        self.c_proj = nn.Linear(config.n_embed, config.n_embed)
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_heads = config.n_heads
        self.n_embed = config.n_embed
        self.dropout = config.dropout
        self.register_buffer("bias", torch.tril(torch.ones(config.context_length, config.context_length)).view(1, 1, config.context_length, config.context_length))
        
    def forward(self, x):
        B, T, C = x.shape
        q, k, v = self.c_attn(x).split(self.n_embed, dim=2)
        k = k.view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2) # B, n_heads, T, C // n_heads
        q = q.view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2) # B, n_heads, T, C // n_heads
        v = v.view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2) # B, n_heads, T, C // n_heads
        
        # compute attention scores
        # att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        # att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
        # att = F.softmax(att, dim=-1)
        # y = att @ v # B, n_heads, T, T x B, n_heads, T, head_size  -> B, n_heads, T, head_size
        y = F.scaled_dot_product_attention(q, k, v, is_causal=True)

        y = y.transpose(1, 2).contiguous().view(B, T, C) # B, T, n_heads * head_size
        y = self.c_proj(y) # B, T, C
        return y
    


In [None]:
import time

device = 'cuda' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu')
print("Using device: ", device)

total_batch_size, B, T = 524288, 16, 1024 # Total Batch Size, Micro Batch Size, Sequence Length
assert total_batch_size % (B * T) == 0, f"total_batch_size {total_batch_size} must be divisible by B*T {B*T}"
grad_accumulation_steps = total_batch_size // (B * T)
print("total_batch_size: ", total_batch_size)
print("=> calculated grad_accumulation_steps: ", grad_accumulation_steps)

train_loader = DataLoader(B=4, T=32)
torch.set_float32_matmul_precision('medium') # set 'medium' for bfloat16, 'high' for tf32, 'highest' for float32

# loss to expect upon initialization is -log(1/vocab_size) = -log(1/488) = 6.2
model = GPT(Config())
#model.eval()
model.to(device)
if device not in ['mps', 'cpu']:
    model = torch.compile(model) # optimises model for the cuda device

max_lr = 6e-4
min_lr = max_lr * 0.1
warmup_steps = 10
max_steps = 50

def lr_scheduler(it):
    # linear warmup for warmup_iters steps
    if it < warmup_steps:
        return max_lr * (it+1) / warmup_steps
    # if it > lr_decay_iters, return min_lr
    if it > max_steps:
        return min_lr
    # in between, use cosine decay down to min learning rate
    decay_ratio = (it - warmup_steps) / (max_steps - warmup_steps)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1 + math.cos(math.pi * decay_ratio)) # coeff starts at 1 and goes to 0
    return min_lr + coeff * (max_lr - min_lr)

# optimisation
optimizer = model.configure_optimizers(weight_decay=0.1, learning_rate=6e-4, device=device)
for step in range(max_steps):
    t0 = time.time()
    optimizer.zero_grad()
    loss_accum = 0.0
    for micro_step in range(grad_accumulation_steps):
        x, y = train_loader.next_batch()
        x, y = x.to(device), y.to(device)
        if device in ['mps', 'cpu']:
            # MPS does not support autocast yet
            logits, loss = model(x, y)
        else:
            # autocast automatically handles the precision and the device
            with torch.autocast(device_type=device, dtype=torch.bfloat16):
                logits, loss = model(x, y)
        loss = loss / grad_accumulation_steps
        loss_accum += loss.detach()
        loss.backward()
    # clips gradients to prevent exploding gradients caused by bad data or initialization
    norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    # determine the learning rate
    lr = lr_scheduler(step)
    for param_group in optimizer.param_groups:
        # update the parameters
        param_group['lr'] = lr
    optimizer.step()
    # torch.cuda.synchronize(device=device)
    t1 = time.time()
    tokens_processed = (train_loader.B * train_loader.T) * grad_accumulation_steps
    tokens_per_second = tokens_processed / (t1 - t0)
    print(f"step: {step} | loss: {loss_accum.item():.6f} | lr: {lr:.4e} | norm: {norm:.4e} | time: {(t1 - t0)*1000:.2f}ms | tokens/s: {tokens_per_second:.2f}")

# TODO: 
# choose a larger dataset
# load the dataset from huggingface and preprocess it
# split dataset with train/val/test and load with dataloader separately
# implement validation step in train loop
# do the logging
# save the model to huggingface

import sys; sys.exit()
# inference
num_return_sequences = 5
max_length = 30

torch.manual_seed(42)
torch.cuda.manual_seed(42)
while x.size(1) < max_length:
    with torch.no_grad():
        logits = model(x)
        # get the last token
        logits = logits[:, -1, :] # B, T, C -> B, C
        # sample from the distribution
        probs = F.softmax(logits, dim=-1)
        topk_probs, topk_indices = torch.topk(probs, k=50, dim=-1)
        next_token = torch.multinomial(topk_probs, num_samples=1)
        next_token = torch.gather(topk_indices, -1, next_token)
        x = torch.cat((x, next_token), dim=1)
    
for i in range(num_return_sequences):
    tokens = x[i, :max_length].tolist()
    decoded = tokenizer.decode(tokens)
    print(f"> {decoded}")

Using device:  mps
total_batch_size:  524288
=> calculated grad_accumulation_steps:  32
loaded {len(self.tokens)} tokens
1 epoch = {len(self.tokens) // (B * T)} batches
number of parameters: 10.84M
number of decay params: 10.91M
number of nodecay params: 0.03M
using adamw


  from .autonotebook import tqdm as notebook_tqdm


step: 0 | loss: 6.284475 | lr: 6.0000e-05 | norm: 1.0211e+01 | time: 1419.85ms | tokens/s: 2884.82
step: 1 | loss: 5.584973 | lr: 1.2000e-04 | norm: 6.1113e+00 | time: 571.91ms | tokens/s: 7161.99
step: 2 | loss: 5.116330 | lr: 1.8000e-04 | norm: 4.4456e+00 | time: 544.29ms | tokens/s: 7525.44
step: 3 | loss: 4.736885 | lr: 2.4000e-04 | norm: 5.6074e+00 | time: 548.44ms | tokens/s: 7468.50
step: 4 | loss: 4.274448 | lr: 3.0000e-04 | norm: 1.8683e+01 | time: 553.03ms | tokens/s: 7406.50
step: 5 | loss: 3.855610 | lr: 3.6000e-04 | norm: 8.9580e+00 | time: 554.29ms | tokens/s: 7389.66
step: 6 | loss: 3.420907 | lr: 4.2000e-04 | norm: 1.1704e+01 | time: 557.43ms | tokens/s: 7347.98
step: 7 | loss: 3.091822 | lr: 4.8000e-04 | norm: 7.4358e+00 | time: 553.31ms | tokens/s: 7402.67
step: 8 | loss: 2.812276 | lr: 5.4000e-04 | norm: 8.0808e+00 | time: 550.07ms | tokens/s: 7446.38
step: 9 | loss: 2.536636 | lr: 6.0000e-04 | norm: 5.5387e+00 | time: 548.23ms | tokens/s: 7471.29
step: 10 | loss: 2.

SystemExit: 

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
