In [1]:
from pathlib import Path

import torch
import torch.nn.functional as F
import datasets as hfds

import olmo
from olmo.config import TrainConfig
from olmo.util import clean_opt
from olmo.torch_util import seed_all
from olmo.data import build_train_dataloader, IterableDataset

from olmo.optim import build_optimizer
from olmo.config import TrainConfig
from olmo.checkpoint import load_state_dict
from olmo.model import OLMo

from zsl_config import ZSL_DIR_OUT_OLMO, ZSL_DIR_ANALYSIS, ZSL_DIR_DATA

from typing import Union, Optional

In [2]:

MODEL_CLASS = 'olmo'
DATASET = 'c4_en_val'
ANALYSIS_NAME = 'per_token_change_in_loss-dt=1'

RUNS = [
        '1028-rmsnorm-14m',
        '1028-rmsnorm-37m',
        '1028-rmsnorm-78m',
        '1028-rmsnorm-144m',
        '1028-rmsnorm-285m',
        '1028-rmsnorm-472m',
    ]

VERBOSE = True
OVERWRITE = False

OUT_DIR = ZSL_DIR_ANALYSIS / ANALYSIS_NAME
OUT_DIR.mkdir(parents=True, exist_ok=True)

# Measure change in losses for $\Delta t=1$

## Data loading

### Train batch

In [3]:
def load_train_batch(run, step, *args) -> TrainConfig:
    yaml_path = ZSL_DIR_OUT_OLMO / run / 'config.yaml'
    cfg = TrainConfig.load(yaml_path, [clean_opt(s) for s in args])
    # Set `global_indices_file` to shared path
    cfg.data.global_indices_file = ZSL_DIR_OUT_OLMO / "train_data/global_indices.npy"
    # Load single batch instead of distributed 
    cfg.device_train_batch_size = cfg.global_train_batch_size
    seed_all(cfg.seed)
    train_loader = build_train_dataloader(cfg)
    assert isinstance(train_loader.dataset, IterableDataset)

    global_train_examples_seen_this_epoch = cfg.global_train_batch_size * step
    train_loader.dataset.start_index = global_train_examples_seen_this_epoch
    batch = next(iter(train_loader))
    return batch['input_ids']

def get_device_bsz(run):
    yaml_path = ZSL_DIR_OUT_OLMO / run / 'config.yaml'
    cfg = TrainConfig.load(yaml_path)
    bsz = cfg.device_train_batch_size
    return bsz

if VERBOSE:
    batch = load_train_batch(RUNS[0],1)
    print(f"Loaded batch with shape: {batch.shape}")
    print(f"Original batch size: {get_device_bsz(RUNS[0])}")
    del batch

Loaded batch with shape: torch.Size([512, 1024])
Original batch size: 64


### Eval dataloader

In [4]:
tokenized_eval_data = ZSL_DIR_DATA / f'tokenized/{MODEL_CLASS}-{DATASET}'
assert tokenized_eval_data.exists()

def get_dataloader(bsz: int = 4, device: str = 'cpu'):
    dataset = hfds.load_from_disk(tokenized_eval_data)
    dataset.set_format(type='torch', columns=['input_ids'])
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=bsz, shuffle=False, 
                                             pin_memory=device != 'cpu', 
                                             pin_memory_device=device)
    return dataloader

if VERBOSE:
    dataloader = get_dataloader(device='cuda' if torch.cuda.is_available() else 'cpu')
    print(f"len(dataloader): {len(dataloader)}")
    print(f"batch shape: {next(iter(dataloader))['input_ids'].shape}")
    print(f"batch device: {next(iter(dataloader))['input_ids'].device}")
    del dataloader

len(dataloader): 244
batch shape: torch.Size([4, 1024])
batch device: cpu


### Model loading

In [5]:
def load_model(run: str, step: int, device="cuda", overrides: Optional[list] = None):
    ckpt_dir = ZSL_DIR_OUT_OLMO / run / f"step{step}-unsharded"
    overrides = overrides or []
    overrides.append(f"model.init_device={device}")
    cfg = TrainConfig.load(
        ckpt_dir / "config.yaml", validate_paths=False, overrides=overrides
    )
    model = OLMo(cfg.model, init_params=False)
    state_dict = load_state_dict(ckpt_dir, "model.pt", map_location=device)
    model.load_state_dict(state_dict)
    return model


def load_optimizer(
    model, run: str, step: int, device="cuda", overrides: Optional[list] = None
):
    ckpt_dir = ZSL_DIR_OUT_OLMO / run / f"step{step}-unsharded"
    overrides = overrides or []
    overrides.append(f"model.init_device={device}")
    cfg = TrainConfig.load(
        ckpt_dir / "config.yaml", validate_paths=False, overrides=overrides
    )
    optimizer = build_optimizer(cfg, model)
    state_dict = load_state_dict(ckpt_dir, "optim.pt", map_location=device)
    optimizer.load_state_dict(state_dict)
    return optimizer


def get_model_steps(run: str):
    ckpts_dir = ZSL_DIR_OUT_OLMO / run
    return sorted(
        [
            int(d.name.replace("step", "").replace("-unsharded", ""))
            for d in ckpts_dir.glob("step[1-9]*-unsharded")
        ]
    )


if VERBOSE:
    run = RUNS[0]
    steps = get_model_steps(run)
    print(f"steps for run {run}: {steps}")
    step = steps[5]

    print("=" * 20, "MODEL", "=" * 20)
    model = load_model(run, step, device="cuda" if torch.cuda.is_available() else "cpu")
    print(model)

    print("="*20, "OPTIMIZER", "="*20)
    optimizer = load_optimizer(model, run, step, device='cpu')
    print(optimizer)

    print("="*20, "PARAMS", "="*20)
    for n,p in model.named_parameters():
        o = optimizer.state[p]
        print("Param: ", n, p.shape, p.dtype, p.device)
        print("Optim: ", n, o['exp_avg'].shape, o['exp_avg'].dtype, o['exp_avg'].device)
        print('.'*80)
        
    del run, step, steps, model, n, p

steps for run 1028-rmsnorm-14m: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144]


  return torch.load(path, map_location=map_location)


OLMo(
  (transformer): ModuleDict(
    (wte): Embedding(50304, 256)
    (emb_drop): Dropout(p=0.0, inplace=False)
    (ln_f): RMSLayerNorm()
    (blocks): ModuleList(
      (0-3): 4 x OLMoSequentialBlock(
        (dropout): Dropout(p=0.0, inplace=False)
        (act): SwiGLU()
        (attn_out): Linear(in_features=256, out_features=256, bias=False)
        (ff_out): Linear(in_features=128, out_features=256, bias=False)
        (rotary_emb): RotaryEmbedding()
        (att_proj): Linear(in_features=256, out_features=768, bias=False)
        (ff_proj): Linear(in_features=256, out_features=256, bias=False)
        (attn_norm): RMSLayerNorm()
        (ff_norm): RMSLayerNorm()
      )
    )
  )
)
AdamW (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.95)
    capturable: False
    differentiable: False
    eps: 1e-05
    foreach: None
    fused: None
    initial_lr: 0.00126
    lr: 0.000144144
    max_grad_norm: 1.0
    max_grad_norm_ratio: None
    maximize: False
    param_names: 

## Experiment

In [6]:
num_train_batches = 5
bsz_eval = 64
device = "cuda" if torch.cuda.is_available() else "cpu"
overrides = ["model.flash_attention=False"] if device == "cpu" else []

# train dataloader + cache tokens
run = RUNS[0]
train_batch_steps = [get_model_steps(run)[-1] + i for i in range(num_train_batches)]
train_batches = {}
for batch_step in train_batch_steps:
    out_path = OUT_DIR / f"tokens/train-batch={batch_step}.pt"
    if out_path.exists() and not OVERWRITE:
        batch = torch.load(out_path, weights_only=True)
    else:
        batch = load_train_batch(run, batch_step)
        out_path.parent.mkdir(parents=True, exist_ok=True)
        torch.save(batch, out_path)

    microbsz = get_device_bsz(RUNS[0])
    num_microbatches = batch.shape[0] // microbsz
    train_batches[batch_step] = batch

# eval dataloader + cache tokens
eval_dataloader = get_dataloader(bsz=bsz_eval, device=device)
eval_tokens = [b["input_ids"].cpu() for b in eval_dataloader]
eval_tokens = torch.cat(eval_tokens, dim=0)
out_path = OUT_DIR / f"tokens/eval-{DATASET}.pt"
out_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(eval_tokens, out_path)

amp_ctx = torch.amp.autocast(device, dtype=torch.bfloat16)
inf_ctx = torch.inference_mode()

for run in RUNS:
    out_dir = OUT_DIR / run
    out_dir.mkdir(exist_ok=True, parents=False)

    steps = get_model_steps(run)
    for t in steps:
        # set model to None for lazy loading
        model = None

        # 1. init eval losses
        path_eval_init = out_dir / f".eval-{DATASET}/step{t}.pt"
        print(f"[{run}][{t}] {path_eval_init}", end="\r")
        if not path_eval_init.exists() or (path_eval_init.exists() and OVERWRITE):
            path_eval_init.parent.mkdir(parents=True, exist_ok=True)
            model = load_model(run, t, overrides=overrides, device=device)
            model = model.to(torch.bfloat16)
            model.eval()
            losses = []
            with inf_ctx, amp_ctx:
                for i, batch in enumerate(eval_dataloader):
                    print(f"[{run}][{t}][{i}][init-eval]\t\t\t", end="\r")
                    batch = batch["input_ids"].to(device)
                    input_ids = batch[:, :-1].contiguous()
                    labels = batch[:, 1:].flatten().to(device)
                    logits = model(input_ids).logits.flatten(0, 1)
                    _losses = F.cross_entropy(logits, labels, reduction="none")
                    losses.append(_losses.detach().cpu())
                    del logits, _losses            
            torch.save(torch.cat(losses, dim=0), path_eval_init)

        for bs in train_batch_steps:
            print(f"[{run}][{t}][{bs}]\t\t\t", end="\r")
            batch_dir = out_dir / f"train-batch={bs}"

            # 0.1 skip batch dir if already done and no overwrite
            path_train_init = batch_dir / f"init/train/step{t}.pt"
            path_train_post = batch_dir / f"post/train/step{t}.pt"
            spath_eval_init = batch_dir / f"init/eval-{DATASET}/step{t}.pt"
            path_eval_post = batch_dir / f"post/eval-{DATASET}/step{t}.pt"
            skip_batch_dir = all(
                [
                    path_train_init.exists(),
                    path_train_post.exists(),
                    path_eval_post.exists(),
                    spath_eval_init.exists(),
                    not OVERWRITE,
                ]
            )
            if skip_batch_dir:
                continue

            # 0.2 load model and optimizer
            if model is None:
                model = load_model(run, t, overrides=overrides, device=device)
                model = model.to(torch.bfloat16)
            optimizer = load_optimizer(model, run, t, device=device)
            optimizer.zero_grad()
            for pg in optimizer.param_groups:
                if pg['lr'] == 0:
                    assert run.startswith('1028-rmsnorm'), f"Learning rate hack might be wrong here."
                    pg['lr'] = pg['initial_lr']

            # 0.3 load train microbatches
            batch = train_batches[bs]
            global_batch_num_tokens = batch.shape[0] * (batch.shape[1] - 1)
            microbsz = get_device_bsz(run)
            num_microbatches = batch.shape[0] // microbsz
            train_microbatch_dataloader = [
                batch[i * microbsz : (i + 1) * microbsz].to(device)
                for i in range(num_microbatches)
            ]

            # 1. init eval losses (symlink)
            if spath_eval_init.exists() and OVERWRITE:
                spath_eval_init.unlink()
            if not spath_eval_init.exists():
                spath_eval_init.parent.mkdir(parents=True, exist_ok=True)
                spath_eval_init.symlink_to(path_eval_init)

            # 2. init train losses (with backward pass and weight update for post losses)
            losses = []
            model.train()
            for i, microbatch in enumerate(train_microbatch_dataloader):
                print(f"[{run}][{t}][{bs}][{i}][init-train]\t\t\t", end="\r")
                input_ids = microbatch[:, :-1].to(device)
                labels = microbatch[:, 1:].flatten().to(device)
                with amp_ctx:
                    logits = model(input_ids).logits.flatten(0, 1)
                    _losses = F.cross_entropy(logits, labels, reduction="none")
                    loss = _losses.sum() / global_batch_num_tokens
                loss.backward()
                losses.append(_losses.detach().cpu())
                del logits, _losses, loss
            optimizer.step()
            del optimizer
            model.eval()
            if not path_train_init.exists():
                path_train_init.parent.mkdir(parents=True, exist_ok=True)
                torch.save(torch.cat(losses, dim=0), path_train_init)

            # 3. post train losses
            if (not path_train_post.exists()) or (path_train_post.exists() and OVERWRITE):
                path_train_post.parent.mkdir(parents=True, exist_ok=True)
                losses = []
                with inf_ctx, amp_ctx:
                    for i, microbatch in enumerate(train_microbatch_dataloader):
                        print(f"[{run}][{t}][{bs}][{i}][post-train]\t\t\t", end="\r")
                        input_ids = microbatch[:, :-1].to(device)
                        labels = microbatch[:, 1:].flatten().to(device)
                        logits = model(input_ids).logits.flatten(0, 1)
                        _losses = F.cross_entropy(logits, labels, reduction="none")
                        losses.append(_losses.detach().cpu())
                        del logits, _losses
                torch.save(torch.cat(losses, dim=0), path_train_post)

            # 4. post eval losses
            if (not path_eval_post.exists()) or (path_eval_post.exists() and OVERWRITE):
                path_eval_post.parent.mkdir(parents=True, exist_ok=True)
                losses = []
                with inf_ctx, amp_ctx:
                    for i, batch in enumerate(eval_dataloader):
                        print(f"[{run}][{t}][{bs}][{i}][post-eval]\t\t\t", end="\r")
                        batch = batch["input_ids"].to(device)
                        input_ids = batch[:, :-1].contiguous()
                        labels = batch[:, 1:].flatten().to(device)
                        logits = model(input_ids).logits.flatten(0, 1)
                        _losses = F.cross_entropy(logits, labels, reduction="none")
                        losses.append(_losses.detach().cpu())
                        del logits, _losses
                torch.save(torch.cat(losses, dim=0), path_eval_post)

[1028-rmsnorm-472m][262144][262148][15][post-eval]				a/zsl_scratch/analysis/per_token_change_in_loss-dt=1/1028-rmsnorm-472m/.eval-c4_en_val/step262144.pt