In [1]:
!git clone -b Extension1 https://github.com/ntua-el21050/bd3lms.git

Cloning into 'bd3lms'...
remote: Enumerating objects: 791, done.[K
remote: Counting objects: 100% (249/249), done.[K
remote: Compressing objects: 100% (63/63), done.[K
remote: Total 791 (delta 215), reused 193 (delta 186), pack-reused 542 (from 1)[K
Receiving objects: 100% (791/791), 2.40 MiB | 26.19 MiB/s, done.
Resolving deltas: 100% (509/509), done.


In [2]:
%%capture
!pip install -r bd3lms/requirements.txt

In [3]:
import subprocess
import re
from pathlib import Path

def run_main(overrides, timeout=None):
    """Run `python -u main.py ...` and return combined stdout/stderr text."""
    env = dict(os.environ)
    env.setdefault("HYDRA_FULL_ERROR", "1")
    cmd = [sys.executable, "-u", "bd3lms/main.py", *overrides]
    print("\n$", " ".join(cmd))
    proc = subprocess.run(
        cmd,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
        timeout=timeout,
        check=False,
        env=env,
    )
    print(proc.stdout[-4000:])  # tail for quick visibility
    if proc.returncode != 0:
        raise RuntimeError(f"Command failed with return code {proc.returncode}")
    return proc.stdout

_METRIC_PATTERNS = [
    # Key: value (some loggers print this)
    re.compile(r"val/ppl\s*[:=]\s*([0-9]+(?:\.[0-9]+)?(?:e[+-]?\d+)?)", re.IGNORECASE),
    re.compile(r"'val/ppl'\s*:\s*([0-9]+(?:\.[0-9]+)?(?:e[+-]?\d+)?)", re.IGNORECASE),

    # Lightning "rich" table row (note the unicode box character │)
    re.compile(r"val/ppl\s*[│|]\s*([0-9]+(?:\.[0-9]+)?(?:e[+-]?\d+)?)", re.IGNORECASE),
]

def extract_val_ppl(log_text: str):
    # First try line-based parse from the end (most reliable for tables)
    for line in reversed(log_text.splitlines()):
        if "val/ppl" in line.lower():
            m = re.search(r"val/ppl.*?([0-9]+(?:\.[0-9]+)?(?:e[+-]?\d+)?)", line, re.IGNORECASE)
            if m:
                return float(m.group(1))

    # Fallback: scan entire text with known patterns
    hits = []
    for pat in _METRIC_PATTERNS:
        hits.extend(pat.findall(log_text))
    return float(hits[-1]) if hits else None

def _small_loader_overrides(batch_size=8, num_workers=2):
    """Overrides needed to avoid huge default batch sizes on Colab."""
    return [
        f"loader.global_batch_size={batch_size}",
        f"loader.eval_global_batch_size={batch_size}",
        f"loader.batch_size={batch_size}",
        f"loader.eval_batch_size={batch_size}",
        f"loader.num_workers={num_workers}",
        "trainer.accumulate_grad_batches=1",
    ]

def train_run(run_name, algo, block_size=None, from_pretrained=None, max_steps=800, extra_overrides=None):
    """Train a model (optionally from a base checkpoint) and return the last.ckpt path."""
    save_dir = Path("/content/repro_runs") / run_name
    if save_dir.exists():
        shutil.rmtree(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)

    overrides = [
        "mode=train",
        "data=lm1b-wrap",
        "data.cache_dir=/content/bd3lms/data",
        "data.streaming=true",
        "data.max_train_samples=5000",
        # For LM1B, validation uses the 'test' split in this codebase
        "model=tiny",
        "model.length=128",
        "model.attn_backend=sdpa",
        f"algo={algo}",
        "trainer.accelerator=gpu",
        "trainer.devices=1",
        "trainer.num_nodes=1",
        "trainer.precision=16-mixed",
        "trainer.num_sanity_val_steps=0",
        "trainer.log_every_n_steps=10",
        "trainer.val_check_interval=50",
        f"trainer.max_steps={max_steps}",
        "data.max_valid_samples=100",
        "data.max_test_samples=100",
        f"checkpointing.save_dir=/content/repro_runs/{run_name}",
        "checkpointing.resume_from_ckpt=false",
        "wandb=null",
    ]
    overrides.extend(_small_loader_overrides(batch_size=8, num_workers=2))
    if block_size is not None:
        overrides.append(f"block_size={block_size}")
    if from_pretrained is not None:
        overrides.append(f"training.from_pretrained={from_pretrained}")
    if extra_overrides:
        overrides.extend(extra_overrides)

    _ = run_main(overrides)
    ckpt = save_dir / "checkpoints" / "last.ckpt"
    if not ckpt.exists():
        raise FileNotFoundError(f"Expected checkpoint not found: {ckpt}")
    return str(ckpt)

def eval_run(algo, checkpoint_path, block_size=None, extra_overrides=None):
    """Evaluate perplexity (val/ppl) for a given checkpoint."""
    overrides = [
        "mode=ppl_eval",
        "data=lm1b-wrap",
        "data.cache_dir=/content/bd3lms/data",
        "data.streaming=true",
        # For LM1B, `get_dataloaders` maps validation to the 'test' split
        "data.max_test_samples=1000",
        "model=tiny",
        "model.length=128",
        "model.attn_backend=sdpa",
        f"algo={algo}",
        f"eval.checkpoint_path={checkpoint_path}",
        "trainer.accelerator=gpu",
        "trainer.devices=1",
        "trainer.num_nodes=1",
        "trainer.precision=16-mixed",
        "trainer.num_sanity_val_steps=0",
        "wandb=null",
    ]
    overrides.extend(_small_loader_overrides(batch_size=8, num_workers=2))
    if block_size is not None:
        overrides.append(f"block_size={block_size}")
    if extra_overrides:
        overrides.extend(extra_overrides)

    log_text = run_main(overrides)
    ppl = extract_val_ppl(log_text)
    if ppl is None:
        raise ValueError("Could not parse val/ppl from output. Try increasing the log tail or printing full logs.")
    return ppl

In [4]:
import os
import shutil
import sys

results = []

bd3lm_base_run = "bd3lm_base_len128"
bd3lm_base_ckpt = train_run(
    bd3lm_base_run,
    algo="bd3lm",
    block_size=128,
    extra_overrides=[
        "training.resample=false",
        "algo.var_min=false",
        "algo.clip_search_widths=[]",
    ],
)

for Lprime in [16, 8, 4]:
    finetune_run = f"bd3lm_finetune_Lp{Lprime}"
    finetune_ckpt = train_run(
        finetune_run,
        algo="bd3lm",
        block_size=Lprime,
        from_pretrained=bd3lm_base_ckpt,
        extra_overrides=[
            "training.resample=true",
            "algo.var_min=false",
            "algo.clip_search_widths=[]",
        ],
    )
    ppl = eval_run(
        algo="bd3lm",
        checkpoint_path=finetune_ckpt,
        block_size=Lprime,
        extra_overrides=[
            "algo.var_min=false",
        ],
    )
    results.append({"model": "Block diffusion (BD3LM)", "block_size_Lprime": Lprime, "val_ppl": ppl})


$ /usr/bin/python3 -u bd3lms/main.py mode=train data=lm1b-wrap data.cache_dir=/content/bd3lms/data data.streaming=true data.max_train_samples=5000 model=tiny model.length=128 model.attn_backend=sdpa algo=bd3lm trainer.accelerator=gpu trainer.devices=1 trainer.num_nodes=1 trainer.precision=16-mixed trainer.num_sanity_val_steps=0 trainer.log_every_n_steps=10 trainer.val_check_interval=50 trainer.max_steps=800 data.max_valid_samples=100 data.max_test_samples=100 checkpointing.save_dir=/content/repro_runs/bd3lm_base_len128 checkpointing.resume_from_ckpt=false wandb=null loader.global_batch_size=8 loader.eval_global_batch_size=8 loader.batch_size=8 loader.eval_batch_size=8 loader.num_workers=2 trainer.accumulate_grad_batches=1 block_size=128 training.resample=false algo.var_min=false algo.clip_search_widths=[]
:  68%|██████▊   | 100/146 [00:23<00:10,  4.21it/s, v_num=0]Epoch 3, global step 538: 'val/nll' was not in top 1

Epoch 3:  82%|████████▏ | 120/146 [00:28<00:06,  4.27it/s, v_num=0]


In [5]:
def print_table(rows):
    if not rows:
        print("No data to display.")
        return

    columns = list(rows[0].keys())

    str_rows = [
        {col: str(row.get(col, "")) for col in columns}
        for row in rows
    ]

    widths = {
        col: max(len(col), max(len(row[col]) for row in str_rows))
        for col in columns
    }

    def print_separator():
        print("+" + "+".join("-" * (widths[col] + 2) for col in columns) + "+")

    def print_row(row):
        print(
            "| " +
            " | ".join(row[col].ljust(widths[col]) for col in columns) +
            " |"
        )

    # Print table
    print_separator()
    print_row({col: col for col in columns})
    print_separator()
    for row in str_rows:
        print_row(row)
    print_separator()

print_table(results)

+-------------------------+-------------------+--------------------+
| model                   | block_size_Lprime | val_ppl            |
+-------------------------+-------------------+--------------------+
| Block diffusion (BD3LM) | 16                | 252.72149658203125 |
| Block diffusion (BD3LM) | 8                 | 249.3388214111328  |
| Block diffusion (BD3LM) | 4                 | 246.02944946289062 |
+-------------------------+-------------------+--------------------+
