# Table 8 Reproduction ‚Äî LITE Budget

ŒëŒΩŒ±œÄŒ±œÅŒ±Œ≥œâŒ≥ŒÆ Table 8 Œ±œÄœå Arriola et al. (ICLR 2025).

**Budget:** Base 5K steps / 10K samples, Finetune 3K steps / 5K samples

**ŒïŒ∫œÑ. œáœÅœåŒΩŒøœÇ:** ~1 œéœÅŒ± œÉŒµ Colab T4

**Œ†ŒµŒπœÅŒ¨ŒºŒ±œÑŒ±:** 5 noise schedules √ó 2 block sizes (L'=4, L'=16)

**Workflow:** Base training ‚Üí Fine-tune per schedule ‚Üí Evaluate (Linear noise schedule)

In [26]:
# Clone project repo
import os
if not os.path.exists('/content/bd3lms'):
    !cd /content && git clone https://github.com/ntua-el21050/bd3lms.git
else:
    print("Repo already exists, skipping clone")

!mkdir -p /content/repro_runs_v2

Repo already exists, skipping clone


In [27]:
# Install dependencies
%%capture
!pip install -r bd3lms/requirements.txt


# Step 1: Patch diffusion.py (Print variance)

In [28]:

import os
import subprocess
import shutil
import sys
import re
import pandas as pd

diffusion_file = '/content/bd3lms/diffusion.py'
with open(diffusion_file, 'r') as f:
    content = f.read()

if 'VARIANCE:' not in content:
    print("üõ†Ô∏è Patching diffusion.py...")
    old_log = """      self.log(f'valid_var_{round(eps_min, 2)} - {round(eps_max, 2)}',
                all_vars / len(var),
                on_epoch=True,
                on_step=False,
                sync_dist=True)"""

    new_log = """      _var_val = (all_vars / len(var)).item()
      print(f'VARIANCE: valid_var_{round(eps_min, 2)} - {round(eps_max, 2)} = {_var_val:.4f}')
      self.log(f'valid_var_{round(eps_min, 2)} - {round(eps_max, 2)}',
                all_vars / len(var),
                on_epoch=True,
                on_step=False,
                sync_dist=True)"""

    content = content.replace(old_log, new_log)
    with open(diffusion_file, 'w') as f:
        f.write(content)
    # Clear pycache
    shutil.rmtree('/content/bd3lms/__pycache__', ignore_errors=True)
    print("Patch applied.")
else:
    print("diffusion.py already patched.")

diffusion.py already patched.


In [29]:
# Utils

BASE_MAX_STEPS = 5000
BASE_MAX_SAMPLES = 10000
FINETUNE_MAX_STEPS = 3000
FINETUNE_MAX_SAMPLES = 5000

def run_main(overrides):
    env = dict(os.environ)
    env["HYDRA_FULL_ERROR"] = "1"
    cmd = [sys.executable, '-u', 'main.py'] + overrides
    print(f"\n$ python main.py ... {' '.join(overrides[-3:])}")

    proc = subprocess.run(
        cmd,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
        cwd='/content/bd3lms',
        env=env
    )
    if proc.returncode != 0:
        print(f"FAILED:\n{proc.stdout[-2000:]}")
        raise RuntimeError(f"Command failed with code {proc.returncode}")
    return proc.stdout

def train_run(run_name, algo, block_size, from_pretrained=None, extra_overrides=None, is_base=False):
    save_dir = f"/content/repro_runs_final/{run_name}"
    ckpt_path = f"{save_dir}/checkpoints/last.ckpt"

    if os.path.exists(ckpt_path):
        print(f" Checkpoint exists: {run_name}")
        return ckpt_path, ""

    # Set budget
    steps = BASE_MAX_STEPS if is_base else FINETUNE_MAX_STEPS
    samples = BASE_MAX_SAMPLES if is_base else FINETUNE_MAX_SAMPLES

    overrides = [
        "mode=train",
        "data=lm1b-wrap", "data.cache_dir=/content/bd3lms/data",
        "data.streaming=true", f"data.max_train_samples={samples}",
        "model=tiny", "model.length=128", "model.attn_backend=sdpa",
        f"algo={algo}",
        "trainer.accelerator=gpu", "trainer.devices=1", "trainer.num_nodes=1",
        "trainer.precision=16-mixed", "trainer.num_sanity_val_steps=0",
        "trainer.log_every_n_steps=20", "trainer.val_check_interval=100",
        f"trainer.max_steps={steps}",
        "data.max_valid_samples=800", "data.max_test_samples=100",
        f"checkpointing.save_dir={save_dir}",
        "checkpointing.resume_from_ckpt=false",
        "wandb=null",
        "loader.global_batch_size=8", "loader.eval_global_batch_size=8",
        "loader.batch_size=8", "loader.eval_batch_size=8",
        "loader.num_workers=2", "trainer.accumulate_grad_batches=1",
        f"block_size={block_size}",
    ]

    if from_pretrained:
        overrides.append(f"training.from_pretrained={from_pretrained}")

    if extra_overrides:
        overrides.extend(extra_overrides)

    log = run_main(overrides)
    return ckpt_path, log

def eval_run(checkpoint_path, algo, block_size):
    overrides = [
        "mode=ppl_eval",
        "data=lm1b-wrap", "data.cache_dir=/content/bd3lms/data",
        "data.streaming=true", "data.max_test_samples=1000",
        "model=tiny", "model.length=128", "model.attn_backend=sdpa",
        f"algo={algo}", f"eval.checkpoint_path={checkpoint_path}",
        "trainer.accelerator=gpu", "trainer.devices=1",
        "trainer.precision=16-mixed", "trainer.num_sanity_val_steps=0",
        "wandb=null",
        "noise.type=loglinear",
        "algo.var_min=false",
        "data.max_train_samples=1000", "data.max_valid_samples=100",
        "loader.global_batch_size=8", "loader.batch_size=8",
        f"block_size={block_size}",
    ]
    log = run_main(overrides)
    for line in reversed(log.splitlines()):
        if "val/ppl" in line.lower():
            m = re.search(r"val/ppl.*?([0-9]+(?:\.[0-9]+)?)", line, re.IGNORECASE)
            if m: return float(m.group(1))
    return None

def extract_valid_var(log_text, key):
    last_val = None
    for line in log_text.splitlines():
        if f"VARIANCE: {key}" in line:
            m = re.search(r"=\s*([0-9]+\.?[0-9]*)", line)
            if m: last_val = float(m.group(1))
    return last_val

# Step 2. Run base training (if missing)

In [30]:
print(f"\n{'='*60}\nSTEP A: BASE MODEL\n{'='*60}")
base_run_name = "bd3lm_base_len128_vfinal"
bd3lm_base_ckpt, _ = train_run(
    run_name=base_run_name,
    algo="bd3lm",
    block_size=128,
    from_pretrained=None,
    extra_overrides=["training.resample=false", "algo.var_min=false"],
    is_base=True
)
print(f"Base Checkpoint: {bd3lm_base_ckpt}")


STEP A: BASE MODEL

$ python main.py ... block_size=128 training.resample=false algo.var_min=false
Base Checkpoint: /content/repro_runs_final/bd3lm_base_len128_vfinal/checkpoints/last.ckpt


# Step 3: Fine-tuning and evaluation

In [31]:
paper_data = {
    4:  {"Clipped U[0.45,0.95]": 29.21, "Clipped U[0.3,0.8]": 29.38, "Linear U[0,1]": 30.18, "Logarithmic": 30.36, "Square root": 31.41},
    16: {"Clipped U[0.45,0.95]": 31.42, "Clipped U[0.3,0.8]": 31.12, "Linear U[0,1]": 31.72, "Square": 31.43, "Cosine": 31.41}
}

schedules = {
    4: [
        ("Clipped U[0.45,0.95]", ["training.sampling_eps_min=0.45", "training.sampling_eps_max=0.95"], "valid_var_0.45 - 0.95"),
        ("Clipped U[0.3,0.8]",   ["training.sampling_eps_min=0.3", "training.sampling_eps_max=0.8"], "valid_var_0.3 - 0.8"),
        ("Linear U[0,1]",        ["training.sampling_eps_min=0.001", "training.sampling_eps_max=1.0"], "valid_var_0.0 - 1"),
        ("Logarithmic",          ["noise.type=log", "training.sampling_eps_min=0.001", "training.sampling_eps_max=1.0"], "valid_var_0.0 - 1"),
        ("Square root",          ["noise.type=square_root", "training.sampling_eps_min=0.001", "training.sampling_eps_max=1.0"], "valid_var_0.0 - 1"),
    ],
    16: [
        ("Clipped U[0.45,0.95]", ["training.sampling_eps_min=0.45", "training.sampling_eps_max=0.95"], "valid_var_0.45 - 0.95"),
        ("Clipped U[0.3,0.8]",   ["training.sampling_eps_min=0.3", "training.sampling_eps_max=0.8"], "valid_var_0.3 - 0.8"),
        ("Linear U[0,1]",        ["training.sampling_eps_min=0.001", "training.sampling_eps_max=1.0"], "valid_var_0.0 - 1"),
        ("Square",               ["noise.type=square", "training.sampling_eps_min=0.001", "training.sampling_eps_max=1.0"], "valid_var_0.0 - 1"),
        ("Cosine",               ["noise.type=cosine", "training.sampling_eps_min=0.001", "training.sampling_eps_max=1.0"], "valid_var_0.0 - 1"),
    ]
}

all_results = {}

for Lp in [4, 16]:
    print(f"\n{'='*70}\nTABLE 8 ‚Äî Block Size L'={Lp}\n{'='*70}")
    results = []

    for name, sched_ov, var_key in schedules[Lp]:
        print(f"\n--- Processing: {name} ---")
        safe_name = name.replace("[","").replace("]","").replace(",","_").replace(" ","_")
        run_name = f"bd3lm_fine_{safe_name}_Lp{Lp}"

        # A. TRAIN (With Variance Monitoring)
        ckpt, train_log = train_run(
            run_name,
            algo="bd3lm",
            block_size=Lp,
            from_pretrained=bd3lm_base_ckpt,
            extra_overrides=[
                "training.resample=true",
                "algo.var_min=true",
                "algo.clip_search_widths=[0.5]",
                "algo.fix_clipping=true",
            ] + sched_ov,
            is_base=False
        )

        # B. EXTRACT VARIANCE & EVAL PPL
        var_nelbo = extract_valid_var(train_log, key=var_key)
        ppl = eval_run(ckpt, algo="bd3lm", block_size=Lp)

        print(f"  ‚úì PPL={ppl:.2f}, Var={var_nelbo if var_nelbo else 'N/A'}")
        results.append({"Schedule": name, "PPL": ppl, "Var NELBO": var_nelbo})

    all_results[Lp] = results


TABLE 8 ‚Äî Block Size L'=4

--- Processing: Clipped U[0.45,0.95] ---

$ python main.py ... algo.fix_clipping=true training.sampling_eps_min=0.45 training.sampling_eps_max=0.95

$ python main.py ... loader.global_batch_size=8 loader.batch_size=8 block_size=4
  ‚úì PPL=1199.43, Var=15.2313

--- Processing: Clipped U[0.3,0.8] ---

$ python main.py ... algo.fix_clipping=true training.sampling_eps_min=0.3 training.sampling_eps_max=0.8

$ python main.py ... loader.global_batch_size=8 loader.batch_size=8 block_size=4
  ‚úì PPL=1101.03, Var=27.6118

--- Processing: Linear U[0,1] ---

$ python main.py ... algo.fix_clipping=true training.sampling_eps_min=0.001 training.sampling_eps_max=1.0

$ python main.py ... loader.global_batch_size=8 loader.batch_size=8 block_size=4
  ‚úì PPL=751.15, Var=260.1035

--- Processing: Logarithmic ---

$ python main.py ... noise.type=log training.sampling_eps_min=0.001 training.sampling_eps_max=1.0

$ python main.py ... loader.global_batch_size=8 loader.batch_si

# Step 4: Log Results

In [32]:
for Lp, res_list in all_results.items():
    print(f"\n\n=== FINAL RESULTS TABLE 8 (L'={Lp}) ===")
    df = pd.DataFrame(res_list).sort_values("PPL")
    # Add Paper Reference
    df['Paper PPL'] = df['Schedule'].map(paper_data[Lp])
    display(df.style.format({"PPL": "{:.2f}", "Var NELBO": "{:.4f}", "Paper PPL": "{:.2f}"}).background_gradient(subset=["PPL"]))



=== FINAL RESULTS TABLE 8 (L'=4) ===


ModuleNotFoundError: No module named 'numpy.rec'