In [1]:
import math
import torch
from tqdm.auto import tqdm
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
from datasets     import load_dataset

# 👇 your RADD imports
from load_model       import load_model
from bayes_radd        import mc_marginal_tokenwise

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ——— GPT-2 baseline (just for reference) ————————————————
tok2   = GPT2TokenizerFast.from_pretrained("gpt2")
gpt2   = GPT2LMHeadModel.from_pretrained("gpt2").to(device).eval()

# ——— Your RADD model —————————————————————————————
# assume load_model returns (model, noise) and you’ve registered bayesRADD.mc_marginal_tokenwise
radd, noise = load_model("JingyangOu/radd-t-dce", device)
radd.eval()

# ——— helper to compute RADD MC‐marginal PPL on a batch ——————
@torch.no_grad()
def radd_ppl_on_block(x, K):
    # x: [1, L] contiguous chunk of token‐ids
    # we will only score the last trg_len tokens
    mc_p_true, _ = mc_marginal_tokenwise(radd, noise, x, K)  # → [1, L]
    # we only care about the tail positions (where we masked prefix)
    # so return the sum of negative log‐prob of those tokens
    return -(mc_p_true.clamp(min=1e-12).log().sum().item())

# ——— load WikiText-2 test as one long sequence ————————
test = load_dataset("wikitext", "wikitext-2-v1", split="test")
raw  = "\n\n".join(test["text"])
enc2 = tok2(raw, return_tensors="pt")
ids = enc2.input_ids.to(device)  # [1, N]
N   = ids.size(1)

# ——— sliding‐window settings ————————————————————
max_len, stride = 1024, 512
nll_sum_radd, n_tok_radd = 0.0, 0
nll_sum_gpt2, n_tok_gpt2 = 0.0, 0

K = 8  # number of MC samples for RADD

for i in tqdm(range(0, N, stride)):
    begin = max(i + stride - max_len, 0)
    end   = min(i + stride, N)
    trg_len = end - i

    block = ids[:, begin:end]            # [1, L]
    # — GPT2
    with torch.no_grad():
        # mask out the prefix in the labels
        labels = block.clone()
        labels[:, :-trg_len] = -100
        loss2 = gpt2(block, labels=labels).loss
    nll_sum_gpt2 += loss2.item() * trg_len
    n_tok_gpt2  += trg_len

    # — RADD MC marginal
    # we score that same block by calling mc_marginal_tokenwise,
    # but only summing over the last trg_len positions
    radd_nll = radd_ppl_on_block(block, K)
    nll_sum_radd += radd_nll
    n_tok_radd  += trg_len

    if end == N:
        break

ppl_gpt2 = math.exp(nll_sum_gpt2 / n_tok_gpt2)
ppl_radd = math.exp(nll_sum_radd  / n_tok_radd)

print(f"GPT-2 PPL  = {ppl_gpt2:.2f}")
print(f"RADD (K={K}) PPL = {ppl_radd:.2f}")

Token indices sequence length is longer than the specified maximum sequence length for this model (297300 > 1024). Running this sequence through the model will result in indexing errors


  0%|          | 0/581 [00:00<?, ?it/s]

`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.
  with torch.cuda.amp.autocast(dtype=self.dtype):
  with torch.cuda.amp.autocast(enabled=False):
  with torch.cuda.amp.autocast(enabled=False):


GPT-2 PPL  = 20.33
RADD (K=8) PPL = 8.83


In [4]:
import math
import torch
from tqdm.auto import tqdm
from transformers import GPT2TokenizerFast
from datasets import load_dataset
from load_model import load_model
from losses import get_loss_fn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load tokenizer & data
tok = GPT2TokenizerFast.from_pretrained("gpt2")
test = load_dataset("wikitext", "wikitext-2-v1", split="test")
raw  = "\n\n".join(test["text"])
ids  = tok(raw, return_tensors="pt").input_ids.to(device)
N    = ids.size(1)

# Load RADD and analytic DCE loss
model, noise = load_model("JingyangOu/radd-t-dce", device)
model.eval()
dce_loss_fn = get_loss_fn(
    noise,
    model.config.tokens + 1,  # include [MASK]
    train=False,
    loss_type="lambda_DCE"
)

@torch.no_grad()
def vanilla_radd_ppl(seq_len=1024, stride=512, max_batches=None):
    total_nll, total_toks = 0.0, 0
    batches = 0

    for i in tqdm(range(0, N, stride)):
        if max_batches and batches >= max_batches:
            break

        begin = max(i + stride - seq_len, 0)
        end   = min(i + stride, N)
        trg_len = end - i

        block = ids[:, begin:end]     # [1, L]
        L = block.size(1)

        # --- get *sum* NLL over the block (dce_loss_fn returns sum NLL) ---
        block_nll = dce_loss_fn(model, block).item()

        # --- isolate the tail (last trg_len tokens) ---
        tail_nll = block_nll * (trg_len / L)

        total_nll  += tail_nll
        total_toks += trg_len
        batches    += 1

        if end == N:
            break

    # average NLL per token, then exponentiate
    avg_nll = total_nll / total_toks
    return math.exp(avg_nll)

# Run it!
ppl_vanilla = vanilla_radd_ppl(seq_len=1024, stride=512, max_batches=100)
print(f"Vanilla RADD PPL = {ppl_vanilla:.2f}")

Token indices sequence length is longer than the specified maximum sequence length for this model (297300 > 1024). Running this sequence through the model will result in indexing errors


  0%|          | 0/581 [00:00<?, ?it/s]

Vanilla RADD PPL = 41.10
