# Table 5 Reproduction: Zero-Shot Perplexity

**Ίδιες παράμετροι με Table 4**, διαφορετικό evaluation dataset.

| | Table 4 | Table 5 |
|---|---|---|
| Training | OWT | OWT (ίδιο) |
| Evaluation | OWT | Wikitext, PTB, LM1B, κλπ |

In [1]:
# Clone repo
!cd /content && rm -rf bd3lms
!cd /content && git clone https://github.com/ntua-el21050/bd3lms.git

# Create directories
!mkdir -p /content/bd3lms/data
!mkdir -p /content/repro_runs

Cloning into 'bd3lms'...
remote: Enumerating objects: 841, done.[K
remote: Counting objects: 100% (298/298), done.[K
remote: Compressing objects: 100% (100/100), done.[K
remote: Total 841 (delta 240), reused 224 (delta 198), pack-reused 543 (from 1)[K
Receiving objects: 100% (841/841), 3.00 MiB | 38.83 MiB/s, done.
Resolving deltas: 100% (534/534), done.


In [2]:
!pip install -q torchmetrics==1.6.2 datasets==3.3.2 einops==0.8.1 \
    hydra-core==1.3.2 lightning==2.5.0.post0 transformers==4.49.0

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.4/40.4 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.0/44.0 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m931.6/931.6 kB[0m [31m56.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m485.4/485.4 kB[0m [31m32.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.5/154.5 kB[0m [31m17.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m815.2/815.2 kB[0m [31m30.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.0/10.0 MB[0m [31m25.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m183.9/183.9 kB[0m [31m20.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━

In [3]:
import subprocess
import re
import os
import shutil
import sys
from pathlib import Path
import pandas as pd

def run_main(overrides, timeout=None):
    """Run `python -u main.py ...` and return combined stdout/stderr text."""
    env = dict(os.environ)
    env.setdefault("HYDRA_FULL_ERROR", "1")
    cmd = [sys.executable, "-u", "bd3lms/main.py", *overrides]
    print("\n$", " ".join(cmd))
    proc = subprocess.run(
        cmd,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
        timeout=timeout,
        check=False,
        env=env,
    )
    print(proc.stdout[-4000:])
    if proc.returncode != 0:
        raise RuntimeError(f"Command failed with return code {proc.returncode}")
    return proc.stdout

_METRIC_PATTERNS = [
    re.compile(r"val/ppl\s*[:=]\s*([0-9]+(?:\.[0-9]+)?(?:e[+-]?\d+)?)", re.IGNORECASE),
    re.compile(r"'val/ppl'\s*:\s*([0-9]+(?:\.[0-9]+)?(?:e[+-]?\d+)?)", re.IGNORECASE),
    re.compile(r"val/ppl\s*[│|]\s*([0-9]+(?:\.[0-9]+)?(?:e[+-]?\d+)?)", re.IGNORECASE),
]

def extract_val_ppl(log_text: str):
    for line in reversed(log_text.splitlines()):
        if "val/ppl" in line.lower():
            m = re.search(r"val/ppl.*?([0-9]+(?:\.[0-9]+)?(?:e[+-]?\d+)?)", line, re.IGNORECASE)
            if m:
                return float(m.group(1))
    hits = []
    for pat in _METRIC_PATTERNS:
        hits.extend(pat.findall(log_text))
    return float(hits[-1]) if hits else None

def _small_loader_overrides(batch_size=4, num_workers=2):
    """Smaller batch size for OWT (1024 context length)."""
    return [
        f"loader.global_batch_size={batch_size}",
        f"loader.eval_global_batch_size={batch_size}",
        f"loader.batch_size={batch_size}",
        f"loader.eval_batch_size={batch_size}",
        f"loader.num_workers={num_workers}",
        "trainer.accumulate_grad_batches=1",
    ]

In [4]:
def train_run(run_name, algo, block_size=None, from_pretrained=None, max_steps=800, extra_overrides=None):
    """Train a model for Table 5 (identical to Table 4)."""
    save_dir = Path("/content/repro_runs") / run_name
    if save_dir.exists():
        shutil.rmtree(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)

    overrides = [
        "mode=train",
        # ══════════════════════════════════════════════════════════════════
        # DATA CONFIG (ίδιο με Table 4)
        # ══════════════════════════════════════════════════════════════════
        "data=openwebtext-split",
        "data.cache_dir=/content/bd3lms/data",
        "data.streaming=true",
        "data.max_train_samples=1200",      # ← Ίδιο με Table 4
        "data.max_valid_samples=100",       # ← Ίδιο με Table 4
        "data.max_test_samples=100",        # ← Ίδιο με Table 4
        # ══════════════════════════════════════════════════════════════════
        # MODEL CONFIG (ίδιο με Table 4)
        # ══════════════════════════════════════════════════════════════════
        "model=tiny",
        "model.length=1024",
        "model.attn_backend=sdpa",
        f"algo={algo}",
        # ══════════════════════════════════════════════════════════════════
        # TRAINER CONFIG (ίδιο με Table 4)
        # ══════════════════════════════════════════════════════════════════
        "trainer.accelerator=gpu",
        "trainer.devices=1",
        "trainer.num_nodes=1",
        "trainer.precision=16-mixed",
        "trainer.num_sanity_val_steps=0",
        "trainer.log_every_n_steps=10",
        "trainer.val_check_interval=10",
        f"trainer.max_steps={max_steps}",   # ← 800 by default
        f"checkpointing.save_dir=/content/repro_runs/{run_name}",
        "checkpointing.resume_from_ckpt=false",
        "wandb=null",
    ]
    overrides.extend(_small_loader_overrides(batch_size=4, num_workers=2))

    if block_size is not None:
        overrides.append(f"block_size={block_size}")
    if from_pretrained is not None:
        overrides.append(f"training.from_pretrained={from_pretrained}")
    if extra_overrides:
        overrides.extend(extra_overrides)

    _ = run_main(overrides)
    ckpt = save_dir / "checkpoints" / "last.ckpt"
    if not ckpt.exists():
        raise FileNotFoundError(f"Expected checkpoint not found: {ckpt}")
    return str(ckpt)

In [5]:
def eval_zeroshot(algo, checkpoint_path, dataset_name, block_size=None, extra_overrides=None):
    """
    Zero-shot evaluation: trained on OWT, evaluated on DIFFERENT dataset.

    Built-in datasets: 'wikitext', 'ptb', 'lm1b', 'lambada'
    """
    overrides = [
        "mode=ppl_eval",
        # ══════════════════════════════════════════════════════════════════
        # ZERO-SHOT: Different dataset!
        # ══════════════════════════════════════════════════════════════════
        f"data={dataset_name}",             # ← CHANGED for zero-shot
        "data.cache_dir=/content/bd3lms/data",
        "data.streaming=true",
        "data.max_test_samples=500",        # ← Ίδιο με Table 4 eval
        # ══════════════════════════════════════════════════════════════════
        # MODEL CONFIG (must match training)
        # ══════════════════════════════════════════════════════════════════
        "model=tiny",
        "model.length=1024",
        "model.attn_backend=sdpa",
        f"algo={algo}",
        f"eval.checkpoint_path={checkpoint_path}",
        "trainer.accelerator=gpu",
        "trainer.devices=1",
        "trainer.num_nodes=1",
        "trainer.precision=16-mixed",
        "trainer.num_sanity_val_steps=0",
        "wandb=null",
    ]
    overrides.extend(_small_loader_overrides(batch_size=4, num_workers=2))

    if block_size is not None:
        overrides.append(f"block_size={block_size}")
    if extra_overrides:
        overrides.extend(extra_overrides)

    log_text = run_main(overrides)
    ppl = extract_val_ppl(log_text)
    if ppl is None:
        raise ValueError("Could not parse val/ppl from output.")
    return ppl

## Training (Ίδιο με Table 4)

In [6]:
results = []
CHECKPOINTS = {}

# ══════════════════════════════════════════════════════════════════════════════
# 1) AUTOREGRESSIVE BASELINE
# ══════════════════════════════════════════════════════════════════════════════
print("=" * 60)
print("Training AR baseline...")
print("=" * 60)
ar_run = "ar_tiny_owt_len1024"
CHECKPOINTS["AR"] = train_run(ar_run, algo="ar")
print(f"✓ AR checkpoint: {CHECKPOINTS['AR']}")

Training AR baseline...

$ /usr/bin/python3 -u bd3lms/main.py mode=train data=openwebtext-split data.cache_dir=/content/bd3lms/data data.streaming=true data.max_train_samples=1200 data.max_valid_samples=100 data.max_test_samples=100 model=tiny model.length=1024 model.attn_backend=sdpa algo=ar trainer.accelerator=gpu trainer.devices=1 trainer.num_nodes=1 trainer.precision=16-mixed trainer.num_sanity_val_steps=0 trainer.log_every_n_steps=10 trainer.val_check_interval=10 trainer.max_steps=800 checkpointing.save_dir=/content/repro_runs/ar_tiny_owt_len1024 checkpointing.resume_from_ckpt=false wandb=null loader.global_batch_size=4 loader.eval_global_batch_size=4 loader.batch_size=4 loader.eval_batch_size=4 loader.num_workers=2 trainer.accumulate_grad_batches=1
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been us

In [7]:
# ══════════════════════════════════════════════════════════════════════════════
# 2) DIFFUSION BASELINES: SEDD + MDLM
# ══════════════════════════════════════════════════════════════════════════════
for algo_name, display_name in [("sedd", "SEDD"), ("mdlm", "MDLM")]:
    print("=" * 60)
    print(f"Training {display_name} baseline...")
    print("=" * 60)
    run_name = f"{algo_name}_tiny_owt_len1024"
    CHECKPOINTS[display_name] = train_run(
        run_name,
        algo=algo_name,
        extra_overrides=[
            "training.resample=false",
            "algo.var_min=false",
            "algo.clip_search_widths=[]",
        ],
    )
    print(f"✓ {display_name} checkpoint: {CHECKPOINTS[display_name]}")

Training SEDD baseline...

$ /usr/bin/python3 -u bd3lms/main.py mode=train data=openwebtext-split data.cache_dir=/content/bd3lms/data data.streaming=true data.max_train_samples=1200 data.max_valid_samples=100 data.max_test_samples=100 model=tiny model.length=1024 model.attn_backend=sdpa algo=sedd trainer.accelerator=gpu trainer.devices=1 trainer.num_nodes=1 trainer.precision=16-mixed trainer.num_sanity_val_steps=0 trainer.log_every_n_steps=10 trainer.val_check_interval=10 trainer.max_steps=800 checkpointing.save_dir=/content/repro_runs/sedd_tiny_owt_len1024 checkpointing.resume_from_ckpt=false wandb=null loader.global_batch_size=4 loader.eval_global_batch_size=4 loader.batch_size=4 loader.eval_batch_size=4 loader.num_workers=2 trainer.accumulate_grad_batches=1 training.resample=false algo.var_min=false algo.clip_search_widths=[]
l/nll' was not in top 1

Epoch 2:  29%|██▉       | 100/340 [00:28<01:08,  3.48it/s, v_num=0]
Epoch 2:  29%|██▉       | 100/340 [00:28<01:08,  3.48it/s, v_num=0

In [8]:
# ══════════════════════════════════════════════════════════════════════════════
# 3) BD3-LM BASE TRAINING (block_size = 1024 = L)
# ══════════════════════════════════════════════════════════════════════════════
print("=" * 60)
print("Training BD3-LM BASE (block_size=1024)...")
print("=" * 60)
bd3lm_base_run = "bd3lm_base_owt_len1024"
bd3lm_base_ckpt = train_run(
    bd3lm_base_run,
    algo="bd3lm",
    block_size=1024,
    extra_overrides=[
        "training.resample=false",
        "algo.var_min=false",
        "algo.clip_search_widths=[]",
    ],
)
print(f"✓ BD3-LM base checkpoint: {bd3lm_base_ckpt}")

Training BD3-LM BASE (block_size=1024)...

$ /usr/bin/python3 -u bd3lms/main.py mode=train data=openwebtext-split data.cache_dir=/content/bd3lms/data data.streaming=true data.max_train_samples=1200 data.max_valid_samples=100 data.max_test_samples=100 model=tiny model.length=1024 model.attn_backend=sdpa algo=bd3lm trainer.accelerator=gpu trainer.devices=1 trainer.num_nodes=1 trainer.precision=16-mixed trainer.num_sanity_val_steps=0 trainer.log_every_n_steps=10 trainer.val_check_interval=10 trainer.max_steps=800 checkpointing.save_dir=/content/repro_runs/bd3lm_base_owt_len1024 checkpointing.resume_from_ckpt=false wandb=null loader.global_batch_size=4 loader.eval_global_batch_size=4 loader.batch_size=4 loader.eval_batch_size=4 loader.num_workers=2 trainer.accumulate_grad_batches=1 block_size=1024 training.resample=false algo.var_min=false algo.clip_search_widths=[]
l/nll' was not in top 1

Epoch 2:  29%|██▉       | 100/340 [00:37<01:30,  2.66it/s, v_num=0]
Epoch 2:  29%|██▉       | 100/34

In [9]:
# ══════════════════════════════════════════════════════════════════════════════
# 4) BD3-LM FINE-TUNING (block_size = 16, 8, 4)
# ══════════════════════════════════════════════════════════════════════════════
for Lprime in [16, 8, 4]:
    print("=" * 60)
    print(f"Fine-tuning BD3-LM (block_size={Lprime})...")
    print("=" * 60)
    finetune_run = f"bd3lm_finetune_owt_Lp{Lprime}"
    CHECKPOINTS[f"BD3-LM_L{Lprime}"] = train_run(
        finetune_run,
        algo="bd3lm",
        block_size=Lprime,
        from_pretrained=bd3lm_base_ckpt,
        extra_overrides=[
            "training.resample=true",
            "algo.var_min=false",
            "algo.clip_search_widths=[]",
        ],
    )
    print(f"✓ BD3-LM (L'={Lprime}) checkpoint: {CHECKPOINTS[f'BD3-LM_L{Lprime}']}")

print("\n" + "=" * 60)
print("ALL TRAINING COMPLETE!")
print("=" * 60)

Fine-tuning BD3-LM (block_size=16)...

$ /usr/bin/python3 -u bd3lms/main.py mode=train data=openwebtext-split data.cache_dir=/content/bd3lms/data data.streaming=true data.max_train_samples=1200 data.max_valid_samples=100 data.max_test_samples=100 model=tiny model.length=1024 model.attn_backend=sdpa algo=bd3lm trainer.accelerator=gpu trainer.devices=1 trainer.num_nodes=1 trainer.precision=16-mixed trainer.num_sanity_val_steps=0 trainer.log_every_n_steps=10 trainer.val_check_interval=10 trainer.max_steps=800 checkpointing.save_dir=/content/repro_runs/bd3lm_finetune_owt_Lp16 checkpointing.resume_from_ckpt=false wandb=null loader.global_batch_size=4 loader.eval_global_batch_size=4 loader.batch_size=4 loader.eval_batch_size=4 loader.num_workers=2 trainer.accumulate_grad_batches=1 block_size=16 training.from_pretrained=/content/repro_runs/bd3lm_base_owt_len1024/checkpoints/last.ckpt training.resample=true algo.var_min=false algo.clip_search_widths=[]
l/nll' was not in top 1

Epoch 2:  65%|██

In [14]:
def eval_owt(algo, checkpoint_path, block_size=None, extra_overrides=None):
    """Evaluate on OWT (SANITY CHECK - should get ~2000 PPL like Table 4)."""
    overrides = [
        "mode=ppl_eval",
        "data=openwebtext-split",
        "data.cache_dir=/content/bd3lms/data",
        "data.streaming=true",
        "data.max_test_samples=500",
        "model=tiny",
        "model.length=1024",
        "model.attn_backend=sdpa",
        f"algo={algo}",
        f"eval.checkpoint_path={checkpoint_path}",
        "trainer.accelerator=gpu",
        "trainer.devices=1",
        "trainer.num_nodes=1",
        "trainer.precision=16-mixed",
        "trainer.num_sanity_val_steps=0",
        "wandb=null",
    ]
    overrides.extend(_small_loader_overrides(batch_size=4, num_workers=2))

    if block_size is not None:
        overrides.append(f"block_size={block_size}")
    if extra_overrides:
        overrides.extend(extra_overrides)

    log_text = run_main(overrides)
    ppl = extract_val_ppl(log_text)
    if ppl is None:
        raise ValueError("Could not parse val/ppl from output.")
    return ppl

In [15]:
# ══════════════════════════════════════════════════════════════════════════════
# OWT SANITY CHECK - Should match Table 4 results (~2000 PPL)
# ══════════════════════════════════════════════════════════════════════════════
owt_results = []

EXPECTED_OWT_PPL = {
    "AR": 2035, "SEDD": 2120, "MDLM": 2100,
    "BD3-LM L'=16": 1940, "BD3-LM L'=8": 1940, "BD3-LM L'=4": 1935,
}

print("=" * 60)
print("OWT SANITY CHECK")
print("=" * 60)

# AR
ppl = eval_owt("ar", CHECKPOINTS["AR"])
owt_results.append({"Model": "AR", "OWT_PPL": ppl, "Expected": 2035})
print(f"✓ AR: {ppl:.1f} (expected ~2035)")

# SEDD & MDLM
for name, algo in [("SEDD", "sedd"), ("MDLM", "mdlm")]:
    ppl = eval_owt(algo, CHECKPOINTS[name], extra_overrides=["algo.var_min=false"])
    owt_results.append({"Model": name, "OWT_PPL": ppl, "Expected": EXPECTED_OWT_PPL[name]})
    print(f"✓ {name}: {ppl:.1f} (expected ~{EXPECTED_OWT_PPL[name]})")

# BD3-LM variants
for Lprime in [16, 8, 4]:
    name = f"BD3-LM L'={Lprime}"
    ppl = eval_owt("bd3lm", CHECKPOINTS[f"BD3-LM_L{Lprime}"], block_size=Lprime, extra_overrides=["algo.var_min=false"])
    owt_results.append({"Model": name, "OWT_PPL": ppl, "Expected": EXPECTED_OWT_PPL[name]})
    print(f"✓ {name}: {ppl:.1f} (expected ~{EXPECTED_OWT_PPL[name]})")

# Summary
print("\n" + "=" * 60)
owt_df = pd.DataFrame(owt_results)
print(owt_df.to_string(index=False))


OWT SANITY CHECK

$ /usr/bin/python3 -u bd3lms/main.py mode=ppl_eval data=openwebtext-split data.cache_dir=/content/bd3lms/data data.streaming=true data.max_test_samples=500 model=tiny model.length=1024 model.attn_backend=sdpa algo=ar eval.checkpoint_path=/content/repro_runs/ar_tiny_owt_len1024/checkpoints/last.ckpt trainer.accelerator=gpu trainer.devices=1 trainer.num_nodes=1 trainer.precision=16-mixed trainer.num_sanity_val_steps=0 wandb=null loader.global_batch_size=4 loader.eval_global_batch_size=4 loader.batch_size=4 loader.eval_batch_size=4 loader.num_workers=2 trainer.accumulate_grad_batches=1
| 26740/27619 [06:37<00:13, 67.23it/s]
Validation DataLoader 0:  97%|█████████▋| 26760/27619 [06:38<00:12, 67.23it/s]
Validation DataLoader 0:  97%|█████████▋| 26780/27619 [06:38<00:12, 67.23it/s]
Validation DataLoader 0:  97%|█████████▋| 26800/27619 [06:38<00:12, 67.23it/s]
Validation DataLoader 0:  97%|█████████▋| 26820/27619 [06:38<00:11, 67.23it/s]
Validation DataLoader 0:  97%|███████

## Zero-Shot Evaluation

In [16]:
# ══════════════════════════════════════════════════════════════════════════════
# ZERO-SHOT DATASETS
# ══════════════════════════════════════════════════════════════════════════════
ZEROSHOT_DATASETS = [
    ("Wikitext2", "wikitext2"),        # ← FIXED (όχι "wikitext")
    ("Wikitext103", "wikitext103"),    # ← επιπλέον option
    ("LM1B", "lm1b-gpt2"),             # ← FIXED (όχι "lm1b")
    ("Lambada", "lambada"),
]

# Models to evaluate
EVAL_MODELS = [
    ("AR", "ar", "AR", None),
    ("SEDD", "sedd", "SEDD", None),
    ("MDLM", "mdlm", "MDLM", None),
    ("BD3-LM L'=16", "bd3lm", "BD3-LM_L16", 16),
    ("BD3-LM L'=8", "bd3lm", "BD3-LM_L8", 8),
    ("BD3-LM L'=4", "bd3lm", "BD3-LM_L4", 4),
]

In [17]:
# Run zero-shot evaluation
zeroshot_results = []

for dataset_name, data_cfg in ZEROSHOT_DATASETS:
    print(f"\n{'='*60}")
    print(f"ZERO-SHOT: {dataset_name}")
    print("="*60)

    for model_name, algo, ckpt_key, block_size in EVAL_MODELS:
        ckpt = CHECKPOINTS.get(ckpt_key)
        if not ckpt:
            print(f"  ⚠ Skip {model_name}: no checkpoint")
            continue

        # Extra overrides for diffusion models
        extra = ["algo.var_min=false"] if algo != "ar" else []

        print(f"\n→ {model_name}...")
        try:
            ppl = eval_zeroshot(algo, ckpt, data_cfg, block_size, extra)
            print(f"  ✓ PPL = {ppl}")
        except Exception as e:
            ppl = None
            print(f"  ✗ Error: {str(e)[:100]}")

        zeroshot_results.append({
            "Model": model_name,
            "Dataset": dataset_name,
            "PPL": ppl,
        })

print("\n" + "="*60)
print("ZERO-SHOT EVALUATION COMPLETE!")
print("="*60)


ZERO-SHOT: Wikitext2

→ AR...

$ /usr/bin/python3 -u bd3lms/main.py mode=ppl_eval data=wikitext2 data.cache_dir=/content/bd3lms/data data.streaming=true data.max_test_samples=500 model=tiny model.length=1024 model.attn_backend=sdpa algo=ar eval.checkpoint_path=/content/repro_runs/ar_tiny_owt_len1024/checkpoints/last.ckpt trainer.accelerator=gpu trainer.devices=1 trainer.num_nodes=1 trainer.precision=16-mixed trainer.num_sanity_val_steps=0 wandb=null loader.global_batch_size=4 loader.eval_global_batch_size=4 loader.batch_size=4 loader.eval_batch_size=4 loader.num_workers=2 trainer.accumulate_grad_batches=1
.938160: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-01-25 13:53:14.958480: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register c

In [18]:
# Format results as Table 5
df = pd.DataFrame(zeroshot_results)
pivot = df.pivot(index="Model", columns="Dataset", values="PPL")

# Reorder
model_order = ["AR", "SEDD", "MDLM", "BD3-LM L'=16", "BD3-LM L'=8", "BD3-LM L'=4"]
pivot = pivot.reindex([m for m in model_order if m in pivot.index])

print("\n" + "="*70)
print("TABLE 5: Zero-Shot Validation Perplexities")
print("="*70)
print(pivot.to_string())


TABLE 5: Zero-Shot Validation Perplexities
Dataset       LM1B      Lambada  Wikitext103    Wikitext2
Model                                                    
AR             NaN  1550.629028  2875.344482  2875.344482
SEDD           NaN  1562.114380  3335.748779  3335.748779
MDLM           NaN  1555.910767  3282.915283  3282.915283
BD3-LM L'=16   NaN  1414.840698  3124.815674  3124.815674
BD3-LM L'=8    NaN  1436.505859  3177.801270  3177.801270
BD3-LM L'=4    NaN  1437.541992  3143.730713  3143.730713


In [19]:
# Paper reference
print("\n" + "="*70)
print("PAPER VALUES (Table 5):")
print("="*70)
paper_df = pd.DataFrame({
    "Model": ["AR", "SEDD", "MDLM", "BD3-LM L'=4"],
    "PTB": [81.07, 96.33, 90.96, 96.81],
    "Wikitext": [25.32, 35.98, 33.22, 31.31],
    "LM1B": [51.14, 68.14, 64.94, 60.88],
    "Lambada": [52.13, 48.93, 48.29, 50.03],
}).set_index("Model")
print(paper_df.to_string())
print("\nExpected: BD3-LM < MDLM < SEDD for diffusion models")


PAPER VALUES (Table 5):
               PTB  Wikitext   LM1B  Lambada
Model                                       
AR           81.07     25.32  51.14    52.13
SEDD         96.33     35.98  68.14    48.93
MDLM         90.96     33.22  64.94    48.29
BD3-LM L'=4  96.81     31.31  60.88    50.03

Expected: BD3-LM < MDLM < SEDD for diffusion models


In [21]:
# Check actual keys first
print("Available CHECKPOINTS keys:", list(CHECKPOINTS.keys()))

Available CHECKPOINTS keys: ['AR', 'SEDD', 'MDLM', 'BD3-LM_L16', 'BD3-LM_L8', 'BD3-LM_L4']


In [26]:
# Complete self-contained test for LM1B
DATA_DIR = "/content/bd3lms/data"

# Try lm1b-gpt2 WITHOUT streaming but WITH max_test_samples
def eval_zeroshot_lm1b_gpt2(algo, checkpoint_path, block_size=None):
    """LM1B with GPT2 tokenizer, no streaming, limited samples."""
    overrides = [
        f"mode=ppl_eval",
        f"data=lm1b-gpt2",
        f"data.cache_dir={DATA_DIR}",
        f"data.streaming=false",
        f"data.max_test_samples=500",    # ← Keep this
        f"model=tiny",
        f"model.length=1024",
        f"model.attn_backend=sdpa",
        f"algo={algo}",
        f"eval.checkpoint_path={checkpoint_path}",
        "trainer.accelerator=gpu",
        "trainer.devices=1",
        "wandb=null",
        "loader.eval_batch_size=4",
    ]

    if algo in ["sedd", "mdlm", "bd3lm"]:
        overrides.append("algo.var_min=false")
    if block_size:
        overrides.append(f"block_size={block_size}")

    try:
        output = run_main(overrides, timeout=1800)
        return extract_val_ppl(output)
    except Exception as e:
        print(f"  Error: {e}")
        return None

# Test AR only first
print("Testing LM1B-GPT2 (no streaming, 500 samples)...")
ppl = eval_zeroshot_lm1b_gpt2("ar", CHECKPOINTS['AR'])
print(f"AR PPL: {ppl}")

Testing LM1B-GPT2 (no streaming, 500 samples)...

$ /usr/bin/python3 -u bd3lms/main.py mode=ppl_eval data=lm1b-gpt2 data.cache_dir=/content/bd3lms/data data.streaming=false data.max_test_samples=500 model=tiny model.length=1024 model.attn_backend=sdpa algo=ar eval.checkpoint_path=/content/repro_runs/ar_tiny_owt_len1024/checkpoints/last.ckpt trainer.accelerator=gpu trainer.devices=1 wandb=null loader.eval_batch_size=4
ore details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 1 processes
----------------------------------------------------------------------------------------------------

2026-01-25 15:02:37.720109: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. Y

In [27]:
# ══════════════════════════════════════════════════════════════════════════════
# FULL LM1B EVALUATION (no streaming)
# ══════════════════════════════════════════════════════════════════════════════

print("="*60)
print("ZERO-SHOT: LM1B (no streaming)")
print("="*60)

lm1b_results = {"AR": 2388.21}  # Already have this!

for model_name, algo, ckpt_key, block_size in [
    # ("AR", "ar", "AR", None),  # Already done!
    ("SEDD", "sedd", "SEDD", None),
    ("MDLM", "mdlm", "MDLM", None),
    ("BD3-LM L'=16", "bd3lm", "BD3-LM_L16", 16),
    ("BD3-LM L'=8", "bd3lm", "BD3-LM_L8", 8),
    ("BD3-LM L'=4", "bd3lm", "BD3-LM_L4", 4),
]:
    print(f"\n→ {model_name}...")
    ppl = eval_zeroshot_lm1b_gpt2(algo, CHECKPOINTS[ckpt_key], block_size=block_size)
    lm1b_results[model_name] = ppl
    if ppl:
        print(f"  ✓ PPL = {ppl}")
    else:
        print(f"  ✗ Failed")

print("\n" + "="*60)
print("LM1B Results Summary:")
print("="*60)
for model, ppl in lm1b_results.items():
    print(f"{model:>15}: {ppl}")

ZERO-SHOT: LM1B (no streaming)

→ SEDD...

$ /usr/bin/python3 -u bd3lms/main.py mode=ppl_eval data=lm1b-gpt2 data.cache_dir=/content/bd3lms/data data.streaming=false data.max_test_samples=500 model=tiny model.length=1024 model.attn_backend=sdpa algo=sedd eval.checkpoint_path=/content/repro_runs/sedd_tiny_owt_len1024/checkpoints/last.ckpt trainer.accelerator=gpu trainer.devices=1 wandb=null loader.eval_batch_size=4 algo.var_min=false
ore details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 1 processes
----------------------------------------------------------------------------------------------------

2026-01-25 15:11:38.748023: I tensorflow/core/util/port.cc:153] oneDNN custom oper