In [1]:
from tools.checkpoint import model_from_checkpoint
from training.eval import Evaluator
from data.data_gen_stream import DistributedDataGenerator
import torch

torch.manual_seed(1337)
model, hparams = model_from_checkpoint(
    '/Users/jonathanmiddleton/models/checkpoints/820m-base/20251126T2111-val3.340-step043946-tokens18000281600-run1-final.pt',
    device='mps')
model.eval()

ddg = DistributedDataGenerator(
    filename_pattern="/Users/jonathanmiddleton/projects/daisy/data/dclm_baseline/dclm_baseline_val_000000.bin",
    batch_size=8192,
    rank=0,
    world_size=1,
    start_shard=1,
    device='mps',
)
evaluator = Evaluator(
            data_generator=ddg,
            distributed_enabled=False,
            rank=0,
            attn_window_len=hparams['train_attention_window_len'],
            val_type='pretraining'
)



In [2]:
def get_scalar_views(model):
    L = len(model.blocks)
    s = model.scalars.view(-1)
    skip = s[:L]
    lambdas = s[1 * L:3 * L].view(L, 2)
    sa_lambdas = s[3 * L:5 * L].view(L, 2)
    return skip, lambdas, sa_lambdas

def print_scalar_stats(model):
    skip, lambdas, sa_lambdas = get_scalar_views(model)
    with torch.no_grad():
        def stats(x):
            return dict(
                min=float(x.min()),
                max=float(x.max()),
                mean=float(x.mean()),
                std=float(x.std())
            )
        print("skip_weights:", stats(skip))
        print("lambdas:", stats(lambdas))
        print("sa_lambdas:", stats(sa_lambdas))

In [3]:
def sample_loss(model, total_tokens):
    acc_loss = 0
    samples = 10
    for i in range(samples):
        total_tokens = 100_000
        d = evaluator.eval(model=model, total_tokens=total_tokens, schedule=.83) #.83 pretraining progress from checkpoint
        acc_loss += d['val_loss']
        print(d)
    print(f"avg loss: {acc_loss/samples}")

def renormalize_scalar_pairs_(model, target_norm=1.0, eps=1e-6):
    L = len(model.blocks)
    with torch.no_grad():
        s = model.scalars.view(-1)
        lambdas = s[1 * L:3 * L].view(L, 2)
        sa_lambdas = s[3 * L:5 * L].view(L, 2)

        lam_norm = lambdas.norm(dim=-1, keepdim=True).clamp_min(eps)
        sa_norm = sa_lambdas.norm(dim=-1, keepdim=True).clamp_min(eps)

        lambdas.mul_(target_norm / lam_norm)
        sa_lambdas.mul_(target_norm / sa_norm)

def clone_with_renormalized_scalars(model, target_norm=1.0, eps=1e-6):
    import copy
    new_model = copy.deepcopy(model)
    renormalize_scalar_pairs_(new_model, target_norm=target_norm, eps=eps)
    return new_model

In [4]:
total_tokens = 100_000
print("original stats:")
print_scalar_stats(model)
sample_loss(model, total_tokens)

original stats:
skip_weights: {'min': 0.0881648138165474, 'max': 1.0, 'mean': 0.8417758941650391, 'std': 0.3415984809398651}
lambdas: {'min': -17.922931671142578, 'max': 52.3514404296875, 'mean': 13.798450469970703, 'std': 16.747438430786133}
sa_lambdas: {'min': 0.5, 'max': 78.92021179199219, 'mean': 25.31561279296875, 'std': 21.081050872802734}
{'val_loss': 2.839808146158854, 'val_acc': None, 'epoch': None, 'ema_dloss_per_token': nan}
{'val_loss': 2.9695704778035483, 'val_acc': None, 'epoch': None, 'ema_dloss_per_token': 1.297623316446943e-06}
{'val_loss': 2.7490504582722983, 'val_acc': None, 'epoch': None, 'ema_dloss_per_token': 2.4677626291911003e-07}
{'val_loss': 2.9539562861124673, 'val_acc': None, 'epoch': None, 'ema_dloss_per_token': 7.874608675638839e-07}
{'val_loss': 2.7987181345621743, 'val_acc': None, 'epoch': None, 'ema_dloss_per_token': 8.55081526438398e-08}
{'val_loss': 2.785984992980957, 'val_acc': None, 'epoch': None, 'ema_dloss_per_token': 2.165628210703595e-08}
{'val_

In [6]:
ns_model = clone_with_renormalized_scalars(model)
torch.manual_seed(1337)
evaluator.reset_generator()

print("new stats:")
print_scalar_stats(ns_model)
sample_loss(ns_model, total_tokens)

new stats:
skip_weights: {'min': 0.0881648138165474, 'max': 1.0, 'mean': 0.8417758941650391, 'std': 0.3415984809398651}
lambdas: {'min': -0.9981178045272827, 'max': 0.9999814033508301, 'mean': 0.4664236307144165, 'std': 0.5399631857872009}
sa_lambdas: {'min': 0.007431285455822945, 'max': 0.9999724626541138, 'mean': 0.5980982780456543, 'std': 0.3832336366176605}
{'val_loss': 13.876227060953775, 'val_acc': None, 'epoch': None, 'ema_dloss_per_token': 2.3206257127815456e-05}
{'val_loss': 13.836037953694662, 'val_acc': None, 'epoch': None, 'ema_dloss_per_token': 1.612381266769348e-05}
{'val_loss': 13.820976257324219, 'val_acc': None, 'epoch': None, 'ema_dloss_per_token': 1.1241483778274105e-05}
{'val_loss': 13.80954360961914, 'val_acc': None, 'epoch': None, 'ema_dloss_per_token': 7.83474070167664e-06}
{'val_loss': 13.743260701497396, 'val_acc': None, 'epoch': None, 'ema_dloss_per_token': 5.2854697668084145e-06}
{'val_loss': 13.867200215657553, 'val_acc': None, 'epoch': None, 'ema_dloss_per_