In [1]:
!git clone https://github.com/karpathy/nanoGPT.git
%cd nanoGPT
!pip install tiktoken

Cloning into 'nanoGPT'...
remote: Enumerating objects: 689, done.[K
remote: Total 689 (delta 0), reused 0 (delta 0), pack-reused 689 (from 1)[K
Receiving objects: 100% (689/689), 975.24 KiB | 16.25 MiB/s, done.
Resolving deltas: 100% (382/382), done.
/content/nanoGPT


In [2]:
%%writefile student_model.py
"""
RWKV-inspired Student model compatible with nanoGPT's training infrastructure.
GPU-optimized with vectorized WKV computation.
"""

import math
import inspect
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F


@dataclass
class StudentConfig:
  block_size: int = 1024
  vocab_size: int = 50304
  n_layer: int = 5
  n_embd: int = 128
  dropout: float = 0.1
  bias: bool = False


class RWKVBlock(nn.Module):
  def __init__(self, config, layer_id):
      super().__init__()
      dim = config.n_embd
      self.layer_id = layer_id

      self.ln1 = nn.LayerNorm(dim)
      self.time_decay = nn.Parameter(torch.ones(dim) * -5.0)
      self.time_first = nn.Parameter(torch.ones(dim) * math.log(0.3))

      self.time_mix_k = nn.Parameter(torch.ones(1, 1, dim) * 0.5)
      self.time_mix_v = nn.Parameter(torch.ones(1, 1, dim) * 0.5)
      self.time_mix_r = nn.Parameter(torch.ones(1, 1, dim) * 0.5)

      self.key = nn.Linear(dim, dim, bias=False)
      self.value = nn.Linear(dim, dim, bias=False)
      self.receptance = nn.Linear(dim, dim, bias=False)
      self.output = nn.Linear(dim, dim, bias=False)

      self.ln2 = nn.LayerNorm(dim)
      self.channel_mix_k = nn.Parameter(torch.ones(1, 1, dim) * 0.5)
      self.channel_mix_r = nn.Parameter(torch.ones(1, 1, dim) * 0.5)

      hidden = int(dim * 2.5)
      self.ffn_key = nn.Linear(dim, hidden, bias=False)
      self.ffn_value = nn.Linear(hidden, dim, bias=False)
      self.ffn_receptance = nn.Linear(dim, dim, bias=False)

      self.drop = nn.Dropout(config.dropout)

  def forward(self, x):
      B, T, D = x.shape

      residual = x
      x_norm = self.ln1(x)
      x_prev = F.pad(x_norm, (0, 0, 1, 0))[:, :-1, :]

      xk = x_norm * self.time_mix_k + x_prev * (1 - self.time_mix_k)
      xv = x_norm * self.time_mix_v + x_prev * (1 - self.time_mix_v)
      xr = x_norm * self.time_mix_r + x_prev * (1 - self.time_mix_r)

      k = self.key(xk)
      v = self.value(xv)
      r = torch.sigmoid(self.receptance(xr))

      wkv = self.wkv_compute(k, v)
      x = residual + self.drop(self.output(r * wkv))

      residual = x
      x_norm = self.ln2(x)
      x_prev = F.pad(x_norm, (0, 0, 1, 0))[:, :-1, :]

      xk = x_norm * self.channel_mix_k + x_prev * (1 - self.channel_mix_k)
      xr = x_norm * self.channel_mix_r + x_prev * (1 - self.channel_mix_r)

      k = torch.square(F.relu(self.ffn_key(xk)))
      kv = self.ffn_value(k)
      r = torch.sigmoid(self.ffn_receptance(xr))

      x = residual + self.drop(r * kv)
      return x

  def wkv_compute(self, k, v):
      """Vectorized WKV computation - GPU friendly."""
      B, T, D = k.shape
      w = torch.exp(-torch.exp(self.time_decay))  # (D,)
      u = torch.exp(self.time_first)  # (D,)

      ek = torch.exp(k)  # (B, T, D)
      ekv = ek * v  # (B, T, D)

      # Build powers of w for rescaling: w^0, w^1, ..., w^{T-1}
      t_idx = torch.arange(T, device=k.device, dtype=k.dtype)
      w_powers = w.unsqueeze(0) ** t_idx.unsqueeze(1)  # (T, D)
      w_inv_powers = 1.0 / (w_powers + 1e-10)  # (T, D)

      # Rescale to remove exponential decay: x_scaled[t] = w^{-t} * x[t]
      ekv_scaled = ekv * w_inv_powers.unsqueeze(0)  # (B, T, D)
      ek_scaled = ek * w_inv_powers.unsqueeze(0)  # (B, T, D)

      # Cumulative sum in scaled domain
      ekv_cumsum = torch.cumsum(ekv_scaled, dim=1)  # (B, T, D)
      ek_cumsum = torch.cumsum(ek_scaled, dim=1)  # (B, T, D)

      # Shift right to get "previous" accumulated state (exclusive prefix sum)
      zeros = torch.zeros(B, 1, D, device=k.device, dtype=k.dtype)
      ekv_cumsum_prev = torch.cat([zeros, ekv_cumsum[:, :-1, :]], dim=1)
      ek_cumsum_prev = torch.cat([zeros, ek_cumsum[:, :-1, :]], dim=1)

      # Rescale back by w^{t-1} (shifted powers)
      w_powers_prev = torch.cat([torch.ones(1, D, device=k.device, dtype=k.dtype), w_powers[:-1, :]], dim=0)
      a_prev = ekv_cumsum_prev * w_powers_prev.unsqueeze(0)  # (B, T, D)
      b_prev = ek_cumsum_prev * w_powers_prev.unsqueeze(0)  # (B, T, D)

      # Final WKV: (a_prev + u * ek * v) / (b_prev + u * ek + eps)
      numer = a_prev + u * ekv
      denom = b_prev + u * ek + 1e-8
      wkv = numer / denom

      return wkv


class Student(nn.Module):
  def __init__(self, config):
      super().__init__()
      self.config = config

      self.transformer = nn.ModuleDict(dict(
          wte = nn.Embedding(config.vocab_size, config.n_embd),
          wpe = nn.Embedding(config.block_size, config.n_embd),
          drop = nn.Dropout(config.dropout),
          h = nn.ModuleList([RWKVBlock(config, layer_id=i) for i in range(config.n_layer)]),
          ln_f = nn.LayerNorm(config.n_embd),
      ))
      self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
      self.transformer.wte.weight = self.lm_head.weight

      self.apply(self._init_weights)
      print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))

  def get_num_params(self, non_embedding=True):
      n_params = sum(p.numel() for p in self.parameters())
      if non_embedding:
          n_params -= self.transformer.wpe.weight.numel()
      return n_params

  def _init_weights(self, module):
      if isinstance(module, nn.Linear):
          torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
          if module.bias is not None:
              torch.nn.init.zeros_(module.bias)
      elif isinstance(module, nn.Embedding):
          torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

  def forward(self, idx, targets=None):
      device = idx.device
      b, t = idx.size()
      assert t <= self.config.block_size
      pos = torch.arange(0, t, dtype=torch.long, device=device)

      tok_emb = self.transformer.wte(idx)
      pos_emb = self.transformer.wpe(pos)
      x = self.transformer.drop(tok_emb + pos_emb)

      for block in self.transformer.h:
          x = block(x)

      x = self.transformer.ln_f(x)

      if targets is not None:
          logits = self.lm_head(x)
          loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
      else:
          logits = self.lm_head(x[:, [-1], :])
          loss = None

      return logits, loss

  def crop_block_size(self, block_size):
      assert block_size <= self.config.block_size
      self.config.block_size = block_size
      self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])

  def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
      param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad}
      decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
      nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]

      optim_groups = [
          {'params': decay_params, 'weight_decay': weight_decay},
          {'params': nodecay_params, 'weight_decay': 0.0}
      ]

      num_decay_params = sum(p.numel() for p in decay_params)
      num_nodecay_params = sum(p.numel() for p in nodecay_params)
      print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
      print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")

      fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
      use_fused = fused_available and device_type == 'cuda'
      extra_args = dict(fused=True) if use_fused else dict()
      optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
      print(f"using fused AdamW: {use_fused}")

      return optimizer

  def estimate_mfu(self, fwdbwd_per_iter, dt):
      N = self.get_num_params()
      cfg = self.config
      flops_per_token = 6 * N
      flops_per_fwdbwd = flops_per_token * cfg.block_size
      flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
      flops_achieved = flops_per_iter * (1.0/dt)
      flops_promised = 312e12
      mfu = flops_achieved / flops_promised
      return mfu

  @torch.no_grad()
  def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
      for _ in range(max_new_tokens):
          idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
          logits, _ = self(idx_cond)
          logits = logits[:, -1, :] / temperature
          if top_k is not None:
              v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
              logits[logits < v[:, [-1]]] = -float('Inf')
          probs = F.softmax(logits, dim=-1)
          idx_next = torch.multinomial(probs, num_samples=1)
          idx = torch.cat((idx, idx_next), dim=1)
      return idx


Writing student_model.py


In [3]:
%%writefile config/train_shakespeare_student.py
out_dir = 'out-shakespeare-student'
eval_interval = 250
eval_iters = 200
log_interval = 10

always_save_checkpoint = False
wandb_log = False

dataset = 'shakespeare_char'
gradient_accumulation_steps = 1
batch_size = 64
block_size = 256

n_layer = 5
n_embd = 128
dropout = 0.1

model_type = 'student'

learning_rate = 1e-3
max_iters = 5000
lr_decay_iters = 5000
min_lr = 1e-4
beta2 = 0.99
warmup_iters = 100


Writing config/train_shakespeare_student.py


In [4]:
# Add Student model import and support
import fileinput
import sys

# Read train.py
with open('train.py', 'r') as f:
    content = f.read()

# Add import
content = content.replace(
    "from model import GPTConfig, GPT",
    "from model import GPTConfig, GPT\nfrom student_model import StudentConfig, Student"
)

# Add model_type default
content = content.replace(
    "bias = False # do we use bias inside LayerNorm and Linear layers?",
    "bias = False # do we use bias inside LayerNorm and Linear layers?\nmodel_type = 'gpt' # 'gpt' or 'student'"
)

# Add Student model init
old_init = '''if init_from == 'scratch':
    # init a new model from scratch
    print("Initializing a new model from scratch")
    # determine the vocab size we'll use for from-scratch training
    if meta_vocab_size is None:
        print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)")
    model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304
    gptconf = GPTConfig(**model_args)
    model = GPT(gptconf)'''

new_init = '''if init_from == 'scratch':
    # init a new model from scratch
    print("Initializing a new model from scratch")
    # determine the vocab size we'll use for from-scratch training
    if meta_vocab_size is None:
        print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)")
    model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304
    if model_type == 'student':
        student_args = {k: v for k, v in model_args.items() if k in ['n_layer', 'n_embd', 'block_size', 'vocab_size', 'dropout']}
        conf = StudentConfig(**student_args)
        model = Student(conf)
    else:
        gptconf = GPTConfig(**model_args)
        model = GPT(gptconf)'''

content = content.replace(old_init, new_init)

content = content.replace(
    "if iter_num % eval_interval == 0 and master_process:",
    "if iter_num > 0 and iter_num % eval_interval == 0 and master_process:"
)

with open('train.py', 'w') as f:
    f.write(content)

print("train.py patched!")


train.py patched!


In [11]:
%%writefile sample.py
"""
Sample from a trained model
"""
import os
import pickle
from contextlib import nullcontext
import torch
import tiktoken
from model import GPTConfig, GPT
from student_model import StudentConfig, Student

# -----------------------------------------------------------------------------
init_from = 'resume'
out_dir = 'out-shakespeare-student'
start = "\n"
num_samples = 3
max_new_tokens = 500
temperature = 0.8
top_k = 200
seed = 1337
device = 'cuda'
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
compile = False
exec(open('configurator.py').read())
# -----------------------------------------------------------------------------

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
device_type = 'cuda' if 'cuda' in device else 'cpu'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

# model
if init_from == 'resume':
    ckpt_path = os.path.join(out_dir, 'ckpt.pt')
    checkpoint = torch.load(ckpt_path, map_location=device)
    model_args = checkpoint['model_args']

    # Detect model type
    model_type = checkpoint.get('model_type', None)
    if model_type is None:
        state_keys = list(checkpoint['model'].keys())
        model_type = 'student' if any('time_decay' in k for k in state_keys) else 'gpt'

    if model_type == 'student':
        student_args = {k: v for k, v in model_args.items() if k in ['n_layer', 'n_embd', 'block_size', 'vocab_size', 'dropout']}
        conf = StudentConfig(**student_args)
        model = Student(conf)
    else:
        gptconf = GPTConfig(**model_args)
        model = GPT(gptconf)

    state_dict = checkpoint['model']
    unwanted_prefix = '_orig_mod.'
    for k,v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    model.load_state_dict(state_dict)

elif init_from.startswith('gpt2'):
    model = GPT.from_pretrained(init_from, dict(dropout=0.0))

model.eval()
model.to(device)
if compile:
    model = torch.compile(model)

# look for the meta pickle in case it is available in the dataset folder
load_meta = False
if init_from == 'resume' and 'config' in checkpoint and 'dataset' in checkpoint['config']:
    meta_path = os.path.join('data', checkpoint['config']['dataset'], 'meta.pkl')
    load_meta = os.path.exists(meta_path)
if load_meta:
    print(f"Loading meta from {meta_path}...")
    with open(meta_path, 'rb') as f:
        meta = pickle.load(f)
    stoi, itos = meta['stoi'], meta['itos']
    encode = lambda s: [stoi[c] for c in s]
    decode = lambda l: ''.join([itos[i] for i in l])
else:
    print("No meta.pkl found, assuming GPT-2 encodings...")
    enc = tiktoken.get_encoding("gpt2")
    encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
    decode = lambda l: enc.decode(l)

# encode the beginning of the prompt
if start.startswith('FILE:'):
    with open(start[5:], 'r', encoding='utf-8') as f:
        start = f.read()
start_ids = encode(start)
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])

# run generation
with torch.no_grad():
    with ctx:
        for k in range(num_samples):
            y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
            print(decode(y[0].tolist()))
            print('---------------')

Overwriting sample.py


In [5]:
!python data/shakespeare_char/prepare.py

length of dataset in characters: 1,115,394
all the unique characters: 
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
vocab size: 65
train has 1,003,854 tokens
val has 111,540 tokens


# Train the Models

In [6]:
!python train.py config/train_shakespeare_student.py --device=cuda --compile=False --log_interval=10 --max_iters=1000

Overriding config with config/train_shakespeare_student.py:
out_dir = 'out-shakespeare-student'
eval_interval = 250
eval_iters = 200
log_interval = 10

always_save_checkpoint = False
wandb_log = False

dataset = 'shakespeare_char'
gradient_accumulation_steps = 1
batch_size = 64
block_size = 256

n_layer = 5
n_embd = 128
dropout = 0.1

model_type = 'student'

learning_rate = 1e-3
max_iters = 5000
lr_decay_iters = 5000
min_lr = 1e-4
beta2 = 0.99
warmup_iters = 100

Overriding: device = cuda
Overriding: compile = False
Overriding: log_interval = 10
Overriding: max_iters = 1000
tokens per iteration will be: 16,384
  self.setter(val)
found vocab_size = 65 (inside data/shakespeare_char/meta.pkl)
Initializing a new model from scratch
number of parameters: 0.83M
  scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
num decayed parameter tensors: 62, with 863,488 parameters
num non-decayed parameter tensors: 32, with 4,096 parameters
using fused AdamW: True
iter 0: loss 4.2015, tim

In [7]:
!python train.py config/train_shakespeare_char.py --device=cuda --compile=False --log_interval=10 --max_iters=1000

Overriding config with config/train_shakespeare_char.py:
# train a miniature character-level shakespeare model
# good for debugging and playing on macbooks and such

out_dir = 'out-shakespeare-char'
eval_interval = 250 # keep frequent because we'll overfit
eval_iters = 200
log_interval = 10 # don't print too too often

# we expect to overfit on this small dataset, so only save when val improves
always_save_checkpoint = False

wandb_log = False # override via command line if you like
wandb_project = 'shakespeare-char'
wandb_run_name = 'mini-gpt'

dataset = 'shakespeare_char'
gradient_accumulation_steps = 1
batch_size = 64
block_size = 256 # context of up to 256 previous characters

# baby GPT model :)
n_layer = 6
n_head = 6
n_embd = 384
dropout = 0.2

learning_rate = 1e-3 # with baby networks can afford to go a bit higher
max_iters = 5000
lr_decay_iters = 5000 # make equal to max_iters usually
min_lr = 1e-4 # learning_rate / 10 usually
beta2 = 0.99 # make a bit bigger because number of 

# Sample the Models

In [12]:
!python sample.py --out_dir=out-shakespeare-student --start="ROMEO:"

Overriding: out_dir = out-shakespeare-student
Overriding: start = ROMEO:
  self.setter(val)
number of parameters: 0.83M
Loading meta from data/shakespeare_char/meta.pkl...
ROMEO:
You shall you will be your some is to Barnar:
What then? what's he would yet will have?

MERCUTIO:
Why crown?

PRINCE EDWARD:
You shall be a come mile will, and it end.

LUCIO:
Your drop the countreater of you so. Ah, no mean.

KING RICHARD III:
Marry, let me love.

RICHARD:
Ay, sir, thy world that more not what evil sound,
For this cheek me strong of his guilt and thrust,
Is no worse my mother all of which Prince,
So murder some still is a hope of the hearth,
Which as I tear the such not carel
---------------
ROMEO:
Mine you shall my soul that you so you myself and
The world madam stand against the world.

DUKE VINCENTIO:
Why, castle. And you not peace thou ever
About over the pass'd of this moved by mind the heart;
Which of you to prove his heart, if I content me
As wantone her against the way soul.

Clown:


In [13]:
!python sample.py --out_dir=out-shakespeare-char --start="ROMEO:"

Overriding: out_dir = out-shakespeare-char
Overriding: start = ROMEO:
  self.setter(val)
number of parameters: 10.65M
Loading meta from data/shakespeare_char/meta.pkl...
ROMEO:
All the bride will the comple
And the way down will desire to a
heaven thee us are to be determitted with my father,
And I must be proof in heart my heart.

CLIFFORD:
In ministress in the son, and let him be joy
Well not peace thee won him from than speak you:
That I may perform to be the that that more not when evil
To the trouble hand him from of my heart
To known breathe son; if he must not hale got forget.

CAMILLO:
Why, so hard you saw his brother'd state: but knew and
His hand, and for hi
---------------
ROMEO:
One to the marriage of the measure,
Before the noble two that take the crown
But not us but the much friends.

DUKE VINCENTIO:
And let me stronged the country's death?

LUCIO:
For over some so under in meetimage and come,
Against his wife his view, or I change beauty against thee;
And she they long 