In [4]:
import os
import json
import math
import torch
from tqdm.auto import tqdm
import argparse

from load_model import load_model
from bayes_radd import zero_shot_ppl, zero_shot_mc_ppl
import data

# ─── SETTINGS ────────────────────────────────────────────────────────────────
MODEL_PATH   = "JingyangOu/radd-t-dce"
SEQ_LEN      = 1024
BATCH_SIZE   = 16
MAX_BATCHES  = 1   # for a quick smoke test; set to None to do the full split
K = 64
OUT_DIR      = "BayesRADD"
os.makedirs(OUT_DIR, exist_ok=True)

# ─── DEVICE & SEED ────────────────────────────────────────────────────────────
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ─── MODEL & NOISE ───────────────────────────────────────────────────────────
model, noise = load_model(MODEL_PATH, device)
model.eval()

# ─── DATA LOADER ─────────────────────────────────────────────────────────────
def build_loader():
    args = argparse.Namespace(
        cache_dir="data", batch_size=BATCH_SIZE,
        length=SEQ_LEN, valid_dataset="wikitext2", ngpus=1
    )
    return data.get_valid_dataloaders(args, distributed=False)

# ─── RUN ONE-SHOT BAYESRADD PPL ───────────────────────────────────────────────
if __name__ == "__main__":
    loader = build_loader()
    vanilla = zero_shot_ppl(
        model=model, noise=noise, device=device,
        dataloader=loader, max_batches=MAX_BATCHES, sequence_length=SEQ_LEN
    )

    mc = zero_shot_mc_ppl(
        model=model, noise=noise, device=device,
        dataloader=loader, K=K, mask_rate=0.05,
        max_batches=MAX_BATCHES, sequence_length=SEQ_LEN
    )

    print(f"Vanilla λ-DCE PPL = {vanilla:.2f}")
    print(f"MC-marginal PPL   = {mc:.2f}")

DCE PPL:   0%|          | 0/1 [00:00<?, ?it/s]

MC double‐mask PPL:   0%|          | 0/1 [00:00<?, ?it/s]

Vanilla λ-DCE PPL = 42.37
MC-marginal PPL   = 49.41
