In [None]:
from google.colab import drive

drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
%cd /content/drive/MyDrive/University (M.S., Stanford)/CS 229S - Systems for ML/cs229s-nanoGPT-rmg

/content/drive/MyDrive/University (M.S., Stanford)/CS 229S - Systems for ML/cs229s-nanoGPT-rmg


In [None]:
# !pip install torch numpy transformers datasets tiktoken wandb tqdm memory-profiler torcheval

In [None]:
# !python data/wikitext/prepare.py

In [None]:
import os
import time
import math
import pickle
from contextlib import nullcontext

import numpy as np
import torch

from model import GPTConfig, GPT
from pruning import convert_to_prunable, compress_layers, PrunableLinear

# -----------------------------------------------------------------------------
# default config values designed to train a gpt2 (124M) on OpenWebText
# I/O
out_dir = 'out'
eval_interval = 100
log_interval = 1
eval_iters = 200
eval_only = False # if True, script exits right after the first eval
always_save_checkpoint = False # if True, always save a checkpoint after each eval
init_from = 'resume' # 'gpt2-medium' # 'scratch' or 'resume' or 'gpt2*'
# wandb logging
wandb_log = False # disabled by default
wandb_project = 'cs229s'
wandb_run_name = 'gpt2' # 'run' + str(time.time())
# data
dataset = 'wikitext'
gradient_accumulation_steps = 5 * 8 # used to simulate larger batch sizes
batch_size = 4 # if gradient_accumulation_steps > 1, this is the micro-batch size
block_size = 1024
# model
n_layer = 12
n_head = 12
n_embd = 768
dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
bias = False # do we use bias inside LayerNorm and Linear layers?
# adamw optimizer
learning_rate = 6e-4 # max learning rate
max_iters = 100 # total number of training iterations
weight_decay = 1e-1
beta1 = 0.9
beta2 = 0.95
grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
# learning rate decay settings
decay_lr = True # whether to decay the learning rate
warmup_iters = 2000 # how many steps to warm up for
lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla
min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
# DDP settings
backend = 'nccl' # 'nccl', 'gloo', etc.
# system
device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
# -----------------------------------------------------------------------------
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
config = {k: globals()[k] for k in config_keys} # will be useful for logging
# -----------------------------------------------------------------------------


# if not ddp, we are running on a single gpu, and one process
seed_offset = 0
tokens_per_iter = gradient_accumulation_steps * batch_size * block_size
print(f"tokens per iteration will be: {tokens_per_iter:,}")

os.makedirs(out_dir, exist_ok=True)
torch.manual_seed(1337 + seed_offset)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
# note: float16 data type will automatically use a GradScaler
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

# poor man's data loader
data_dir = os.path.join('data', dataset)
train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
    if device_type == 'cuda':
        # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
        x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
    else:
        x, y = x.to(device), y.to(device)
    return x, y

# init these up here, can override if init_from='resume' (i.e. from a checkpoint)
iter_num = 0
best_val_loss = 1e9

# attempt to derive vocab_size from the dataset
meta_path = os.path.join(data_dir, 'meta.pkl')
meta_vocab_size = None
if os.path.exists(meta_path):
    with open(meta_path, 'rb') as f:
        meta = pickle.load(f)
    meta_vocab_size = meta['vocab_size']
    print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})")

# model init
model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size,
                  bias=bias, vocab_size=None, dropout=dropout) # start with model_args from command line
if init_from == 'scratch':
    # init a new model from scratch
    print("Initializing a new model from scratch")
    # determine the vocab size we'll use for from-scratch training
    if meta_vocab_size is None:
        print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)")
    model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304
    gptconf = GPTConfig(**model_args)
    model = GPT(gptconf)
elif init_from == 'resume':
    print(f"Resuming training from {out_dir}")
    # resume training from a checkpoint.
    ckpt_path = os.path.join(out_dir, 'ckpt_500_pruned_0.42.pt')
    checkpoint = torch.load(ckpt_path, map_location=device)
    checkpoint_model_args = checkpoint['model_args']
    # force these config attributes to be equal otherwise we can't even resume training
    # the rest of the attributes (e.g. dropout) can stay as desired from command line
    for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
        model_args[k] = checkpoint_model_args[k]
    # create the model
    gptconf = GPTConfig(**model_args)
    model = GPT(gptconf)
    convert_to_prunable(model, device=device_type)
    state_dict = checkpoint['model']
    # fix the keys of the state dictionary :(
    # honestly no idea how checkpoints sometimes get this prefix, have to debug more
    unwanted_prefix = '_orig_mod.'
    for k,v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    model.load_state_dict(state_dict)
    iter_num = checkpoint['iter_num']
    best_val_loss = checkpoint['best_val_loss']
elif init_from.startswith('gpt2'):
    print(f"Initializing from OpenAI GPT-2 weights: {init_from}")
    # initialize from OpenAI GPT-2 weights
    override_args = dict(dropout=dropout)
    model = GPT.from_pretrained(init_from, override_args)
    # read off the created config params, so we can store them into checkpoint correctly
    for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
        model_args[k] = getattr(model.config, k)
# crop down the model block size if desired, using model surgery
if block_size < model.config.block_size:
    model.crop_block_size(block_size)
    model_args['block_size'] = block_size # so that the checkpoint will have the right value
model.to(device)

# initialize a GradScaler. If enabled=False scaler is a no-op
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))

# optimizer
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
if init_from == 'resume':
    optimizer.load_state_dict(checkpoint['optimizer'])
checkpoint = None # free up memory

# helps estimate an arbitrarily accurate loss over either split using many batches
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            with ctx:
                logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

# learning rate decay scheduler (cosine with warmup)
def get_lr(it):
    # 1) linear warmup for warmup_iters steps
    if it < warmup_iters:
        return learning_rate * it / warmup_iters
    # 2) if it > lr_decay_iters, return min learning rate
    if it > lr_decay_iters:
        return min_lr
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
    return min_lr + coeff * (learning_rate - min_lr)

# logging
if wandb_log:
    import wandb
    wandb.init(project=wandb_project, name=wandb_run_name, config=config)

del state_dict, checkpoint
torch.cuda.empty_cache()

tokens per iteration will be: 163,840
Resuming training from out
number of parameters: 353.77M
num decayed parameter tensors: 98, with 354,501,632 parameters
num non-decayed parameter tensors: 194, with 321,536 parameters
using fused AdamW: True


In [None]:
# convert_to_prunable(model, device=device_type)

In [None]:
def finetune(model, iter_num, best_val_loss, eval_only=False):
  # training loop
  X, Y = get_batch('train') # fetch the very first batch
  t0 = time.time()
  running_mfu = -1.0
  for local_iter_num in range(max_iters):

      # determine and set the learning rate for this iteration
      lr = get_lr(iter_num) if decay_lr else learning_rate
      for param_group in optimizer.param_groups:
          param_group['lr'] = lr

      # evaluate the loss on train/val sets and write checkpoints
      if iter_num % eval_interval == 0:
          losses = estimate_loss()
          print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
          if wandb_log:
              wandb.log({
                  "iter": iter_num,
                  "train/loss": losses['train'],
                  "val/loss": losses['val'],
                  "lr": lr,
                  "mfu": running_mfu*100, # convert to percentage
              })
          if losses['val'] < best_val_loss or always_save_checkpoint:
              best_val_loss = losses['val']
              if iter_num > 0:
                  checkpoint = {
                      'model': model.state_dict(),
                      'optimizer': optimizer.state_dict(),
                      'model_args': model_args,
                      'iter_num': iter_num,
                      'best_val_loss': best_val_loss,
                      'config': config,
                  }
                  print(f"saving checkpoint to {out_dir}")
                  torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
      if local_iter_num == 0 and eval_only:
          break

      # forward backward update, with optional gradient accumulation to simulate larger batch size
      # and using the GradScaler if data type is float16
      for micro_step in range(gradient_accumulation_steps):
          with ctx:
              logits, loss = model(X, Y)
              loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation
          # immediately async prefetch next batch while model is doing the forward pass on the GPU
          X, Y = get_batch('train')
          # backward pass, with gradient scaling if training in fp16
          scaler.scale(loss).backward()
      # clip the gradient
      if grad_clip != 0.0:
          scaler.unscale_(optimizer)
          torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
      # step the optimizer and scaler if training in fp16
      scaler.step(optimizer)
      scaler.update()
      # flush the gradients as soon as we can, no need for this memory anymore
      optimizer.zero_grad(set_to_none=True)

      # timing and logging
      t1 = time.time()
      dt = t1 - t0
      t0 = t1
      if iter_num % log_interval == 0:
          # get loss as float. note: this is a CPU-GPU sync point
          # scale up to undo the division above, approximating the true total loss (exact would have been a sum)
          lossf = loss.item() * gradient_accumulation_steps
          if local_iter_num >= 5: # let the training loop settle a bit
              mfu = model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
              running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
          print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%")

      iter_num += 1
  return iter_num, best_val_loss

In [None]:
iter_num, best_val_loss = finetune(model, iter_num, best_val_loss)

step 0: train loss 3.7232, val loss 3.7326
iter 0: loss 3.8476, time 138304.75ms, mfu -100.00%
iter 1: loss 3.5523, time 31620.78ms, mfu -100.00%
iter 2: loss 3.8951, time 31535.70ms, mfu -100.00%
iter 3: loss 3.6432, time 31759.12ms, mfu -100.00%
iter 4: loss 3.6798, time 31835.23ms, mfu -100.00%
iter 5: loss 3.7793, time 31738.59ms, mfu 4.01%
iter 6: loss 3.6765, time 31699.78ms, mfu 4.01%
iter 7: loss 3.6175, time 31734.08ms, mfu 4.01%
iter 8: loss 3.5956, time 31774.57ms, mfu 4.01%
iter 9: loss 3.7610, time 31749.35ms, mfu 4.01%
iter 10: loss 3.7035, time 31711.85ms, mfu 4.01%
iter 11: loss 3.5853, time 31625.33ms, mfu 4.01%
iter 12: loss 3.5244, time 31618.93ms, mfu 4.01%
iter 13: loss 3.5588, time 31592.44ms, mfu 4.02%
iter 14: loss 3.5339, time 31542.19ms, mfu 4.02%
iter 15: loss 3.6042, time 31568.39ms, mfu 4.02%
iter 16: loss 3.4638, time 31736.23ms, mfu 4.02%
iter 17: loss 3.4432, time 31840.00ms, mfu 4.02%
iter 18: loss 3.5116, time 31844.93ms, mfu 4.02%
iter 19: loss 3.4349

In [None]:
# torch.save(model.state_dict(), "./out/backup_100_it.pt")

In [None]:
iter_num, best_val_loss = finetune(model, iter_num, best_val_loss, eval_only=True)

step 100: train loss 2.9703, val loss 3.0115
saving checkpoint to out


In [None]:
pcnt_pruned = model.prune(m=0.015)
print(pcnt_pruned)

0.12141100003021132


In [None]:
#
# for name, module in model.named_modules():
#   if isinstance(module, PrunableLinear):
#     print(name, module.n_pruned)
# # model.transformer['h'][2].attn.c_attn.n_pruned

In [None]:
# checkpoint = {
#     'model': model.state_dict(),
#     'optimizer': optimizer.state_dict(),
#     'model_args': model_args,
#     'iter_num': iter_num,
#     'best_val_loss': best_val_loss,
#     'config': config,
# }

In [None]:
# pcnt_pruned = model.prune(p=0.1)
# losses = estimate_loss()
# print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

In [None]:
iter_num, best_val_loss = finetune(model, iter_num, best_val_loss)

step 100: train loss 2.9889, val loss 3.0161
iter 100: loss 3.0003, time 138688.28ms, mfu -100.00%
iter 101: loss 2.9146, time 31454.99ms, mfu -100.00%
iter 102: loss 2.9878, time 31515.09ms, mfu -100.00%
iter 103: loss 2.8782, time 31510.14ms, mfu -100.00%
iter 104: loss 2.9141, time 31503.95ms, mfu -100.00%
iter 105: loss 2.9519, time 31454.09ms, mfu 4.05%
iter 106: loss 3.0229, time 31744.26ms, mfu 4.04%
iter 107: loss 2.9633, time 31592.72ms, mfu 4.04%
iter 108: loss 2.9199, time 31489.54ms, mfu 4.04%
iter 109: loss 3.0572, time 31549.54ms, mfu 4.04%
iter 110: loss 2.9676, time 31664.88ms, mfu 4.04%
iter 111: loss 2.9564, time 31685.03ms, mfu 4.04%
iter 112: loss 3.0131, time 31741.12ms, mfu 4.04%
iter 113: loss 3.0980, time 31670.20ms, mfu 4.03%
iter 114: loss 3.0155, time 31741.99ms, mfu 4.03%
iter 115: loss 3.0021, time 31737.56ms, mfu 4.03%
iter 116: loss 2.9743, time 31589.28ms, mfu 4.03%
iter 117: loss 2.9903, time 31544.30ms, mfu 4.03%
iter 118: loss 2.9818, time 31551.33ms,

In [None]:
print(iter_num, best_val_loss)

200 tensor(3.0115, device='cuda:0')


In [None]:
iter_num, best_val_loss = finetune(model, iter_num, best_val_loss, eval_only=True)

step 200: train loss 2.8916, val loss 2.9354


In [None]:
print(iter_num, best_val_loss)

200 tensor(2.9343)


In [None]:
# print(f"saving checkpoint to {out_dir}")
# torch.save({
#     'model': model.state_dict(),
#     'optimizer': optimizer.state_dict(),
#     'model_args': model_args,
#     'iter_num': iter_num,
#     'best_val_loss': best_val_loss,
#     'config': config,
# }, os.path.join(out_dir, 'ckpt_200_pruned_0.12.pt'))

saving checkpoint to out


In [None]:
with torch.no_grad():
  for module in model.modules():
    if isinstance(module, PrunableLinear):
      module.weight *= module.mask

In [None]:
torch.where(~model.transformer['h'][2].attn.c_proj.mask[1])

(tensor([   8,   11,   13,   14,   24,   26,   27,   30,   31,   34,   36,   40,
           42,   44,   47,   49,   52,   55,   57,   62,   66,   71,   81,   85,
           90,   92,   96,  106,  107,  109,  111,  113,  114,  115,  118,  124,
          125,  126,  129,  131,  132,  135,  138,  144,  151,  163,  172,  178,
          193,  206,  211,  215,  216,  241,  242,  247,  248,  254,  257,  260,
          263,  281,  288,  289,  297,  299,  300,  318,  322,  327,  329,  361,
          368,  394,  397,  398,  403,  405,  408,  410,  411,  413,  460,  461,
          468,  479,  484,  490,  491,  493,  494,  495,  499,  518,  519,  530,
          534,  537,  538,  548,  549,  554,  555,  567,  570,  575,  579,  580,
          581,  593,  596,  600,  601,  605,  606,  607,  610,  614,  634,  641,
          646,  647,  650,  667,  672,  675,  680,  700,  704,  707,  708,  711,
          716,  719,  721,  727,  728,  731,  733,  734,  737,  740,  744,  745,
          747,  753,  758,  

In [None]:
model.prune(m=0.025)

0.20027845300998892

In [None]:
with torch.no_grad():
  for module in model.modules():
    if isinstance(module, PrunableLinear):
      module.weight *= module.mask

In [None]:
torch.where(~model.transformer['h'][2].attn.c_proj.mask[1])

(tensor([   1,    2,    8,    9,   11,   13,   14,   23,   24,   26,   27,   29,
           30,   31,   34,   36,   40,   41,   42,   43,   44,   45,   47,   49,
           52,   55,   57,   59,   62,   65,   66,   69,   71,   72,   74,   75,
           76,   78,   80,   81,   83,   85,   87,   89,   90,   92,   96,  104,
          105,  106,  107,  109,  110,  111,  113,  114,  115,  117,  118,  120,
          122,  123,  124,  125,  126,  129,  131,  132,  133,  135,  136,  138,
          144,  148,  149,  151,  163,  172,  175,  178,  182,  186,  193,  195,
          197,  202,  204,  205,  206,  210,  211,  214,  215,  216,  241,  242,
          246,  247,  248,  252,  254,  257,  260,  263,  266,  267,  271,  274,
          281,  283,  288,  289,  297,  299,  300,  302,  303,  312,  318,  320,
          322,  327,  329,  333,  334,  336,  351,  361,  363,  368,  369,  383,
          391,  393,  394,  395,  397,  398,  403,  405,  406,  408,  410,  411,
          412,  413,  417,  

In [None]:
iter_num, best_val_loss = finetune(model, iter_num, best_val_loss)

step 200: train loss 2.9388, val loss 2.9758
iter 200: loss 2.9429, time 137213.41ms, mfu -100.00%
iter 201: loss 2.8565, time 31290.64ms, mfu -100.00%
iter 202: loss 2.9032, time 31250.92ms, mfu -100.00%
iter 203: loss 2.8096, time 31192.49ms, mfu -100.00%
iter 204: loss 2.8390, time 31329.26ms, mfu -100.00%
iter 205: loss 2.8935, time 31586.36ms, mfu 4.03%
iter 206: loss 2.9585, time 31219.90ms, mfu 4.04%
iter 207: loss 2.8916, time 31300.63ms, mfu 4.04%
iter 208: loss 2.8385, time 31321.18ms, mfu 4.04%
iter 209: loss 2.9934, time 31317.19ms, mfu 4.04%
iter 210: loss 2.8834, time 31414.56ms, mfu 4.04%
iter 211: loss 2.8969, time 31447.02ms, mfu 4.05%
iter 212: loss 2.9224, time 31478.04ms, mfu 4.05%
iter 213: loss 3.0015, time 31521.42ms, mfu 4.04%
iter 214: loss 2.9259, time 31520.53ms, mfu 4.04%
iter 215: loss 2.9143, time 31496.70ms, mfu 4.04%
iter 216: loss 2.8812, time 31443.80ms, mfu 4.04%
iter 217: loss 2.8985, time 31420.98ms, mfu 4.05%
iter 218: loss 2.9008, time 31309.13ms,

In [None]:
print(f"saving checkpoint to {out_dir}")
torch.save({
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'model_args': model_args,
    'iter_num': iter_num,
    'best_val_loss': best_val_loss,
    'config': config,
}, os.path.join(out_dir, 'ckpt_300_pruned_0.20.pt'))

saving checkpoint to out


In [None]:
torch.where(~model.transformer['h'][2].attn.c_proj.mask[1])

(tensor([   1,    2,    8,    9,   11,   13,   14,   23,   24,   26,   27,   29,
           30,   31,   34,   36,   40,   41,   42,   43,   44,   45,   47,   49,
           52,   55,   57,   59,   62,   65,   66,   69,   71,   72,   74,   75,
           76,   78,   80,   81,   83,   85,   87,   89,   90,   92,   96,  104,
          105,  106,  107,  109,  110,  111,  113,  114,  115,  117,  118,  120,
          122,  123,  124,  125,  126,  129,  131,  132,  133,  135,  136,  138,
          144,  148,  149,  151,  163,  172,  175,  178,  182,  186,  193,  195,
          197,  202,  204,  205,  206,  210,  211,  214,  215,  216,  241,  242,
          246,  247,  248,  252,  254,  257,  260,  263,  266,  267,  271,  274,
          281,  283,  288,  289,  297,  299,  300,  302,  303,  312,  318,  320,
          322,  327,  329,  333,  334,  336,  351,  361,  363,  368,  369,  383,
          391,  393,  394,  395,  397,  398,  403,  405,  406,  408,  410,  411,
          412,  413,  417,  

In [None]:
iter_num, best_val_loss = finetune(model, iter_num, best_val_loss, eval_only=True)

step 300: train loss 2.8775, val loss 2.9221
saving checkpoint to out


In [None]:
model.prune(m=0.0385)

0.3025793780012331

In [None]:
torch.where(~model.transformer['h'][2].attn.c_proj.mask[1])

(tensor([   0,    1,    2,    6,    8,    9,   11,   12,   13,   14,   17,   23,
           24,   25,   26,   27,   29,   30,   31,   33,   34,   36,   40,   41,
           42,   43,   44,   45,   47,   49,   50,   52,   53,   54,   55,   56,
           57,   59,   62,   63,   64,   65,   66,   69,   70,   71,   72,   74,
           75,   76,   78,   80,   81,   83,   85,   86,   87,   89,   90,   91,
           92,   95,   96,   97,  100,  101,  103,  104,  105,  106,  107,  109,
          110,  111,  113,  114,  115,  117,  118,  120,  122,  123,  124,  125,
          126,  129,  131,  132,  133,  134,  135,  136,  138,  143,  144,  146,
          148,  149,  151,  163,  166,  170,  171,  172,  175,  178,  180,  182,
          184,  186,  187,  191,  193,  195,  197,  198,  200,  201,  202,  204,
          205,  206,  210,  211,  213,  214,  215,  216,  217,  229,  235,  236,
          238,  241,  242,  243,  246,  247,  248,  252,  253,  254,  257,  258,
          260,  263,  264,  

In [None]:
iter_num, best_val_loss = finetune(model, iter_num, best_val_loss)

step 300: train loss 2.9700, val loss 3.0600
iter 300: loss 3.0705, time 136722.15ms, mfu -100.00%
iter 301: loss 2.8066, time 31074.61ms, mfu -100.00%
iter 302: loss 2.9218, time 31145.96ms, mfu -100.00%
iter 303: loss 3.0061, time 31397.18ms, mfu -100.00%
iter 304: loss 2.9983, time 31248.36ms, mfu -100.00%
iter 305: loss 2.9643, time 31098.34ms, mfu 4.09%
iter 306: loss 2.8019, time 31065.81ms, mfu 4.09%
iter 307: loss 2.6629, time 31057.56ms, mfu 4.10%
iter 308: loss 2.8768, time 31195.51ms, mfu 4.09%
iter 309: loss 2.8451, time 31230.92ms, mfu 4.09%
iter 310: loss 2.9171, time 31225.30ms, mfu 4.09%
iter 311: loss 2.8218, time 31206.47ms, mfu 4.09%
iter 312: loss 2.7488, time 31154.20ms, mfu 4.09%
iter 313: loss 2.9174, time 31145.44ms, mfu 4.09%
iter 314: loss 2.8661, time 31177.37ms, mfu 4.09%
iter 315: loss 2.7546, time 31179.94ms, mfu 4.09%
iter 316: loss 2.8442, time 31100.56ms, mfu 4.09%
iter 317: loss 2.7941, time 31260.97ms, mfu 4.09%
iter 318: loss 2.9990, time 31173.99ms,

In [None]:
with torch.no_grad():
  for module in model.modules():
    if isinstance(module, PrunableLinear):
      module.weight *= module.mask

In [None]:
print(f"saving checkpoint to {out_dir}")
torch.save({
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'model_args': model_args,
    'iter_num': iter_num,
    'best_val_loss': best_val_loss,
    'config': config,
}, os.path.join(out_dir, 'ckpt_400_pruned_0.30.pt'))

saving checkpoint to out


In [None]:
torch.where(model.transformer['h'][2].attn.c_proj.weight[1] == 0.0)

(tensor([   0,    1,    2,    6,    8,    9,   11,   12,   13,   14,   17,   23,
           24,   25,   26,   27,   29,   30,   31,   33,   34,   36,   40,   41,
           42,   43,   44,   45,   47,   49,   50,   52,   53,   54,   55,   56,
           57,   59,   62,   63,   64,   65,   66,   69,   70,   71,   72,   74,
           75,   76,   78,   80,   81,   83,   85,   86,   87,   89,   90,   91,
           92,   95,   96,   97,  100,  101,  103,  104,  105,  106,  107,  109,
          110,  111,  113,  114,  115,  117,  118,  120,  122,  123,  124,  125,
          126,  129,  131,  132,  133,  134,  135,  136,  138,  143,  144,  146,
          148,  149,  151,  163,  166,  170,  171,  172,  175,  178,  180,  182,
          184,  186,  187,  191,  193,  195,  197,  198,  200,  201,  202,  204,
          205,  206,  210,  211,  213,  214,  215,  216,  217,  229,  235,  236,
          238,  241,  242,  243,  246,  247,  248,  252,  253,  254,  257,  258,
          260,  263,  264,  

In [None]:
iter_num, best_val_loss = finetune(model, iter_num, best_val_loss, eval_only=True)

step 400: train loss 2.8797, val loss 2.9228


In [None]:
model.prune(m=0.055)

0.41900721067045993

In [None]:
with torch.no_grad():
  for module in model.modules():
    if isinstance(module, PrunableLinear):
      module.weight *= module.mask

In [None]:
iter_num, best_val_loss = finetune(model, iter_num, best_val_loss)

step 400: train loss 3.1741, val loss 3.2333
iter 400: loss 3.1238, time 135262.17ms, mfu -100.00%
iter 401: loss 3.1533, time 30914.75ms, mfu -100.00%
iter 402: loss 3.1931, time 30965.67ms, mfu -100.00%
iter 403: loss 3.0026, time 31072.06ms, mfu -100.00%
iter 404: loss 3.0646, time 31042.03ms, mfu -100.00%
iter 405: loss 2.9705, time 31105.72ms, mfu 4.09%
iter 406: loss 2.9675, time 31059.01ms, mfu 4.09%
iter 407: loss 3.1748, time 31050.29ms, mfu 4.09%
iter 408: loss 2.8497, time 30968.47ms, mfu 4.10%
iter 409: loss 2.9522, time 31007.30ms, mfu 4.10%
iter 410: loss 3.0652, time 31036.07ms, mfu 4.10%
iter 411: loss 3.0838, time 31105.06ms, mfu 4.10%
iter 412: loss 3.0813, time 31143.65ms, mfu 4.10%
iter 413: loss 3.1053, time 31089.94ms, mfu 4.10%
iter 414: loss 3.0127, time 31063.78ms, mfu 4.10%
iter 415: loss 3.1330, time 30994.39ms, mfu 4.10%
iter 416: loss 2.9300, time 30954.18ms, mfu 4.10%
iter 417: loss 2.8523, time 30911.01ms, mfu 4.10%
iter 418: loss 3.0828, time 30956.65ms,

In [None]:
torch.where(torch.abs(model.transformer['h'][2].attn.c_proj.weight[1]) < 1e-3)

(tensor([   0,    1,    2,    4,    6,    8,    9,   10,   11,   12,   13,   14,
           15,   16,   17,   19,   20,   21,   22,   23,   24,   25,   26,   27,
           28,   29,   30,   31,   32,   33,   34,   35,   36,   37,   38,   39,
           40,   41,   42,   43,   44,   45,   46,   47,   49,   50,   51,   52,
           53,   54,   55,   56,   57,   59,   62,   63,   64,   65,   66,   67,
           69,   70,   71,   72,   74,   75,   76,   78,   80,   81,   83,   84,
           85,   86,   87,   89,   90,   91,   92,   93,   95,   96,   97,   98,
          100,  101,  103,  104,  105,  106,  107,  108,  109,  110,  111,  113,
          114,  115,  116,  117,  118,  120,  121,  122,  123,  124,  125,  126,
          129,  131,  132,  133,  134,  135,  136,  137,  138,  143,  144,  146,
          148,  149,  151,  154,  158,  159,  163,  164,  165,  166,  169,  170,
          171,  172,  175,  176,  178,  180,  182,  184,  186,  187,  188,  191,
          193,  195,  197,  

In [None]:
iter_num, best_val_loss = finetune(model, iter_num, best_val_loss, eval_only=True)

step 500: train loss 2.9011, val loss 2.9731


In [None]:
with torch.no_grad():
  for module in model.modules():
    if isinstance(module, PrunableLinear):
      module.weight *= module.mask

In [None]:
iter_num, best_val_loss = finetune(model, iter_num, best_val_loss, eval_only=True)

step 500: train loss 2.9045, val loss 2.9726


In [None]:
print(f"saving checkpoint to {out_dir}")
torch.save({
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'model_args': model_args,
    'iter_num': iter_num,
    'best_val_loss': best_val_loss,
    'config': config,
}, os.path.join(out_dir, 'ckpt_500_pruned_0.42.pt'))

saving checkpoint to out


In [None]:
model.prune(m=0.068)

0.5028553322450019

In [None]:
with torch.no_grad():
  for module in model.modules():
    if isinstance(module, PrunableLinear):
      module.weight *= module.mask

In [None]:
iter_num, best_val_loss = finetune(model, iter_num, best_val_loss)

step 500: train loss 3.2638, val loss 3.3162
iter 500: loss 3.2477, time 132647.88ms, mfu -100.00%
iter 501: loss 3.1438, time 30853.46ms, mfu -100.00%
iter 502: loss 3.1076, time 30571.51ms, mfu -100.00%
iter 503: loss 3.0438, time 30600.49ms, mfu -100.00%
iter 504: loss 3.0423, time 30792.53ms, mfu -100.00%
iter 505: loss 3.0970, time 30822.93ms, mfu 4.13%
iter 506: loss 3.1580, time 30886.38ms, mfu 4.13%
iter 507: loss 3.0724, time 30856.76ms, mfu 4.13%
iter 508: loss 3.0667, time 30817.67ms, mfu 4.13%
iter 509: loss 3.1683, time 30746.05ms, mfu 4.13%
iter 510: loss 3.0291, time 30764.77ms, mfu 4.13%
iter 511: loss 3.0810, time 30697.59ms, mfu 4.13%
iter 512: loss 3.0712, time 30701.52ms, mfu 4.13%
iter 513: loss 3.1551, time 30681.75ms, mfu 4.14%
iter 514: loss 3.0471, time 30693.77ms, mfu 4.14%
iter 515: loss 3.0497, time 30658.60ms, mfu 4.14%
iter 516: loss 2.9981, time 30815.83ms, mfu 4.14%
iter 517: loss 3.0083, time 30804.96ms, mfu 4.14%
iter 518: loss 3.0388, time 30867.43ms,

In [None]:
iter_num, best_val_loss = finetune(model, iter_num, best_val_loss, eval_only=True)

step 600: train loss 2.9010, val loss 3.0341


In [None]:
with torch.no_grad():
  for module in model.modules():
    if isinstance(module, PrunableLinear):
      module.weight *= module.mask

In [None]:
iter_num, best_val_loss = finetune(model, iter_num, best_val_loss, eval_only=True)

step 600: train loss 2.9593, val loss 3.0223


In [None]:
print(f"saving checkpoint to {out_dir}")
torch.save({
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'model_args': model_args,
    'iter_num': iter_num,
    'best_val_loss': best_val_loss,
    'config': config,
}, os.path.join(out_dir, 'ckpt_600_pruned_0.50.pt'))

saving checkpoint to out


In [None]:
model.prune(m=0.085)

0.6007125011397088

In [None]:
with torch.no_grad():
  for module in model.modules():
    if isinstance(module, PrunableLinear):
      module.weight *= module.mask

In [None]:
iter_num, best_val_loss = finetune(model, iter_num, best_val_loss)

step 600: train loss 4.2862, val loss 4.3404
iter 600: loss 4.1979, time 132184.71ms, mfu -100.00%
iter 601: loss 4.0288, time 30408.08ms, mfu -100.00%
iter 602: loss 3.9112, time 30343.78ms, mfu -100.00%
iter 603: loss 3.6945, time 30407.14ms, mfu -100.00%
iter 604: loss 3.6111, time 30394.88ms, mfu -100.00%
iter 605: loss 3.4950, time 30443.56ms, mfu 4.18%
iter 606: loss 3.3665, time 30463.75ms, mfu 4.18%
iter 607: loss 3.5673, time 30534.09ms, mfu 4.18%
iter 608: loss 3.2651, time 30558.14ms, mfu 4.18%
iter 609: loss 3.3153, time 30650.13ms, mfu 4.18%
iter 610: loss 3.4048, time 30542.66ms, mfu 4.18%
iter 611: loss 3.4150, time 30517.91ms, mfu 4.18%
iter 612: loss 3.4047, time 30363.45ms, mfu 4.18%
iter 613: loss 3.4133, time 30456.97ms, mfu 4.18%
iter 614: loss 3.3169, time 30475.37ms, mfu 4.18%
iter 615: loss 3.4294, time 30528.85ms, mfu 4.18%
iter 616: loss 3.2098, time 30426.96ms, mfu 4.18%
iter 617: loss 3.1102, time 30382.02ms, mfu 4.18%
iter 618: loss 3.3461, time 30485.32ms,

In [None]:
iter_num, best_val_loss = finetune(model, iter_num, best_val_loss, eval_only=True)

step 700: train loss 3.0812, val loss 3.1637


In [None]:
with torch.no_grad():
  for module in model.modules():
    if isinstance(module, PrunableLinear):
      module.weight *= module.mask

In [None]:
iter_num, best_val_loss = finetune(model, iter_num, best_val_loss, eval_only=True)

step 700: train loss 3.0911, val loss 3.1691


In [None]:
print(f"saving checkpoint to {out_dir}")
torch.save({
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'model_args': model_args,
    'iter_num': iter_num,
    'best_val_loss': best_val_loss,
    'config': config,
}, os.path.join(out_dir, 'ckpt_700_pruned_0.60.pt'))

saving checkpoint to out


In [None]:
model.prune(m=0.11)

0.7190681658676042

In [None]:
with torch.no_grad():
  for module in model.modules():
    if isinstance(module, PrunableLinear):
      module.weight *= module.mask

In [None]:
iter_num

700

In [None]:
iter_num, best_val_loss = finetune(model, iter_num, best_val_loss)

step 700: train loss 8.1779, val loss 8.1862
iter 700: loss 8.1901, time 129157.44ms, mfu -100.00%
iter 701: loss 7.7622, time 29898.45ms, mfu -100.00%
iter 702: loss 7.3745, time 30075.06ms, mfu -100.00%
iter 703: loss 6.9212, time 30213.90ms, mfu -100.00%
iter 704: loss 6.5888, time 30212.24ms, mfu -100.00%
iter 705: loss 6.1762, time 30242.17ms, mfu 4.21%
iter 706: loss 5.6560, time 30225.20ms, mfu 4.21%
iter 707: loss 5.5904, time 30214.18ms, mfu 4.21%
iter 708: loss 5.3330, time 30116.91ms, mfu 4.21%
iter 709: loss 5.1531, time 30044.39ms, mfu 4.22%
iter 710: loss 5.0524, time 30178.26ms, mfu 4.22%
iter 711: loss 4.9326, time 30309.63ms, mfu 4.21%
iter 712: loss 4.9083, time 29992.96ms, mfu 4.22%
iter 713: loss 4.6380, time 30146.25ms, mfu 4.22%
iter 714: loss 4.7053, time 30237.18ms, mfu 4.22%
iter 715: loss 4.6639, time 30244.66ms, mfu 4.22%
iter 716: loss 4.5953, time 30259.72ms, mfu 4.22%
iter 717: loss 4.5339, time 30259.80ms, mfu 4.21%
iter 718: loss 4.3792, time 30218.27ms,

In [None]:
iter_num, best_val_loss = finetune(model, iter_num, best_val_loss, eval_only=True)

step 800: train loss 3.4667, val loss 3.5555


In [None]:
with torch.no_grad():
  for module in model.modules():
    if isinstance(module, PrunableLinear):
      module.weight *= module.mask

In [None]:
iter_num, best_val_loss = finetune(model, iter_num, best_val_loss, eval_only=True)

step 800: train loss 3.5245, val loss 3.5823


In [None]:
print(f"saving checkpoint to {out_dir}")
torch.save({
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'model_args': model_args,
    'iter_num': iter_num,
    'best_val_loss': best_val_loss,
    'config': config,
}, os.path.join(out_dir, 'ckpt_800_pruned_0.72.pt'))

saving checkpoint to out


In [None]:
model.prune(m=0.135)

0.8087050581631369

In [None]:
with torch.no_grad():
  for module in model.modules():
    if isinstance(module, PrunableLinear):
      module.weight *= module.mask

In [None]:
iter_num, best_val_loss = finetune(model, iter_num, best_val_loss)

step 800: train loss 8.3409, val loss 8.3226
iter 800: loss 8.3951, time 127371.10ms, mfu -100.00%
iter 801: loss 7.7727, time 29615.42ms, mfu -100.00%
iter 802: loss 7.3765, time 29540.40ms, mfu -100.00%
iter 803: loss 7.2748, time 29732.66ms, mfu -100.00%
iter 804: loss 7.1498, time 29776.35ms, mfu -100.00%
iter 805: loss 6.8713, time 29875.20ms, mfu 4.26%
iter 806: loss 6.7723, time 29766.01ms, mfu 4.26%
iter 807: loss 6.6045, time 29648.06ms, mfu 4.27%
iter 808: loss 6.4459, time 29570.16ms, mfu 4.27%
iter 809: loss 6.3131, time 29668.17ms, mfu 4.27%
iter 810: loss 6.1586, time 29706.61ms, mfu 4.27%
iter 811: loss 6.0910, time 29771.78ms, mfu 4.27%
iter 812: loss 5.8455, time 29775.88ms, mfu 4.27%
iter 813: loss 5.7877, time 29784.54ms, mfu 4.27%
iter 814: loss 5.7291, time 29713.06ms, mfu 4.28%
iter 815: loss 5.6286, time 29556.17ms, mfu 4.28%
iter 816: loss 5.6160, time 29772.07ms, mfu 4.28%
iter 817: loss 5.4840, time 29848.01ms, mfu 4.28%
iter 818: loss 5.4755, time 29516.44ms,

In [None]:
iter_num, best_val_loss = finetune(model, iter_num, best_val_loss, eval_only=True)

step 900: train loss 3.9550, val loss 4.0399


In [None]:
with torch.no_grad():
  for module in model.modules():
    if isinstance(module, PrunableLinear):
      module.weight *= module.mask

In [None]:
iter_num, best_val_loss = finetune(model, iter_num, best_val_loss, eval_only=True)

step 900: train loss 4.1152, val loss 4.2020


In [None]:
print(f"saving checkpoint to {out_dir}")
torch.save({
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'model_args': model_args,
    'iter_num': iter_num,
    'best_val_loss': best_val_loss,
    'config': config,
}, os.path.join(out_dir, 'ckpt_900_pruned_0.81.pt'))

saving checkpoint to out


In [None]:
model.prune(m=0.175)

0.9022883983709039

In [None]:
with torch.no_grad():
  for module in model.modules():
    if isinstance(module, PrunableLinear):
      module.weight *= module.mask

In [None]:
iter_num, best_val_loss = finetune(model, iter_num, best_val_loss)

step 900: train loss 9.6733, val loss 9.6650
iter 900: loss 9.7478, time 124005.52ms, mfu -100.00%
iter 901: loss 8.7884, time 29169.71ms, mfu -100.00%
iter 902: loss 8.3018, time 28942.13ms, mfu -100.00%
iter 903: loss 8.1438, time 29191.02ms, mfu -100.00%
iter 904: loss 7.9312, time 29174.26ms, mfu -100.00%
iter 905: loss 7.8321, time 29220.66ms, mfu 4.36%
iter 906: loss 7.8898, time 29105.11ms, mfu 4.36%
iter 907: loss 7.7316, time 29048.25ms, mfu 4.36%
iter 908: loss 7.7245, time 28927.65ms, mfu 4.37%
iter 909: loss 7.7123, time 28945.70ms, mfu 4.37%
iter 910: loss 7.5970, time 29150.14ms, mfu 4.37%
iter 911: loss 7.5676, time 29204.94ms, mfu 4.37%
iter 912: loss 7.5733, time 29101.15ms, mfu 4.37%
iter 913: loss 7.4346, time 29039.16ms, mfu 4.37%
iter 914: loss 7.3628, time 28961.09ms, mfu 4.37%
iter 915: loss 7.2901, time 29100.33ms, mfu 4.37%
iter 916: loss 7.4284, time 29232.42ms, mfu 4.37%
iter 917: loss 7.2996, time 29121.57ms, mfu 4.37%
iter 918: loss 7.2370, time 28995.35ms,

In [None]:
iter_num, best_val_loss = finetune(model, iter_num, best_val_loss, eval_only=True)

step 1000: train loss 5.6313, val loss 5.6840


In [None]:
with torch.no_grad():
  for module in model.modules():
    if isinstance(module, PrunableLinear):
      module.weight *= module.mask

In [None]:
iter_num, best_val_loss = finetune(model, iter_num, best_val_loss, eval_only=True)

step 1000: train loss 6.1090, val loss 6.1618


In [None]:
print(f"saving checkpoint to {out_dir}")
torch.save({
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'model_args': model_args,
    'iter_num': iter_num,
    'best_val_loss': best_val_loss,
    'config': config,
}, os.path.join(out_dir, 'ckpt_1000_pruned_0.90.pt'))

saving checkpoint to out
