In [1]:
!cd /content && git clone https://github.com/ntua-el21050/bd3lms.git

Cloning into 'bd3lms'...
remote: Enumerating objects: 759, done.[K
remote: Counting objects: 100% (222/222), done.[K
remote: Compressing objects: 100% (51/51), done.[K
remote: Total 759 (delta 196), reused 174 (delta 171), pack-reused 537 (from 2)[K
Receiving objects: 100% (759/759), 1.75 MiB | 5.62 MiB/s, done.
Resolving deltas: 100% (489/489), done.


In [2]:
!ls -l /content/bd3lms

total 1740
-rw-r--r-- 1 root root 862037 Jan  6 21:41 2503.09573v3.pdf
drwxr-xr-x 9 root root   4096 Jan  6 21:41 configs
-rw-r--r-- 1 root root  33535 Jan  6 21:41 dataloader.py
-rw-r--r-- 1 root root  44163 Jan  6 21:41 diffusion.py
-rw-r--r-- 1 root root 225205 Jan  6 21:41 graphical_abstract.png
-rw-r--r-- 1 root root  11357 Jan  6 21:41 LICENSE
-rw-r--r-- 1 root root   7873 Jan  6 21:41 main.py
-rw-r--r-- 1 root root   8405 Jan  6 21:41 metrics.py
drwxr-xr-x 3 root root   4096 Jan  6 21:41 models
-rw-r--r-- 1 root root   2538 Jan  6 21:41 noise_schedule.py
-rw-r--r-- 1 root root   1449 Jan  6 21:41 push_to_hf.py
-rw-r--r-- 1 root root  10070 Jan  6 21:41 README.md
-rw-r--r-- 1 root root    363 Jan  6 21:41 requirements.txt
drwxr-xr-x 7 root root   4096 Jan  6 21:41 scripts
drwxr-xr-x 4 root root   4096 Jan  6 21:41 ssd-lm
-rw-r--r-- 1 root root 525005 Jan  6 21:41 table2_final.ipynb
-rw-r--r-- 1 root root   7162 Jan  6 21:41 utils.py


In [3]:
!pip install -q \
    torchmetrics==1.6.2 \
    datasets==3.3.2 \
    einops==0.8.1 \
    fsspec==2024.2.0 \
    hydra-core==1.3.2 \
    lightning==2.5.0.post0 \
    omegaconf==2.3.0 \
    packaging==23.2 \
    pandas==2.2.1 \
    rich==13.7.1 \
    scikit-learn==1.5.1 \
    timm==0.9.16 \
    transformers==4.49.0 \
    matplotlib==3.10.0 \
    wandb

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.4/40.4 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.0/44.0 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m931.6/931.6 kB[0m [31m28.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m485.4/485.4 kB[0m [31m31.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m170.9/170.9 kB[0m [31m14.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.5/154.5 kB[0m [31m11.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m815.2/815.2 kB[0m [31m39.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
# Helpers to run Hydra commands and parse val/ppl from output
import subprocess
import re
from pathlib import Path

def run_main(overrides, timeout=None):
    """Run `python -u main.py ...` and return combined stdout/stderr text."""
    env = dict(os.environ)
    env.setdefault("HYDRA_FULL_ERROR", "1")
    cmd = [sys.executable, "-u", "bd3lms/main.py", *overrides]
    print("\n$", " ".join(cmd))
    proc = subprocess.run(
        cmd,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
        timeout=timeout,
        check=False,
        env=env,
    )
    print(proc.stdout[-4000:])  # tail for quick visibility
    if proc.returncode != 0:
        raise RuntimeError(f"Command failed with return code {proc.returncode}")
    return proc.stdout

_METRIC_PATTERNS = [
    # Key: value (some loggers print this)
    re.compile(r"val/ppl\s*[:=]\s*([0-9]+(?:\.[0-9]+)?(?:e[+-]?\d+)?)", re.IGNORECASE),
    re.compile(r"'val/ppl'\s*:\s*([0-9]+(?:\.[0-9]+)?(?:e[+-]?\d+)?)", re.IGNORECASE),

    # Lightning "rich" table row (note the unicode box character │)
    re.compile(r"val/ppl\s*[│|]\s*([0-9]+(?:\.[0-9]+)?(?:e[+-]?\d+)?)", re.IGNORECASE),
]

def extract_val_ppl(log_text: str):
    # First try line-based parse from the end (most reliable for tables)
    for line in reversed(log_text.splitlines()):
        if "val/ppl" in line.lower():
            m = re.search(r"val/ppl.*?([0-9]+(?:\.[0-9]+)?(?:e[+-]?\d+)?)", line, re.IGNORECASE)
            if m:
                return float(m.group(1))

    # Fallback: scan entire text with known patterns
    hits = []
    for pat in _METRIC_PATTERNS:
        hits.extend(pat.findall(log_text))
    return float(hits[-1]) if hits else None

def _small_loader_overrides(batch_size=8, num_workers=2):
    """Overrides needed to avoid huge default batch sizes on Colab."""
    return [
        f"loader.global_batch_size={batch_size}",
        f"loader.eval_global_batch_size={batch_size}",
        f"loader.batch_size={batch_size}",
        f"loader.eval_batch_size={batch_size}",
        f"loader.num_workers={num_workers}",
        "trainer.accumulate_grad_batches=1",
    ]

def train_run(run_name, algo, block_size=None, from_pretrained=None, max_steps=800, extra_overrides=None):
    """Train a model (optionally from a base checkpoint) and return the last.ckpt path."""
    save_dir = Path("/content/repro_runs") / run_name
    if save_dir.exists():
        shutil.rmtree(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)

    overrides = [
        "mode=train",
        "data=lm1b-wrap",
        "data.cache_dir=/content/bd3lms/data",
        "data.streaming=true",
        "data.max_train_samples=5000",
        # For LM1B, validation uses the 'test' split in this codebase
        "model=tiny",
        "model.length=128",
        "model.attn_backend=sdpa",
        f"algo={algo}",
        "trainer.accelerator=gpu",
        "trainer.devices=1",
        "trainer.num_nodes=1",
        "trainer.precision=16-mixed",
        "trainer.num_sanity_val_steps=0",
        "trainer.log_every_n_steps=10",
        "trainer.val_check_interval=50",
        f"trainer.max_steps={max_steps}",
        "data.max_valid_samples=100",
        "data.max_test_samples=100",
        f"checkpointing.save_dir=/content/repro_runs/{run_name}",
        "checkpointing.resume_from_ckpt=false",
        "wandb=null",
    ]
    overrides.extend(_small_loader_overrides(batch_size=8, num_workers=2))
    if block_size is not None:
        overrides.append(f"block_size={block_size}")
    if from_pretrained is not None:
        overrides.append(f"training.from_pretrained={from_pretrained}")
    if extra_overrides:
        overrides.extend(extra_overrides)

    _ = run_main(overrides)
    ckpt = save_dir / "checkpoints" / "last.ckpt"
    if not ckpt.exists():
        raise FileNotFoundError(f"Expected checkpoint not found: {ckpt}")
    return str(ckpt)

def eval_run(algo, checkpoint_path, block_size=None, extra_overrides=None):
    """Evaluate perplexity (val/ppl) for a given checkpoint."""
    overrides = [
        "mode=ppl_eval",
        "data=lm1b-wrap",
        "data.cache_dir=/content/bd3lms/data",
        "data.streaming=true",
        # For LM1B, `get_dataloaders` maps validation to the 'test' split
        "data.max_test_samples=1000",
        "model=tiny",
        "model.length=128",
        "model.attn_backend=sdpa",
        f"algo={algo}",
        f"eval.checkpoint_path={checkpoint_path}",
        "trainer.accelerator=gpu",
        "trainer.devices=1",
        "trainer.num_nodes=1",
        "trainer.precision=16-mixed",
        "trainer.num_sanity_val_steps=0",
        "wandb=null",
    ]
    overrides.extend(_small_loader_overrides(batch_size=8, num_workers=2))
    if block_size is not None:
        overrides.append(f"block_size={block_size}")
    if extra_overrides:
        overrides.extend(extra_overrides)

    log_text = run_main(overrides)
    ppl = extract_val_ppl(log_text)
    if ppl is None:
        raise ValueError("Could not parse val/ppl from output. Try increasing the log tail or printing full logs.")
    return ppl

## 3) Run experiments and build a mini Table 3
This now mirrors the paper’s training procedure, but at tiny scale:
- Train a base BD3-LM at L'=L (here 128) once.
- Finetune that base for block sizes L' in {16, 8, 4} with noise-schedule-style resampling enabled.
- Also run AR, SEDD, and MDLM tiny baselines.
Only the scale (steps, samples, model size) is reduced for Colab feasibility.

In [None]:
from pathlib import Path
import os
import shutil
import sys
# import pandas as pd

results = []

# 1) Autoregressive baseline (tiny, scratch)
ar_run = "ar_tiny_len128"
ar_ckpt = train_run(ar_run, algo="ar")
ar_ppl = eval_run(algo="ar", checkpoint_path=ar_ckpt)
results.append({"model": "Autoregressive", "block_size_Lprime": "-", "val_ppl": ar_ppl})

# 2) Diffusion baselines (tiny, scratch): SEDD + MDLM
for algo_name, display_name in [("sedd", "SEDD"), ("mdlm", "MDLM")]:
    run_name = f"{algo_name}_tiny_len128"
    ckpt = train_run(
        run_name,
        algo=algo_name,
        extra_overrides=[
            # lightweight run; keep consistent with other tiny runs
            "training.resample=false",
            "algo.var_min=false",
            "algo.clip_search_widths=[]",
        ],
    )
    ppl = eval_run(
        algo=algo_name,
        checkpoint_path=ckpt,
        extra_overrides=[
            "algo.var_min=false",
        ],
    )
    results.append({"model": display_name, "block_size_Lprime": "-", "val_ppl": ppl})

# 3) BD3LM methodology match (small-scale):
bd3lm_base_run = "bd3lm_base_len128"
bd3lm_base_ckpt = train_run(
    bd3lm_base_run,
    algo="bd3lm",
    block_size=128,
    extra_overrides=[
        "training.resample=false",
        "algo.var_min=false",
        "algo.clip_search_widths=[]",
    ],
)

for Lprime in [16, 8, 4]:
    finetune_run = f"bd3lm_finetune_Lp{Lprime}"
    finetune_ckpt = train_run(
        finetune_run,
        algo="bd3lm",
        block_size=Lprime,
        from_pretrained=bd3lm_base_ckpt,
        extra_overrides=[
            "training.resample=true",
            "algo.var_min=false",
            "algo.clip_search_widths=[]",
        ],
    )
    ppl = eval_run(
        algo="bd3lm",
        checkpoint_path=finetune_ckpt,
        block_size=Lprime,
        extra_overrides=[
            "algo.var_min=false",
        ],
    )
    results.append({"model": "Block diffusion (BD3LM)", "block_size_Lprime": Lprime, "val_ppl": ppl})


$ /usr/bin/python3 -u bd3lms/main.py mode=train data=lm1b-wrap data.cache_dir=/content/bd3lms/data data.streaming=true data.max_train_samples=5000 model=tiny model.length=128 model.attn_backend=sdpa algo=ar trainer.accelerator=gpu trainer.devices=1 trainer.num_nodes=1 trainer.precision=16-mixed trainer.num_sanity_val_steps=0 trainer.log_every_n_steps=10 trainer.val_check_interval=50 trainer.max_steps=800 data.max_valid_samples=100 data.max_test_samples=100 checkpointing.save_dir=/content/repro_runs/ar_tiny_len128 checkpointing.resume_from_ckpt=false wandb=null loader.global_batch_size=8 loader.eval_global_batch_size=8 loader.batch_size=8 loader.eval_batch_size=8 loader.num_workers=2 trainer.accumulate_grad_batches=1
|████████▏ | 120/146 [01:53<00:24,  1.05it/s, v_num=0]
Epoch 3:  96%|█████████▌| 140/146 [01:55<00:04,  1.22it/s, v_num=0]
Epoch 3:  96%|█████████▌| 140/146 [01:55<00:04,  1.22it/s, v_num=0]
Epoch 3: 100%|██████████| 146/146 [01:55<00:00,  1.26it/s, v_num=0]
Epoch 3: 100%|

In [6]:
def print_table(rows):
    if not rows:
        print("No data to display.")
        return

    columns = list(rows[0].keys())

    str_rows = [
        {col: str(row.get(col, "")) for col in columns}
        for row in rows
    ]

    widths = {
        col: max(len(col), max(len(row[col]) for row in str_rows))
        for col in columns
    }

    def print_separator():
        print("+" + "+".join("-" * (widths[col] + 2) for col in columns) + "+")

    def print_row(row):
        print(
            "| " +
            " | ".join(row[col].ljust(widths[col]) for col in columns) +
            " |"
        )

    # Print table
    print_separator()
    print_row({col: col for col in columns})
    print_separator()
    for row in str_rows:
        print_row(row)
    print_separator()

print_table(results)

+-------------------------+-------------------+--------------------+
| model                   | block_size_Lprime | val_ppl            |
+-------------------------+-------------------+--------------------+
| Autoregressive          | -                 | 1221.4136962890625 |
| SEDD                    | -                 | 1403.5390625       |
| MDLM                    | -                 | 1370.174560546875  |
| Block diffusion (BD3LM) | 16                | 1345.452880859375  |
| Block diffusion (BD3LM) | 8                 | 1210.863037109375  |
| Block diffusion (BD3LM) | 4                 | 1176.90283203125   |
+-------------------------+-------------------+--------------------+
