In [None]:
import os
from pathlib import Path
import csv
from tqdm.auto import tqdm
import pandas as pd

from hydra import initialize_config_dir, compose
from hydra.utils import instantiate
import resolver as _

from datasets import load_from_disk
from transformers import AutoTokenizer

import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
import torch.distributed.checkpoint as dcp
from torch.utils.data import DataLoader

from src.core.checkpointing import TrainingState



In [2]:
# --- env vars for torch.distributed ---
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["LOCAL_RANK"] = "0"
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "29500"

torch.cuda.set_device(0)
device = torch.device("cuda:0")

dist.init_process_group(
    backend="nccl",
    rank=0,
    world_size=1,
)

print("✅ torch.distributed initialized (1 GPU)")

✅ torch.distributed initialized (1 GPU)


In [3]:
# --- paths ---
CKPT_DIR = "/storage_nvme_4/nano/models/112902/0/step_100"
DATASET_DIR = "/storage_nvme_1/llm-random/datasets/c4/long_context_2048n8192"
OUT_CSV = "per_token_loss.csv"
SEQ_LEN = 512   # set to 8192 if you want hard truncation to context size
BATCH_SIZE = 32

In [4]:
config_dir = str(Path.cwd() / "configs")

with initialize_config_dir(config_dir=config_dir, version_base=None):
    cfg = compose(config_name="tiny_remote_ctx_scl")

model = instantiate(cfg.model, _convert_="all").to(device)
model.eval()

print(f"Model instantiated on {device}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

fsdp_model = FSDP(model)

Model instantiated on cuda:0
Parameters: 295,900,160




In [5]:
optimizer = torch.optim.AdamW(fsdp_model.parameters(), lr=0.0)
scheduler = instantiate(cfg.trainer.scheduler)(
    optimizer=optimizer, n_steps=cfg.trainer.n_steps
)
state = {"app": TrainingState(fsdp_model, optimizer, scheduler)}

dcp.load(
    state,
    checkpoint_id=CKPT_DIR,
)

print("✅ Sharded checkpoint loaded via TrainingState")


  device = getattr(value, "device", None)
  and md.size != obj.size()
  dim_0_size = sharded_tensor.size()[0]  # type: ignore[index]
  tensor_numel = sharded_tensor.size().numel()  # type: ignore[union-attr]
  tensor = tensor.narrow(0, 0, tensor_numel).reshape(sharded_tensor.size())


✅ Sharded checkpoint loaded via TrainingState




In [6]:
ds = load_from_disk(DATASET_DIR)

print(ds)
print("Columns:", ds.column_names)
print("Example keys:", ds[0].keys())
print("Text preview:", (ds[0]["text"][:200] + "...") if "text" in ds[0] else "NO 'text' COLUMN")


Dataset({
    features: ['text', 'timestamp', 'url', 'length'],
    num_rows: 8192
})
Columns: ['text', 'timestamp', 'url', 'length']
Example keys: dict_keys(['text', 'timestamp', 'url', 'length'])
Text preview: Welcome to Boston Mamas Rock! – where we’re giving a voice to fabulous local mamas from all walks of life. Read on for today’s interview with Susan Dorson & Amy Weitzman, two local moms on a mission t...


In [7]:
tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=True)

print("Tokenizer vocab size:", len(tokenizer))

Tokenizer vocab size: 50257


In [8]:
def collate_no_pad(batch):
    texts = [ex["text"] for ex in batch]
    urls = [ex["url"] for ex in batch]
    timestamps = [ex["timestamp"] for ex in batch]

    enc = tokenizer(
        texts,
        add_special_tokens=False,
        truncation=True,
        max_length=SEQ_LEN,
        return_tensors="pt",
    )

    input_ids = enc["input_ids"]  # [B, <=SEQ_LEN]

    # keep only samples that actually reached SEQ_LEN
    keep = input_ids.size(1) == SEQ_LEN
    if not keep:
        return None  # drop this batch

    return {
        "input_ids": input_ids,
        "url": urls,
        "timestamp": timestamps,
    }

import torch.nn.functional as F

@torch.no_grad()
def batch_per_token_losses(model, input_ids):
    input_ids = input_ids.to(device)        # [B, T]

    out = model(input_ids)
    logits = out.logits if hasattr(out, "logits") else out  # [B, T, V]

    logits = logits[:, :-1, :]   # [B, T-1, V]
    targets = input_ids[:, 1:]   # [B, T-1]

    losses = F.cross_entropy(
        logits.reshape(-1, logits.size(-1)),
        targets.reshape(-1),
        reduction="none",
    ).reshape(targets.shape)     # [B, T-1]

    return losses.cpu(), targets.cpu()

In [9]:
loader = DataLoader(
    ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=collate_no_pad,
)

model.eval()

all_losses = []

with open(OUT_CSV, "w") as f:
    f.write("sample_idx,token_pos,token_id,loss,url,timestamp\n")

    sample_idx = 0
    for batch in tqdm(loader):
        if batch is None:
            continue

        losses, targets = batch_per_token_losses(model, batch["input_ids"])
        all_losses.append(losses)


  0%|          | 0/256 [00:00<?, ?it/s]

In [10]:
import pandas as pd

def tensors_rows_to_csv(tensors, path="tensors.csv"):
    rows = []
    for t in tensors:
        rows.append(t.detach().cpu())
    stacked = torch.cat(rows, dim=0)   # (num_tensors * N, N)
    pd.DataFrame(stacked.numpy()).to_csv(path, index=False)


In [11]:
tensors_rows_to_csv(all_losses, path="per_token_loss.csv")

In [12]:
df = pd.read_csv("per_token_loss.csv")
df.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,501,502,503,504,505,506,507,508,509,510
0,4.085172,9.934204,8.996014,11.50411,9.986458,5.382188,8.986745,7.447004,7.199303,4.985841,...,7.789421,5.581819,8.445265,3.161999,8.101936,2.602781,5.626294,8.515235,11.292056,7.314084
1,10.822108,9.538289,11.703181,2.757541,6.331216,6.654324,3.294222,9.728536,5.596427,7.050842,...,5.177627,8.263029,4.265507,5.648528,8.326301,3.465414,1.184136,9.803685,10.910895,8.787147
2,3.437046,4.038119,7.277862,6.908977,7.173857,9.113975,4.560616,8.288863,3.227307,11.024434,...,8.091841,10.49587,4.027703,11.38557,8.660601,6.807249,3.459799,4.571454,7.038602,2.673347
3,9.886657,3.729474,8.190474,9.9484,9.285893,11.86869,10.385446,5.357697,10.134768,3.330999,...,11.794025,7.370442,10.62501,11.624344,9.877463,12.368013,4.121586,4.333259,7.55892,7.106237
4,7.346539,13.069561,12.154629,11.288257,7.26408,4.670476,4.505062,10.830602,10.59691,4.051549,...,9.952202,9.714775,2.942371,10.537815,3.000067,13.041159,10.456656,6.487325,2.98719,9.749381
