In [1]:
import os
from pathlib import Path
import csv
from tqdm.auto import tqdm
import pandas as pd

from hydra import initialize_config_dir, compose
from hydra.utils import instantiate
import resolver as _

from datasets import load_from_disk
from transformers import AutoTokenizer

import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
import torch.distributed.checkpoint as dcp
from torch.utils.data import DataLoader
from torch.distributed.checkpoint.stateful import Stateful

In [2]:
# --- env vars for torch.distributed ---
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["LOCAL_RANK"] = "0"
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "29500"

torch.cuda.set_device(0)
device = torch.device("cuda:0")

dist.init_process_group(
    backend="nccl",
    rank=0,
    world_size=1,
)

print("✅ torch.distributed initialized (1 GPU)")

✅ torch.distributed initialized (1 GPU)


In [None]:
# --- paths ---
CKPT_DIR = "/storage_nvme_4/nano/models/from_helios/mqa_dff3520/step_320000"
DATASET_DIR = "/storage_nvme_1/llm-random/datasets/c4/long_context_2048n8192"
OUT_CSV = "MQA_3520_loss_k16_42B.csv"
SEQ_LEN = 2048   # set to 8192 if you want hard truncation to context size
BATCH_SIZE = 32
exp_config_name = "ctx_scl_k16_long_mqa"    # remember to pick grid variables

In [4]:
config_dir = str(Path.cwd() / "configs")

with initialize_config_dir(config_dir=config_dir, version_base=None):
    cfg = compose(config_name=exp_config_name)

model = instantiate(cfg.model, _convert_="all").to(device)
model.eval()

print(f"Model instantiated on {device}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

fsdp_model = FSDP(model)

Model instantiated on cuda:0
Parameters: 311,628,800




In [5]:
class ModelOnly(Stateful):
    def __init__(self, model):
        self.model = model

    def state_dict(self):
        return {"model": self.model.state_dict()}

    def load_state_dict(self, sd):
        self.model.load_state_dict(sd["model"], strict=True)

state = {"app": ModelOnly(fsdp_model)}
dcp.load(state, checkpoint_id=CKPT_DIR)

fsdp_model.eval()
print("✅ Model loaded for inference (optimizer skipped)")




✅ Model loaded for inference (optimizer skipped)




In [6]:
ds = load_from_disk(DATASET_DIR)

print(ds)
print("Columns:", ds.column_names)
print("Example keys:", ds[0].keys())
print("Text preview:", (ds[0]["text"][:200] + "...") if "text" in ds[0] else "NO 'text' COLUMN")


Dataset({
    features: ['text', 'timestamp', 'url', 'length'],
    num_rows: 8192
})
Columns: ['text', 'timestamp', 'url', 'length']
Example keys: dict_keys(['text', 'timestamp', 'url', 'length'])
Text preview: Welcome to Boston Mamas Rock! – where we’re giving a voice to fabulous local mamas from all walks of life. Read on for today’s interview with Susan Dorson & Amy Weitzman, two local moms on a mission t...


In [7]:
tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=True)

print("Tokenizer vocab size:", len(tokenizer))

Tokenizer vocab size: 50257


In [8]:
def collate_no_pad(batch):
    texts = [ex["text"] for ex in batch]
    urls = [ex["url"] for ex in batch]
    timestamps = [ex["timestamp"] for ex in batch]

    enc = tokenizer(
        texts,
        add_special_tokens=False,
        truncation=True,
        max_length=SEQ_LEN,
        return_tensors="pt",
    )

    input_ids = enc["input_ids"]  # [B, <=SEQ_LEN]

    # keep only samples that actually reached SEQ_LEN
    keep = input_ids.size(1) == SEQ_LEN
    if not keep:
        return None  # drop this batch

    return {
        "input_ids": input_ids,
        "url": urls,
        "timestamp": timestamps,
    }

import torch.nn.functional as F

@torch.no_grad()
def batch_per_token_losses(model, input_ids):
    input_ids = input_ids.to(device)        # [B, T]

    out = model(input_ids)
    logits = out.logits if hasattr(out, "logits") else out  # [B, T, V]

    logits = logits[:, :-1, :]   # [B, T-1, V]
    targets = input_ids[:, 1:]   # [B, T-1]

    losses = F.cross_entropy(
        logits.reshape(-1, logits.size(-1)),
        targets.reshape(-1),
        reduction="none",
    ).reshape(targets.shape)     # [B, T-1]

    return losses.cpu(), targets.cpu()

In [9]:
loader = DataLoader(
    ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=collate_no_pad,
)

model.eval()

all_losses = []

for batch in tqdm(loader):
    if batch is None:
        continue

    losses, targets = batch_per_token_losses(model, batch["input_ids"])
    all_losses.append(losses)


  0%|          | 0/256 [00:00<?, ?it/s]

In [10]:
import pandas as pd

def tensors_rows_to_csv(tensors, path="tensors.csv"):
    rows = []
    for t in tensors:
        rows.append(t.detach().cpu())
    stacked = torch.cat(rows, dim=0)   # (num_tensors * N, N)
    pd.DataFrame(stacked.numpy()).to_csv(path, index=False)


In [11]:
tensors_rows_to_csv(all_losses, path=OUT_CSV)

In [12]:
df = pd.read_csv(OUT_CSV)
df.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,2037,2038,2039,2040,2041,2042,2043,2044,2045,2046
0,0.38003,8.621256,6.870979,7.890626,8.752855,1.675966,7.656805,4.865206,2.13833,3.451378,...,2.176812,7.168613,4.727401,3.990843,1.318014,1.044751,2.883807,5.411947,4.330785,1.583291
1,10.219949,7.931793,9.692278,2.566181,5.39457,5.408353,0.116439,4.205744,6.292889,3.137556,...,5.322657,2.968103,2.034466,2.765616,3.947982,3.603583,8.287707,0.894379,1.263876,1.33297
2,1.066014,6.690225,2.764734,3.897488,4.028049,5.534327,1.602152,6.999635,4.307511,10.394428,...,3.829848,0.10463,0.028694,2.512537,3.888413,0.497884,6.186173,0.112987,11.303635,4.428379
3,3.484951,0.06547,5.952685,9.535594,8.614286,7.266953,0.302838,2.081389,0.015525,0.015402,...,0.000632,1.034212,0.010907,1.349706,0.044986,0.014288,0.000107,0.001127,0.002771,0.006391
4,3.171163,13.471144,3.424521,8.452765,6.60923,3.802885,5.457229,13.996242,4.587484,4.963999,...,4.449986,2.700965,0.352735,0.698385,4.6511,9.894043,10.196079,0.055088,4.668685,2.494915
