In [2]:
# Clone repo
!cd /content && rm -rf bd3lms
!cd /content && git clone https://github.com/ntua-el21050/bd3lms.git

# Create directories
!mkdir -p /content/bd3lms/data
!mkdir -p /content/repro_runs

Cloning into 'bd3lms'...
remote: Enumerating objects: 1181, done.[K
remote: Counting objects: 100% (445/445), done.[K
remote: Compressing objects: 100% (214/214), done.[K
remote: Total 1181 (delta 329), reused 306 (delta 230), pack-reused 736 (from 2)[K
Receiving objects: 100% (1181/1181), 8.28 MiB | 40.00 MiB/s, done.
Resolving deltas: 100% (736/736), done.


In [3]:
!pip install -q torchmetrics==1.6.2 datasets==3.3.2 einops==0.8.1 \
    hydra-core==1.3.2 lightning==2.5.0.post0 transformers==4.49.0

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/40.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.4/40.4 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.0/44.0 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m931.6/931.6 kB[0m [31m36.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m485.4/485.4 kB[0m [31m38.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m64.4/64.4 kB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.5/154.5 kB[0m [31m14.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m815.2/815.2 kB[0m [31m55.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [4]:
import subprocess
import re
import os
import shutil
import sys
from pathlib import Path

def run_main(overrides, timeout=None):
    """Run `python -u main.py ...` and return combined stdout/stderr text."""
    env = dict(os.environ)
    env.setdefault("HYDRA_FULL_ERROR", "1")
    cmd = [sys.executable, "-u", "bd3lms/main.py", *overrides]
    print("\n$", " ".join(cmd))
    proc = subprocess.run(
        cmd,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
        timeout=timeout,
        check=False,
        env=env,
    )
    print(proc.stdout[-4000:])
    if proc.returncode != 0:
        raise RuntimeError(f"Command failed with return code {proc.returncode}")
    return proc.stdout

_METRIC_PATTERNS = [
    re.compile(r"val/ppl\s*[:=]\s*([0-9]+(?:\.[0-9]+)?(?:e[+-]?\d+)?)", re.IGNORECASE),
    re.compile(r"'val/ppl'\s*:\s*([0-9]+(?:\.[0-9]+)?(?:e[+-]?\d+)?)", re.IGNORECASE),
    re.compile(r"val/ppl\s*[│|]\s*([0-9]+(?:\.[0-9]+)?(?:e[+-]?\d+)?)", re.IGNORECASE),
]

def extract_val_ppl(log_text: str):
    for line in reversed(log_text.splitlines()):
        if "val/ppl" in line.lower():
            m = re.search(r"val/ppl.*?([0-9]+(?:\.[0-9]+)?(?:e[+-]?\d+)?)", line, re.IGNORECASE)
            if m:
                return float(m.group(1))
    hits = []
    for pat in _METRIC_PATTERNS:
        hits.extend(pat.findall(log_text))
    return float(hits[-1]) if hits else None

def _small_loader_overrides(batch_size=4, num_workers=2):
    """Smaller batch size for OWT (1024 context length)."""
    return [
        f"loader.global_batch_size={batch_size}",
        f"loader.eval_global_batch_size={batch_size}",
        f"loader.batch_size={batch_size}",
        f"loader.eval_batch_size={batch_size}",
        f"loader.num_workers={num_workers}",
        "trainer.accumulate_grad_batches=1",
    ]

def train_run(run_name, algo, block_size=None, from_pretrained=None, max_steps=800, extra_overrides=None):
    """Train a model for Table 4 (OpenWebText)."""
    save_dir = Path("/content/repro_runs") / run_name
    if save_dir.exists():
        shutil.rmtree(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)

    overrides = [
        "mode=train",

        # TABLE 4 SPECIFIC: OpenWebText, length=1024

        "data=openwebtext-split",           # ← OWT dataset
        "data.cache_dir=/content/bd3lms/data",
        "data.streaming=true",
        "data.max_train_samples=1200",      # Scaled down
        "model=tiny",
        "model.length=1024",                # ← 1024 context (8x larger than Table 3)
        "model.attn_backend=sdpa",
        f"algo={algo}",

        # Trainer settings

        "trainer.accelerator=gpu",
        "trainer.devices=1",
        "trainer.num_nodes=1",
        "trainer.precision=16-mixed",# Mixed precision για ταχύτητα
        "trainer.num_sanity_val_steps=0",
        "trainer.log_every_n_steps=10",
        "trainer.val_check_interval=10",    # ← Μικρό για λίγα samples
        f"trainer.max_steps={max_steps}",
        "data.max_valid_samples=100",
        "data.max_test_samples=100",
        f"checkpointing.save_dir=/content/repro_runs/{run_name}",
        "checkpointing.resume_from_ckpt=false",
        "wandb=null",
    ]
    overrides.extend(_small_loader_overrides(batch_size=4, num_workers=2))  # ← batch=4 για memory

    if block_size is not None:
        overrides.append(f"block_size={block_size}")
    if from_pretrained is not None:
        overrides.append(f"training.from_pretrained={from_pretrained}")
    if extra_overrides:
        overrides.extend(extra_overrides)

    _ = run_main(overrides)
    ckpt = save_dir / "checkpoints" / "last.ckpt"
    if not ckpt.exists():
        raise FileNotFoundError(f"Expected checkpoint not found: {ckpt}")
    return str(ckpt)

def eval_run(algo, checkpoint_path, block_size=None, extra_overrides=None):
    """Evaluate perplexity for Table 4 (OpenWebText)."""
    overrides = [
        "mode=ppl_eval",

        # TABLE 4 SPECIFIC

        "data=openwebtext-split",           # ← OWT dataset
        "data.cache_dir=/content/bd3lms/data",
        "data.streaming=true",
        "data.max_test_samples=500",        # Evaluation samples
        "model=tiny",
        "model.length=1024",                # ← 1024 context
        "model.attn_backend=sdpa",
        f"algo={algo}",
        f"eval.checkpoint_path={checkpoint_path}",
        "trainer.accelerator=gpu",
        "trainer.devices=1",
        "trainer.num_nodes=1",
        "trainer.precision=16-mixed",
        "trainer.num_sanity_val_steps=0",
        "wandb=null",
    ]
    overrides.extend(_small_loader_overrides(batch_size=4, num_workers=2))

    if block_size is not None:
        overrides.append(f"block_size={block_size}")
    if extra_overrides:
        overrides.extend(extra_overrides)

    log_text = run_main(overrides)
    ppl = extract_val_ppl(log_text)
    if ppl is None:
        raise ValueError("Could not parse val/ppl from output.")
    return ppl

In [5]:
results = []

# 1) AUTOREGRESSIVE BASELINE

print("=" * 60)
print("Training AR baseline...")
print("=" * 60)
ar_run = "ar_tiny_owt_len1024"
ar_ckpt = train_run(ar_run, algo="ar")
ar_ppl = eval_run(algo="ar", checkpoint_path=ar_ckpt)
results.append({"model": "Autoregressive", "block_size_Lprime": "-", "val_ppl": ar_ppl})
print(f"✓ AR PPL: {ar_ppl}")


# 2) DIFFUSION BASELINES: SEDD + MDLM

for algo_name, display_name in [("sedd", "SEDD"), ("mdlm", "MDLM")]:
    print("=" * 60)
    print(f"Training {display_name} baseline...")
    print("=" * 60)
    run_name = f"{algo_name}_tiny_owt_len1024"
    ckpt = train_run(
        run_name,
        algo=algo_name,
        extra_overrides=[
            "training.resample=false",
            "algo.var_min=false",
            "algo.clip_search_widths=[]",
        ],
    )
    ppl = eval_run(
        algo=algo_name,
        checkpoint_path=ckpt,
        extra_overrides=[
            "algo.var_min=false",
        ],
    )
    results.append({"model": display_name, "block_size_Lprime": "-", "val_ppl": ppl})
    print(f"✓ {display_name} PPL: {ppl}")


# 3) BD3-LM BASE TRAINING (block_size = 1024 = L)

print("=" * 60)
print("Training BD3-LM BASE (block_size=1024)...")
print("=" * 60)
bd3lm_base_run = "bd3lm_base_owt_len1024"
bd3lm_base_ckpt = train_run(
    bd3lm_base_run,
    algo="bd3lm",
    block_size=1024,  # ← L' = L (full context for base)
    extra_overrides=[
        "training.resample=false",  # No resampling for base
        "algo.var_min=false",
        "algo.clip_search_widths=[]",
    ],
)
print(f"✓ BD3-LM base checkpoint saved")

# 4) BD3-LM FINE-TUNING (block_size = 16, 8, 4)

for Lprime in [16, 8, 4]:
    print("=" * 60)
    print(f"Fine-tuning BD3-LM (block_size={Lprime})...")
    print("=" * 60)
    finetune_run = f"bd3lm_finetune_owt_Lp{Lprime}"
    finetune_ckpt = train_run(
        finetune_run,
        algo="bd3lm",
        block_size=Lprime,
        from_pretrained=bd3lm_base_ckpt,  # ← Start from base!
        extra_overrides=[
            "training.resample=true",  # ← Enable resampling for fine-tune!
            "algo.var_min=false",
            "algo.clip_search_widths=[]",
        ],
    )
    ppl = eval_run(
        algo="bd3lm",
        checkpoint_path=finetune_ckpt,
        block_size=Lprime,
        extra_overrides=[
            "algo.var_min=false",
        ],
    )
    results.append({"model": "BD3-LM", "block_size_Lprime": Lprime, "val_ppl": ppl})
    print(f"✓ BD3-LM (L'={Lprime}) PPL: {ppl}")

print("\n" + "=" * 60)
print("ALL EXPERIMENTS COMPLETE!")
print("=" * 60)

Training AR baseline...

$ /usr/bin/python3 -u bd3lms/main.py mode=train data=openwebtext-split data.cache_dir=/content/bd3lms/data data.streaming=true data.max_train_samples=1200 model=tiny model.length=1024 model.attn_backend=sdpa algo=ar trainer.accelerator=gpu trainer.devices=1 trainer.num_nodes=1 trainer.precision=16-mixed trainer.num_sanity_val_steps=0 trainer.log_every_n_steps=10 trainer.val_check_interval=10 trainer.max_steps=800 data.max_valid_samples=100 data.max_test_samples=100 checkpointing.save_dir=/content/repro_runs/ar_tiny_owt_len1024 checkpointing.resume_from_ckpt=false wandb=null loader.global_batch_size=4 loader.eval_global_batch_size=4 loader.batch_size=4 loader.eval_batch_size=4 loader.num_workers=2 trainer.accumulate_grad_batches=1
elism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: T