# Transformer in PyTorch

- source: https://github.com/karpathy/nanoGPT

In [1]:
import math
import inspect
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F

## Globals

In [2]:
@dataclass
class Config:
    context_len: int = 8
    vocab_size: int = 65
    n_embd: int = 32
    n_head: int = 4
    n_layer: int = 3
    dropout: float = 0.2
    bias: bool = False
    flash_attn: bool = False

In [3]:
torch.manual_seed(1337)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Running on device: ", device)

Running on device:  cpu


## Scratch Model

In [4]:
class SelfAttention(nn.Module):
    """ One head of self attention """

    def __init__(self, config):
        super().__init__()

        assert config.n_embd % config.n_head == 0
        self.head_size = config.n_embd // config.n_head
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.dropout = config.dropout
        self.flash_attn = config.flash_attn

        # One transformation with K,Q,V and all heads combined
        self.lin_attn = nn.Linear(config.n_embd, 3 * config.n_embd, config.bias)
        self.proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)

        self.attn_dropout = nn.Dropout(config.dropout)
        self.proj_dropout = nn.Dropout(config.dropout)

        # Flash attention is fast on GPU but support is only in PyTorch >= 2.0
        if not config.flash_attn:
            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
            # causal mask to ensure that attention is only applied to the left in the input sequence
            self.register_buffer("tril", torch.tril(torch.ones(config.context_len, config.context_len))
                                            .view(1, 1, config.context_len, config.context_len))

    def forward(self, x):
        B,T,C = x.shape

        # Calculate k, q, v and reshape to the separate the heads from the key-query space
        k, q, v = self.lin_attn(x).split(self.n_embd, dim=2) # (B, T, n_head * head_size)
        k = k.view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B,nh,T,hs)
        q = q.view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B,nh,T,hs)
        v = v.view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B,nh,T,hs)

        # Calculate the attention affinities (B,nh,T,hs) x (B,nh,hs,T) = (B,nh,T,T)
        if self.flash_attn:
            attn = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, droupout_p=self.dropout if self.training else 0, is_causal=True)
        else:
            attn = q @ k.transpose(-2, -1) * self.head_size**-0.5 # (B,nh,T,hs) x (B,nh,hs,T) = (B,nh,T,T)
            attn = attn.masked_fill(self.tril[:,:,:T,:T] == 0, float('-inf')) # (B,nh,T,T)
            attn = F.softmax(attn, dim=-1) # (B,nh,T,T)
            attn = self.attn_dropout(attn) # (B,nh,T,T)

            # Perform weighted aggregation of values
            attn = attn @ v # (B,nh,T,T) @ (B,nh,T,hs) --> (B,nh,T,hs)

        out = attn.transpose(1,2).contiguous().view(B, T, C) # Concat the heads
        out = self.proj(out)
        out = self.proj_dropout(out)

        return out

In [5]:
class FeedForward(nn.Module):
    """" Single layer neural network that acts on each token embedding indepentendly """

    def __init__(self, config):
        super().__init__()

        self.fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) # Inner layer has 4x the embedding size in the paper
        self.gelu = nn.GELU()
        self.proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) # Projection back into residual pathway
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.fc(x)
        x = self.gelu(x)
        x = self.proj(x)
        x = self.dropout(x)
        return x

In [6]:
class Block(nn.Module):
    """ Transformer block that alternates attention and computation """

    def __init__(self, config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.n_embd, bias=config.bias)
        self.sa = SelfAttention(config)
        self.ln2 = nn.LayerNorm(config.n_embd, bias=config.bias)
        self.ffwd = FeedForward(config)

    def forward(self, x):
        # Note, in original paper, layer norm was applied after the transformations
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

In [74]:
class Transformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.vocab_size is not None
        assert config.context_len is not None
        self.config = config

        # Each token is embedded into a vector
        self.token_embedding_table = nn.Embedding(config.vocab_size, config.n_embd)
        # Add a positional encoding embedding
        self.position_embedding_table = nn.Embedding(config.context_len, config.n_embd)
        # Dropout for embeddings
        self.dropout = nn.Dropout(config.dropout)
        # Set of transformer blocks
        self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
        # Final layer norm
        self.ln_f = nn.LayerNorm(config.n_embd, bias=config.bias)
        # Bring the embeddings back to vocab_size at the output
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        # Weight tying
        self.token_embedding_table.weight = self.lm_head.weight

        # Init all weights
        self.apply(self._init_weights)
        for pn, p in self.named_parameters():
            if pn.endswith('proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))


        # report number of parameters
        print("number of parameters: %.3fK" % (self.get_num_params()/1e3,))

    def get_num_params(self, non_embedding=True):
        """
        Return the number of parameters in the model.
        For non-embedding count (default), the position embeddings get subtracted.
        The token embeddings would too, except due to the parameter sharing these
        params are actually used as weights in the final layer, so we include them.
        """
        n_params = sum(p.numel() for p in self.parameters())
        if non_embedding:
            n_params -= self.position_embedding_table.weight.numel()
        return n_params

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        # idx and targets are both (B,T) tensors of ints
        device = idx.device
        B, T = idx.size()
        assert T <= self.config.context_len, f"Cannot forward sequence of length {T}, context_len is only {self.config.context_len}"

        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        x = self.dropout(tok_emb + pos_emb) # (B,T,C)
        for block in self.blocks:
            x = block(x) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)

        # Calculate the cross entropy loss for each example
        if targets == None:
            # Optimization to only predict the next word for inference
            logits = self.lm_head(x[:,[-1],:]) # (B,1,C)
            loss = None
        else:
            logits = self.lm_head(x) # (B,T,C)
            B, T, C = logits.shape
            logits = logits.view(B*T, C) # reshape for cross entropy loss
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        """
        Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
        the sequence max_new_tokens times, feeding the predictions back into the model each time.
        Most likely you'll want to make sure to be in model.eval() mode of operation for this.
        """
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # get the predicted next tokens
            logits, _ = self(idx[:,-self.config.context_len:])
            # focus only on the last time step for generation (already done in model)
            logits = logits[:, -1, :] / temperature # (B, C)
            # optionally crop the logits to only the top k options
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            # apply softmax to get probabilities of next token
            probs = F.softmax(logits, dim=1) # (B, C)
            # sample a new token from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sample to sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)

        return idx

## PyTorch Implementation

- Additional reference: https://pytorch.org/tutorials/beginner/transformer_tutorial.html

In [75]:
from torch.nn import TransformerEncoder, TransformerEncoderLayer

class TransformerModelPyTorch(nn.Module):

    def __init__(self, config):

        super().__init__()
        self.model_type = 'Transformer'
        assert config.vocab_size is not None
        assert config.context_len is not None
        self.config = config

        # Each token is embedded into a vector
        self.token_embedding_table = nn.Embedding(config.vocab_size, config.n_embd)
        # Add a positional encoding embedding
        self.position_embedding_table = nn.Embedding(config.context_len, config.n_embd)
        # Dropout for embeddings
        self.dropout = nn.Dropout(config.dropout)
        # Set of transformer blocks
        decoder_layer = TransformerEncoderLayer(config.n_embd, config.n_head, 4 * config.n_embd, config.dropout, 'gelu', batch_first=True, norm_first=True, bias=config.bias, device=device)
        self.transformer_decoder = TransformerEncoder(decoder_layer, config.n_layer)
        # Final layer norm
        self.ln_f = nn.LayerNorm(config.n_embd, bias=config.bias)
        # Bring the embeddings back to vocab_size at the output
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        # Weight tying
        self.token_embedding_table.weight = self.lm_head.weight

        # Init all weights
        self.init_weights()

        # report number of parameters
        print("number of parameters: %.3fK" % (self.get_num_params()/1e3,))

    def init_weights(self) -> None:
        initrange = 0.1
        self.token_embedding_table.weight.data.uniform_(-initrange, initrange)
        self.position_embedding_table.weight.data.uniform_(-initrange, initrange)
        self.lm_head.weight.data.uniform_(-initrange, initrange)

    def forward(self, idx, targets=None):
        # idx and targets are both (B,T) tensors of ints
        device = idx.device
        B, T = idx.size()
        assert T <= self.config.context_len, f"Cannot forward sequence of length {T}, context_len is only {self.config.context_len}"

        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        x = self.dropout(tok_emb + pos_emb) # (B,T,C)

        """Generate a square causal mask for the sequence. The masked positions are filled with float('-inf').
        Unmasked positions are filled with float(0.0).
        """
        src_mask = nn.Transformer.generate_square_subsequent_mask(x.size()[1]).to(device)
        x = self.transformer_decoder(x, src_mask, is_causal=True)
        x = self.ln_f(x) # (B,T,C)

        # Calculate the cross entropy loss for each example
        if targets == None:
            # Optimization to only predict the next word for inference
            logits = self.lm_head(x[:,[-1],:]) # (B,1,C)
            loss = None
        else:
            logits = self.lm_head(x) # (B,T,C)
            B, T, C = logits.shape
            logits = logits.view(B*T, C) # reshape for cross entropy loss
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def get_num_params(self, non_embedding=True):
        """
        Return the number of parameters in the model.
        For non-embedding count (default), the position embeddings get subtracted.
        The token embeddings would too, except due to the parameter sharing these
        params are actually used as weights in the final layer, so we include them.
        """
        n_params = sum(p.numel() for p in self.parameters())
        if non_embedding:
            n_params -= self.position_embedding_table.weight.numel()
        return n_params

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        """
        Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
        the sequence max_new_tokens times, feeding the predictions back into the model each time.
        Most likely you'll want to make sure to be in model.eval() mode of operation for this.
        """
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # get the predicted next tokens
            logits, _ = self(idx[:,-self.config.context_len:])
            # focus only on the last time step for generation (already done in model)
            logits = logits[:, -1, :] / temperature # (B, C)
            # optionally crop the logits to only the top k options
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            # apply softmax to get probabilities of next token
            probs = F.softmax(logits, dim=1) # (B, C)
            # sample a new token from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sample to sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)

        return idx

## Training Basic - Shakespeare

In [76]:
# Start with a dataset to train on. Use the tiny shakespeare dataset
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2024-06-05 16:50:28--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt.5’


2024-06-05 16:50:28 (13.5 MB/s) - ‘input.txt.5’ saved [1115394/1115394]



In [77]:
# Training params
batch_size = 32
learning_rate = 1e-3
max_iters = 2000
eval_interval = 500
eval_iters = 200
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Running on device: ", device)

config = Config()
# config.flash_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
print("Flash attention enabled: ", config.flash_attn)

Running on device:  cpu
Flash attention enabled:  False


In [78]:
with open ('input.txt', 'r', encoding='utf-8)') as f:
    text = f.read()

# Define vocabulary
chars = sorted(list(set(text)))
vocab_size = len(chars)

# Create encoder and decoder
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i,ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

# Data batch loading
def get_batch(split):
    # Generate small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - config.context_len, (batch_size,))
    x = torch.stack([data[i:i+config.context_len] for i in ix])
    y = torch.stack([data[i+1:i+config.context_len+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

# Estimate the loss from training and validation sets over a number of batches
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [80]:
# Training loop
# Define the model and send to device
model = Transformer(config)
m = model.to(device)

# Create the optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):

  # Every once in a while, evaluate the loss on the train and val sets
  if iter % eval_interval == 0:
    losses = estimate_loss()
    print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

  xb, yb = get_batch('train')

  logits, loss = model(xb, yb)
  optimizer.zero_grad(set_to_none=True)
  loss.backward()
  optimizer.step()

# Test model generation
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=500)[0].tolist()))

number of parameters: 39.168K
step 0: train loss 4.1752, val loss 4.1739
step 500: train loss 2.5423, val loss 2.5476
step 1000: train loss 2.4393, val loss 2.4251
step 1500: train loss 2.3876, val loss 2.3827

NIRG y byodromns the yaud, avin vkis glond tlond itos kendeicukig, Bre iumrt bemerecpy hyow.


ULor afdemou y
Ippakrt shah ire htharceweele o afoeaptememl:
Thisff I o alyo arit, bratis thabe aot ne herer ithie theallrdeee Cef, whe yazje mon.
Goxyurwen aNss bus rort,
'nan-Sonalds to kissergavion, and thedt ir eyis an:
Thit shi'y cirse tratitich, sfulfacs
Gicel ghyogh
Reand J kezroul gsre dand dlarror meardtt tthislo thiant is? I ald:
Puncow movore, thil: te wis:

That ceyd to sre ofrigeen a tislp


In [79]:
# Training loop
# Define the model and send to device
model = TransformerModelPyTorch(config)
m = model.to(device)

# Create the optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):

  # Every once in a while, evaluate the loss on the train and val sets
  if iter % eval_interval == 0:
    losses = estimate_loss()
    print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

  xb, yb = get_batch('train')

  logits, loss = model(xb, yb)
  optimizer.zero_grad(set_to_none=True)
  loss.backward()
  optimizer.step()

# Test model generation
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=500)[0].tolist()))



number of parameters: 39.168K
step 0: train loss 4.2980, val loss 4.2918
step 500: train loss 2.5799, val loss 2.5785
step 1000: train loss 2.4679, val loss 2.4451
step 1500: train loss 2.4312, val loss 2.4040

Tat far yeat ho re iprf clomUn ernel.he ancpected mopot'te shos mosor sereaf tope me nold se cerlens! thentll cons?

Ass,alt tristhend.


MANU:
ANur
Whiess nore avy,
TOTrel lavem.
AAURNAng apillo sveity ser anke comy ul bitoecy,
An vfum heen?
Mit.

Shis to ly, Sod sockond Meth sov hes as leyIr tor meri pestes sevilctelldokss h bed al fy.

OKThele ospe' ivord;
Tneafel helury cusoce: is yoouden? ginerewae, Cerenirst.

ILA ININF:
peat wemit onergind fyin Hiq thhast Loth notgomit gowroulachey ghird.


## Extra Nano GPT

### Model

In [12]:
class Transformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.vocab_size is not None
        assert config.context_len is not None
        self.config = config

        # Each token is embedded into a vector
        self.token_embedding_table = nn.Embedding(config.vocab_size, config.n_embd)
        # Add a positional encoding embedding
        self.position_embedding_table = nn.Embedding(config.context_len, config.n_embd)
        # Dropout for embeddings
        self.dropout = nn.Dropout(config.dropout)
        # Set of transformer blocks
        self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
        # Final layer norm
        self.ln_f = nn.LayerNorm(config.n_embd, bias=config.bias)
        # Bring the embeddings back to vocab_size at the output
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        # Weight tying
        self.token_embedding_table.weight = self.lm_head.weight

        # Init all weights
        self.apply(self._init_weights)
        for pn, p in self.named_parameters():
            if pn.endswith('proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))


        # report number of parameters
        print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))

    def get_num_params(self, non_embedding=True):
        """
        Return the number of parameters in the model.
        For non-embedding count (default), the position embeddings get subtracted.
        The token embeddings would too, except due to the parameter sharing these
        params are actually used as weights in the final layer, so we include them.
        """
        n_params = sum(p.numel() for p in self.parameters())
        if non_embedding:
            n_params -= self.position_embedding_table.weight.numel()
        return n_params

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        # idx and targets are both (B,T) tensors of ints
        device = idx.device
        B, T = idx.size()
        assert T <= self.config.context_len, f"Cannot forward sequence of length {T}, context_len is only {self.config.context_len}"

        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        x = self.dropout(tok_emb + pos_emb) # (B,T,C)
        for block in self.blocks:
            x = block(x) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)

        # Calculate the cross entropy loss for each example
        if targets == None:
            # Optimization to only predict the next word for inference
            logits = self.lm_head(x[:,[-1],:]) # (B,1,C)
            loss = None
        else:
            logits = self.lm_head(x) # (B,T,C)
            B, T, C = logits.shape
            logits = logits.view(B*T, C) # reshape for cross entropy loss
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        """
        Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
        the sequence max_new_tokens times, feeding the predictions back into the model each time.
        Most likely you'll want to make sure to be in model.eval() mode of operation for this.
        """
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # get the predicted next tokens
            logits, _ = self(idx[:,-self.config.context_len:])
            # focus only on the last time step for generation (already done in model)
            logits = logits[:, -1, :] / temperature # (B, C)
            # optionally crop the logits to only the top k options
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            # apply softmax to get probabilities of next token
            probs = F.softmax(logits, dim=1) # (B, C)
            # sample a new token from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sample to sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)

        return idx

    def crop_block_size(self, block_size):
        """
        Model surgery to decrease the block size if necessary.
        e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)
        but want to use a smaller block size for some smaller, simpler model
        """
        assert block_size <= self.config.block_size
        self.config.block_size = block_size
        self.position_embedding_table.weight = nn.Parameter(self.position_embedding_table.weight[:block_size])
        for block in self.blocks:
            if hasattr(block.attn, 'bias'):
                block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]

    @classmethod
    def from_pretrained(cls, model_type, override_args=None):
        assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
        override_args = override_args or {} # default to empty dict
        # only dropout can be overridden see more notes below
        assert all(k == 'dropout' for k in override_args)
        from transformers import GPT2LMHeadModel
        print("loading weights from pretrained gpt: %s" % model_type)

        # n_layer, n_head and n_embd are determined from model_type
        config_args = {
            'gpt2':         dict(n_layer=12, n_head=12, n_embd=768),  # 124M params
            'gpt2-medium':  dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
            'gpt2-large':   dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
            'gpt2-xl':      dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
        }[model_type]
        print("forcing vocab_size=50257, block_size=1024, bias=True")
        config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
        config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
        config_args['bias'] = True # always True for GPT model checkpoints
        # we can override the dropout rate, if desired
        if 'dropout' in override_args:
            print(f"overriding dropout rate to {override_args['dropout']}")
            config_args['dropout'] = override_args['dropout']
        # create a from-scratch initialized minGPT model
        config = Config(**config_args)
        model = Transformer(config)
        sd = model.state_dict()
        sd_keys = sd.keys()
        sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param

        # init a huggingface/transformers model
        model_hf = GPT2LMHeadModel.from_pretrained(model_type)
        sd_hf = model_hf.state_dict()

        # copy while ensuring all of the parameters are aligned and match in names and shapes
        sd_keys_hf = sd_hf.keys()
        sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
        sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
        transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
        # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
        # this means that we have to transpose these weights when we import them
        assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
        for k in sd_keys_hf:
            if any(k.endswith(w) for w in transposed):
                # special treatment for the Conv1D weights we need to transpose
                assert sd_hf[k].shape[::-1] == sd[k].shape
                with torch.no_grad():
                    sd[k].copy_(sd_hf[k].t())
            else:
                # vanilla copy over the other parameters
                assert sd_hf[k].shape == sd[k].shape
                with torch.no_grad():
                    sd[k].copy_(sd_hf[k])

        return model

    def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
        # start with all of the candidate parameters
        param_dict = {pn: p for pn, p in self.named_parameters()}
        # filter out those that do not require grad
        param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
        # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
        # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
        optim_groups = [
            {'params': decay_params, 'weight_decay': weight_decay},
            {'params': nodecay_params, 'weight_decay': 0.0}
        ]
        num_decay_params = sum(p.numel() for p in decay_params)
        num_nodecay_params = sum(p.numel() for p in nodecay_params)
        print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
        print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
        # Create AdamW optimizer and use the fused version if it is available
        fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
        use_fused = fused_available and device_type == 'cuda'
        extra_args = dict(fused=True) if use_fused else dict()
        optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
        print(f"using fused AdamW: {use_fused}")

        return optimizer

    def estimate_mfu(self, fwdbwd_per_iter, dt):
        """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
        # first estimate the number of flops we do per iteration.
        # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
        N = self.get_num_params()
        cfg = self.config
        L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.context_len
        flops_per_token = 6*N + 12*L*H*Q*T
        flops_per_fwdbwd = flops_per_token * T
        flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
        # express our flops throughput as ratio of A100 bfloat16 peak flops
        flops_achieved = flops_per_iter * (1.0/dt) # per second
        flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
        mfu = flops_achieved / flops_promised
        return mfu

### Training Full

In [13]:
"""
This training script can be run both on a single gpu in debug mode,
and also in a larger training run with distributed data parallel (ddp).

To run on a single GPU, example:
$ python train.py --batch_size=32 --compile=False

To run with DDP on 4 gpus on 1 node, example:
$ torchrun --standalone --nproc_per_node=4 train.py

To run with DDP on 4 gpus across 2 nodes, example:
- Run on the first (master) node with example IP 123.456.123.456:
$ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=123.456.123.456 --master_port=1234 train.py
- Run on the worker node:
$ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=123.456.123.456 --master_port=1234 train.py
(If your cluster does not have Infiniband interconnect prepend NCCL_IB_DISABLE=1)
"""

import os
import time
import math
import pickle
from contextlib import nullcontext

import numpy as np
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group

# -----------------------------------------------------------------------------
# default config values designed to train a gpt2 (124M) on OpenWebText
# I/O
out_dir = 'out'
eval_interval = 500
log_interval = 500
eval_iters = 200
eval_only = False # if True, script exits right after the first eval
always_save_checkpoint = False # if True, always save a checkpoint after each eval
init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*'

# wandb logging
wandb_log = False # disabled by default
wandb_project = 'owt'
wandb_run_name = 'gpt2'

# data
dataset = 'input.txt' # 'openwebtext'
gradient_accumulation_steps = 5 * 8 # used to simulate larger batch sizes
batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size

# model
context_len: int = 8
n_embd: int = 32
n_head: int = 4
n_layer: int = 3
dropout: float = 0.2
bias: bool = False
flash_attn: bool = False

# adamw optimizer
learning_rate = 1e-3 # max learning rate
max_iters = 1000 # total number of training iterations
weight_decay = 1e-1
beta1 = 0.9
beta2 = 0.95
grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0

# learning rate decay settings
decay_lr = True # whether to decay the learning rate
warmup_iters = 500 # how many steps to warm up for
lr_decay_iters = max_iters # should be ~= max_iters per Chinchilla
min_lr = learning_rate / 10 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla

# DDP settings
backend = 'nccl' # 'nccl', 'gloo', etc.

# system
device = 'cuda' if torch.cuda.is_available() else 'cpu' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
compile = True # use PyTorch 2.0 to compile the model to be faster

In [14]:
# various inits, derived attributes, I/O setup
ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
if ddp:
    init_process_group(backend=backend)
    ddp_rank = int(os.environ['RANK'])
    ddp_local_rank = int(os.environ['LOCAL_RANK'])
    ddp_world_size = int(os.environ['WORLD_SIZE'])
    device = f'cuda:{ddp_local_rank}'
    torch.cuda.set_device(device)
    master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
    seed_offset = ddp_rank # each process gets a different seed
    # world_size number of processes will be training simultaneously, so we can scale
    # down the desired gradient accumulation iterations per process proportionally
    assert gradient_accumulation_steps % ddp_world_size == 0
    gradient_accumulation_steps //= ddp_world_size
else:
    # if not ddp, we are running on a single gpu, and one process
    master_process = True
    seed_offset = 0
    ddp_world_size = 1

tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * context_len
print(f"tokens per iteration will be: {tokens_per_iter:,}")

if master_process:
    os.makedirs(out_dir, exist_ok=True)

torch.manual_seed(1337 + seed_offset)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast

# note: float16 data type will automatically use a GradScaler
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)


tokens per iteration will be: 3,840


In [17]:
# poor man's data loader
data_dir = os.path.join('data', dataset)
# def get_batch(split):
#     # We recreate np.memmap every batch to avoid a memory leak, as per
#     # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
#     if split == 'train':
#         data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
#     else:
#         data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
#     ix = torch.randint(len(data) - context_len, (batch_size,))
#     x = torch.stack([torch.from_numpy((data[i:i+context_len]).astype(np.int64)) for i in ix])
#     y = torch.stack([torch.from_numpy((data[i+1:i+1+context_len]).astype(np.int64)) for i in ix])
#     if device_type == 'cuda':
#         # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
#         x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
#     else:
#         x, y = x.to(device), y.to(device)
#     return x, y

with open ('input.txt', 'r', encoding='utf-8)') as f:
  text = f.read()

# Define vocabulary
chars = sorted(list(set(text)))
vocab_size = len(chars)

# Create encoder and decoder
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i,ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

# Data batch loading
def get_batch(split):
  # Generate small batch of data of inputs x and targets y
  data = train_data if split == 'train' else val_data
  ix = torch.randint(len(data) - context_len, (batch_size,))
  x = torch.stack([data[i:i+context_len] for i in ix])
  y = torch.stack([data[i+1:i+context_len+1] for i in ix])
  x, y = x.to(device), y.to(device)
  return x, y

In [18]:
# init these up here, can override if init_from='resume' (i.e. from a checkpoint)
iter_num = 0
best_val_loss = 1e9

# attempt to derive vocab_size from the dataset
meta_path = os.path.join(data_dir, 'meta.pkl')
meta_vocab_size = None
if os.path.exists(meta_path):
    with open(meta_path, 'rb') as f:
        meta = pickle.load(f)
    meta_vocab_size = meta['vocab_size']
    print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})")

In [19]:
# model init
model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, context_len=context_len,
                  bias=bias, vocab_size=None, dropout=dropout, flash_attn=flash_attn) # start with model_args from command line
if init_from == 'scratch':
    # init a new model from scratch
    print("Initializing a new model from scratch")
    # determine the vocab size we'll use for from-scratch training
    if meta_vocab_size is None:
        print(f"defaulting to vocab_size of GPT-2 to {vocab_size}")
    model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else vocab_size # 50304
    config = Config(**model_args)
    model = Transformer(config)
elif init_from == 'resume':
    print(f"Resuming training from {out_dir}")
    # resume training from a checkpoint.
    ckpt_path = os.path.join(out_dir, 'ckpt.pt')
    checkpoint = torch.load(ckpt_path, map_location=device)
    checkpoint_model_args = checkpoint['model_args']
    # force these config attributes to be equal otherwise we can't even resume training
    # the rest of the attributes (e.g. dropout) can stay as desired from command line
    for k in ['n_layer', 'n_head', 'n_embd', 'context_len', 'bias', 'vocab_size']:
        model_args[k] = checkpoint_model_args[k]
    # create the model
    config = Config(**model_args)
    model = Transformer(config)
    state_dict = checkpoint['model']
    # fix the keys of the state dictionary :(
    # honestly no idea how checkpoints sometimes get this prefix, have to debug more
    unwanted_prefix = '_orig_mod.'
    for k,v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    model.load_state_dict(state_dict)
    iter_num = checkpoint['iter_num']
    best_val_loss = checkpoint['best_val_loss']
elif init_from.startswith('gpt2'):
    print(f"Initializing from OpenAI GPT-2 weights: {init_from}")
    # initialize from OpenAI GPT-2 weights
    override_args = dict(dropout=dropout)
    model = Transformer.from_pretrained(init_from, override_args)
    # read off the created config params, so we can store them into checkpoint correctly
    for k in ['n_layer', 'n_head', 'n_embd', 'context_len', 'bias', 'vocab_size']:
        model_args[k] = getattr(model.config, k)

# crop down the model block size if desired, using model surgery
if context_len < model.config.context_len:
    model.crop_context_len(context_len)
    model_args['context_len'] = context_len # so that the checkpoint will have the right value
model.to(device)

# initialize a GradScaler. If enabled=False scaler is a no-op
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))

# optimizer
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
if init_from == 'resume':
    optimizer.load_state_dict(checkpoint['optimizer'])
checkpoint = None # free up memory

# compile the model
if compile:
    print("compiling the model... (takes a ~minute)")
    unoptimized_model = model
    model = torch.compile(model) # requires PyTorch 2.0

# wrap model into DDP container
if ddp:
    model = DDP(model, device_ids=[ddp_local_rank])

Initializing a new model from scratch
defaulting to vocab_size of GPT-2 to 65
number of parameters: 0.04M
num decayed parameter tensors: 14, with 39,200 parameters
num non-decayed parameter tensors: 7, with 224 parameters
using fused AdamW: False
compiling the model... (takes a ~minute)




In [20]:
# helps estimate an arbitrarily accurate loss over either split using many batches
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            with ctx:
                logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

# learning rate decay scheduler (cosine with warmup)
def get_lr(it):
    # 1) linear warmup for warmup_iters steps
    if it < warmup_iters:
        return learning_rate * it / warmup_iters
    # 2) if it > lr_decay_iters, return min learning rate
    if it > lr_decay_iters:
        return min_lr
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
    return min_lr + coeff * (learning_rate - min_lr)

# logging
if wandb_log and master_process:
    import wandb
    wandb.init(project=wandb_project, name=wandb_run_name, config=config)

In [21]:
# training loop
X, Y = get_batch('train') # fetch the very first batch
t0 = time.time()
local_iter_num = 0 # number of iterations in the lifetime of this process
raw_model = model.module if ddp else model # unwrap DDP container if needed
running_mfu = -1.0
while True:

    # determine and set the learning rate for this iteration
    lr = get_lr(iter_num) if decay_lr else learning_rate
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    # evaluate the loss on train/val sets and write checkpoints
    if iter_num % eval_interval == 0 and master_process:
        losses = estimate_loss()
        print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        if wandb_log:
            wandb.log({
                "iter": iter_num,
                "train/loss": losses['train'],
                "val/loss": losses['val'],
                "lr": lr,
                "mfu": running_mfu*100, # convert to percentage
            })
        if losses['val'] < best_val_loss or always_save_checkpoint:
            best_val_loss = losses['val']
            if iter_num > 0:
                checkpoint = {
                    'model': raw_model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'model_args': model_args,
                    'iter_num': iter_num,
                    'best_val_loss': best_val_loss,
                    'config': config,
                }
                print(f"saving checkpoint to {out_dir}")
                torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
    if iter_num == 0 and eval_only:
        break

    # forward backward update, with optional gradient accumulation to simulate larger batch size
    # and using the GradScaler if data type is float16
    for micro_step in range(gradient_accumulation_steps):
        if ddp:
            # in DDP training we only need to sync gradients at the last micro step.
            # the official way to do this is with model.no_sync() context manager, but
            # I really dislike that this bloats the code and forces us to repeat code
            # looking at the source of that context manager, it just toggles this variable
            model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
        with ctx:
            logits, loss = model(X, Y)
            loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation
        # immediately async prefetch next batch while model is doing the forward pass on the GPU
        X, Y = get_batch('train')
        # backward pass, with gradient scaling if training in fp16
        scaler.scale(loss).backward()
    # clip the gradient
    if grad_clip != 0.0:
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    # step the optimizer and scaler if training in fp16
    scaler.step(optimizer)
    scaler.update()
    # flush the gradients as soon as we can, no need for this memory anymore
    optimizer.zero_grad(set_to_none=True)

    # timing and logging
    t1 = time.time()
    dt = t1 - t0
    t0 = t1
    if iter_num % log_interval == 0 and master_process:
        # get loss as float. note: this is a CPU-GPU sync point
        # scale up to undo the division above, approximating the true total loss (exact would have been a sum)
        lossf = loss.item() * gradient_accumulation_steps
        if local_iter_num >= 5: # let the training loop settle a bit
            mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
            running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
        print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%")
    iter_num += 1
    local_iter_num += 1

    # termination conditions
    if iter_num > max_iters:
        break

if ddp:
    destroy_process_group()

No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'


step 0: train loss 4.1859, val loss 4.1853
iter 0: loss 4.1935, time 99632.64ms, mfu -100.00%
step 500: train loss 2.4877, val loss 2.4647
saving checkpoint to out
iter 500: loss 2.3937, time 799.66ms, mfu 0.00%
step 1000: train loss 2.3400, val loss 2.3360
saving checkpoint to out
iter 1000: loss 2.3475, time 1266.48ms, mfu 0.00%


### Benchmark

In [22]:
"""
A much shorter version of training for benchmarking
"""
import os
from contextlib import nullcontext
import numpy as np
import time
import torch

# -----------------------------------------------------------------------------
# model
context_len: int = 8
n_embd: int = 32
n_head: int = 4
n_layer: int = 3
dropout: float = 0.2
bias: bool = False
flash_attn: bool = False
vocab_size = 65

real_data = False
profile = True
seed = 1337
device = 'cuda' if torch.cuda.is_available() else 'cpu' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
compile = True # use PyTorch 2.0 to compile the model to be faster
# -----------------------------------------------------------------------------

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

# data loading init
if real_data:
    dataset = 'openwebtext'
    data_dir = os.path.join('data', dataset)
    train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
    def get_batch(split):
        data = train_data # note ignore split in benchmarking script
        ix = torch.randint(len(data) - context_len, (batch_size,))
        x = torch.stack([torch.from_numpy((data[i:i+context_len]).astype(np.int64)) for i in ix])
        y = torch.stack([torch.from_numpy((data[i+1:i+1+context_len]).astype(np.int64)) for i in ix])
        x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
        return x, y
else:
    # alternatively, if fixed data is desired to not care about data loading
    x = torch.randint(vocab_size, (batch_size, context_len), device=device)
    y = torch.randint(vocab_size, (batch_size, context_len), device=device)
    get_batch = lambda split: (x, y)

# model init
model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, context_len=context_len,
                  bias=bias, vocab_size=vocab_size, dropout=dropout, flash_attn=flash_attn) # start with model_args from command line
config = Config(**model_args)
model = Transformer(config)

optimizer = model.configure_optimizers(weight_decay=1e-2, learning_rate=1e-4, betas=(0.9, 0.95), device_type=device_type)

if compile:
    print("Compiling model...")
    model = torch.compile(model) # pytorch 2.0

if profile:
    # useful docs on pytorch profiler:
    # - tutorial https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html
    # - api https://pytorch.org/docs/stable/profiler.html#torch.profiler.profile
    wait, warmup, active = 5, 5, 5
    num_steps = wait + warmup + active
    with torch.profiler.profile(
        activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
        schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=1),
        on_trace_ready=torch.profiler.tensorboard_trace_handler('./bench_log'),
        record_shapes=False,
        profile_memory=False,
        with_stack=False, # incurs an additional overhead, disable if not needed
        with_flops=True,
        with_modules=False, # only for torchscript models atm
    ) as prof:

        X, Y = get_batch('train')
        for k in range(num_steps):
            with ctx:
                logits, loss = model(X, Y)
            X, Y = get_batch('train')
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()
            lossf = loss.item()
            print(f"{k}/{num_steps} loss: {lossf:.4f}")

            prof.step() # notify the profiler at end of each step

else:

    # simple benchmarking
    torch.cuda.synchronize()
    for stage, num_steps in enumerate([10, 20]): # burnin, then benchmark
        t0 = time.time()
        X, Y = get_batch('train')
        for k in range(num_steps):
            with ctx:
                logits, loss = model(X, Y)
            X, Y = get_batch('train')
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()
            lossf = loss.item()
            print(f"{k}/{num_steps} loss: {lossf:.4f}")
        torch.cuda.synchronize()
        t1 = time.time()
        dt = t1-t0
        mfu = model.estimate_mfu(batch_size * 1 * num_steps, dt)
        if stage == 1:
            print(f"time per iteration: {dt/num_steps*1000:.4f}ms, MFU: {mfu*100:.2f}%")

number of parameters: 0.04M
num decayed parameter tensors: 14, with 39,200 parameters
num non-decayed parameter tensors: 7, with 224 parameters
using fused AdamW: False
Compiling model...
0/15 loss: 4.1906
1/15 loss: 4.1784
2/15 loss: 4.1711
3/15 loss: 4.1773
4/15 loss: 4.1589
5/15 loss: 4.1610
6/15 loss: 4.1455
7/15 loss: 4.1403
8/15 loss: 4.1380
9/15 loss: 4.1252
10/15 loss: 4.1281
11/15 loss: 4.1289
12/15 loss: 4.1158
13/15 loss: 4.1070
14/15 loss: 4.1037


  warn("CUDA is not available, disabling CUDA profiling")
