In [None]:
# Cell 1: Clone the repo
!cd /content && rm -rf bd3lms
!cd /content && git clone https://github.com/ntua-el21050/bd3lms.git
!rm -rf /content/repro_runs

Cloning into 'bd3lms'...
remote: Enumerating objects: 768, done.[K
remote: Counting objects: 100% (227/227), done.[K
remote: Compressing objects: 100% (51/51), done.[K
remote: Total 768 (delta 203), reused 176 (delta 176), pack-reused 541 (from 1)[K
Receiving objects: 100% (768/768), 1.78 MiB | 1.64 MiB/s, done.
Resolving deltas: 100% (495/495), done.


In [None]:
# Cell 2: Verify repo
!ls -l /content/bd3lms

total 2160
-rw-r--r-- 1 root root 862037 Jan 11 12:53 2503.09573v3.pdf
drwxr-xr-x 9 root root   4096 Jan 11 12:53 configs
-rw-r--r-- 1 root root  33535 Jan 11 12:53 dataloader.py
-rw-r--r-- 1 root root  44840 Jan 11 12:53 diffusion.py
-rw-r--r-- 1 root root 225205 Jan 11 12:53 graphical_abstract.png
-rw-r--r-- 1 root root  11357 Jan 11 12:53 LICENSE
-rw-r--r-- 1 root root   7873 Jan 11 12:53 main.py
-rw-r--r-- 1 root root   8405 Jan 11 12:53 metrics.py
drwxr-xr-x 3 root root   4096 Jan 11 12:53 models
-rw-r--r-- 1 root root   2538 Jan 11 12:53 noise_schedule.py
-rw-r--r-- 1 root root   1449 Jan 11 12:53 push_to_hf.py
-rw-r--r-- 1 root root  10070 Jan 11 12:53 README.md
-rw-r--r-- 1 root root    363 Jan 11 12:53 requirements.txt
drwxr-xr-x 7 root root   4096 Jan 11 12:53 scripts
drwxr-xr-x 4 root root   4096 Jan 11 12:53 ssd-lm
-rw-r--r-- 1 root root 327057 Jan 11 12:53 table_1_diagram_2.ipynb
-rw-r--r-- 1 root root 525005 Jan 11 12:53 table_2_reproduction.ipynb
-rw-r--r-- 1 root root 1

In [None]:
# Cell 3: Install dependencies
!pip install -q \
    torchmetrics==1.6.2 \
    datasets==3.3.2 \
    einops==0.8.1 \
    fsspec==2024.2.0 \
    hydra-core==1.3.2 \
    lightning==2.5.0.post0 \
    omegaconf==2.3.0 \
    packaging==23.2 \
    pandas==2.2.1 \
    rich==13.7.1 \
    scikit-learn==1.5.1 \
    timm==0.9.16 \
    transformers==4.49.0 \
    matplotlib==3.10.0 \
    wandb

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.4/40.4 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.0/44.0 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m931.6/931.6 kB[0m [31m59.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m485.4/485.4 kB[0m [31m38.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m170.9/170.9 kB[0m [31m14.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.5/154.5 kB[0m [31m15.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m815.2/815.2 kB[0m [31m54.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
# Cell 4: Helper functions
import subprocess
import re
import os
import shutil
import sys
from pathlib import Path

def run_main(overrides, timeout=None):
    """Run main.py with given overrides."""
    env = dict(os.environ)
    env.setdefault("HYDRA_FULL_ERROR", "1")
    cmd = [sys.executable, "-u", "bd3lms/main.py", *overrides]
    print("\n$", " ".join(cmd))
    proc = subprocess.run(
        cmd,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
        timeout=timeout,
        check=False,
        env=env,
    )
    print(proc.stdout[-4000:])
    if proc.returncode != 0:
        raise RuntimeError(f"Command failed with return code {proc.returncode}")
    return proc.stdout

_METRIC_PATTERNS = [
    re.compile(r"val/ppl\s*[:=]\s*([0-9]+(?:\.[0-9]+)?(?:e[+-]?\d+)?)", re.IGNORECASE),
    re.compile(r"'val/ppl'\s*:\s*([0-9]+(?:\.[0-9]+)?(?:e[+-]?\d+)?)", re.IGNORECASE),
    re.compile(r"val/ppl\s*[│|]\s*([0-9]+(?:\.[0-9]+)?(?:e[+-]?\d+)?)", re.IGNORECASE),
]

def extract_val_ppl(log_text: str):
    for line in reversed(log_text.splitlines()):
        if "val/ppl" in line.lower():
            m = re.search(r"val/ppl.*?([0-9]+(?:\.[0-9]+)?(?:e[+-]?\d+)?)", line, re.IGNORECASE)
            if m:
                return float(m.group(1))
    hits = []
    for pat in _METRIC_PATTERNS:
        hits.extend(pat.findall(log_text))
    return float(hits[-1]) if hits else None

def _small_loader_overrides(batch_size=8, num_workers=2):
    return [
        f"loader.global_batch_size={batch_size}",
        f"loader.eval_global_batch_size={batch_size}",
        f"loader.batch_size={batch_size}",
        f"loader.eval_batch_size={batch_size}",
        f"loader.num_workers={num_workers}",
        "trainer.accumulate_grad_batches=1",
    ]

def train_run(run_name, algo, block_size=None, from_pretrained=None,
              max_steps=1500, extra_overrides=None):
    """Train a model and return checkpoint path."""
    save_dir = Path("/content/repro_runs") / run_name
    if save_dir.exists():
        shutil.rmtree(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)

    overrides = [
        "mode=train",
        "data=lm1b-wrap",
        "data.cache_dir=/content/bd3lms/data",
        "data.streaming=true",
        "data.max_train_samples=3000",
        "model=tiny",
        "model.length=128",
        "model.attn_backend=sdpa",
        f"algo={algo}",
        "trainer.accelerator=gpu",
        "trainer.devices=1",
        "trainer.num_nodes=1",
        "trainer.precision=16-mixed",
        "trainer.num_sanity_val_steps=0",
        "trainer.log_every_n_steps=10",
        "trainer.val_check_interval=50",
        f"trainer.max_steps={max_steps}",
        "data.max_valid_samples=100",
        "data.max_test_samples=100",
        f"checkpointing.save_dir=/content/repro_runs/{run_name}",
        "checkpointing.resume_from_ckpt=false",
        "wandb=null",
    ]
    overrides.extend(_small_loader_overrides(batch_size=8, num_workers=2))

    if block_size is not None:
        overrides.append(f"block_size={block_size}")
    if from_pretrained is not None:
        overrides.append(f"training.from_pretrained={from_pretrained}")
    if extra_overrides:
        overrides.extend(extra_overrides)

    _ = run_main(overrides)
    ckpt = save_dir / "checkpoints" / "last.ckpt"
    if not ckpt.exists():
        raise FileNotFoundError(f"Expected checkpoint not found: {ckpt}")
    return str(ckpt)

def eval_run(algo, checkpoint_path, block_size=None, extra_overrides=None):
    """Evaluate perplexity."""
    overrides = [
        "mode=ppl_eval",
        "data=lm1b-wrap",
        "data.cache_dir=/content/bd3lms/data",
        "data.streaming=true",
        "data.max_test_samples=1000",
        "model=tiny",
        "model.length=128",
        "model.attn_backend=sdpa",
        f"algo={algo}",
        f"eval.checkpoint_path={checkpoint_path}",
        "trainer.accelerator=gpu",
        "trainer.devices=1",
        "trainer.num_nodes=1",
        "trainer.precision=16-mixed",
        "trainer.num_sanity_val_steps=0",
        "wandb=null",
    ]
    overrides.extend(_small_loader_overrides(batch_size=8, num_workers=2))

    if block_size is not None:
        overrides.append(f"block_size={block_size}")
    if extra_overrides:
        overrides.extend(extra_overrides)

    log_text = run_main(overrides)
    ppl = extract_val_ppl(log_text)
    if ppl is None:
        raise ValueError("Could not parse val/ppl from output.")
    return ppl

print("Helper functions loaded!")

Helper functions loaded!


## Step 1: Train BD3-LM Base Model

Χρειαζόμαστε ένα base model για fine-tuning.
Το base χρησιμοποιεί **Linear U[0,1]** schedule (default).

In [None]:
# Cell 5: Train BD3-LM Base
print("=" * 60)
print("Training BD3-LM BASE (block_size=128)...")
print("=" * 60)

bd3lm_base_ckpt = train_run(
    "bd3lm_base_len128",
    algo="bd3lm",
    block_size=128,
    extra_overrides=[
        "training.resample=false",
        "algo.var_min=false",
        "trainer.val_check_interval=10",
    ],
)
print(f"✓ BD3-LM base checkpoint: {bd3lm_base_ckpt}")

Training BD3-LM BASE (block_size=128)...

$ /usr/bin/python3 -u bd3lms/main.py mode=train data=lm1b-wrap data.cache_dir=/content/bd3lms/data data.streaming=true data.max_train_samples=3000 model=tiny model.length=128 model.attn_backend=sdpa algo=bd3lm trainer.accelerator=gpu trainer.devices=1 trainer.num_nodes=1 trainer.precision=16-mixed trainer.num_sanity_val_steps=0 trainer.log_every_n_steps=10 trainer.val_check_interval=50 trainer.max_steps=1500 data.max_valid_samples=100 data.max_test_samples=100 checkpointing.save_dir=/content/repro_runs/bd3lm_base_len128 checkpointing.resume_from_ckpt=false wandb=null loader.global_batch_size=8 loader.eval_global_batch_size=8 loader.batch_size=8 loader.eval_batch_size=8 loader.num_workers=2 trainer.accumulate_grad_batches=1 block_size=128 training.resample=false algo.var_min=false trainer.val_check_interval=10
um=0]

Validation: |          | 0/? [00:00<?, ?it/s][A

Validation:   0%|          | 0/3 [00:00<?, ?it/s][A

Validation DataLoader 0:  

## Step 2: Table 8 - Noise Schedule Ablation

**ΚΡΙΣΙΜΟ:** Χρησιμοποιούμε `training.sampling_eps_min` και `training.sampling_eps_max`
για να ελέγξουμε το noise schedule κατά το training!

| Schedule | sampling_eps_min | sampling_eps_max |
|----------|------------------|------------------|
| Linear U[0,1] | 0.001 | 1.0 |
| Clipped U[0.3,0.8] | 0.3 | 0.8 |
| Clipped U[0.45,0.95] | 0.45 | 0.95 |

In [None]:
# Cell 6: Table 8 Experiments - CORRECTED

# Define noise schedules with CORRECT training parameters
noise_schedules = [
    # (name, [training overrides])
    ("Linear U[0,1]", [
        "training.sampling_eps_min=0.001",
        "training.sampling_eps_max=1.0",
    ]),
    ("Clipped U[0.3,0.8]", [
        "training.sampling_eps_min=0.3",
        "training.sampling_eps_max=0.8",
    ]),
    ("Clipped U[0.45,0.95]", [
        "training.sampling_eps_min=0.45",
        "training.sampling_eps_max=0.95",
    ]),
]

# Block sizes to test
block_sizes = [4, 16]

results_table8 = []

for Lprime in block_sizes:
    print("\n" + "=" * 60)
    print(f"EXPERIMENTS FOR L' = {Lprime}")
    print("=" * 60)

    for schedule_name, schedule_overrides in noise_schedules:
        print(f"\n--- {schedule_name} ---")

        # Create safe run name
        safe_name = schedule_name.replace("[", "").replace("]", "").replace(",", "_").replace(" ", "_")
        run_name = f"bd3lm_schedule_{safe_name}_Lp{Lprime}"

        # Fine-tune with this schedule
        finetune_ckpt = train_run(
            run_name,
            algo="bd3lm",
            block_size=Lprime,
            from_pretrained=bd3lm_base_ckpt,
            extra_overrides=[
                "training.resample=true",
                "algo.var_min=false",
                "trainer.val_check_interval=10",
            ] + schedule_overrides,  # Add schedule-specific overrides
        )

        # Evaluate (always with linear schedule for fair comparison)
        ppl = eval_run(
            algo="bd3lm",
            checkpoint_path=finetune_ckpt,
            block_size=Lprime,
            extra_overrides=[
                "algo.var_min=false",
            ],
        )

        results_table8.append({
            "block_size": Lprime,
            "schedule": schedule_name,
            "ppl": ppl,
        })
        print(f"✓ {schedule_name} (L'={Lprime}): PPL = {ppl:.2f}")

print("\n" + "=" * 60)
print("TABLE 8 EXPERIMENTS COMPLETE!")
print("=" * 60)


EXPERIMENTS FOR L' = 4

--- Linear U[0,1] ---

$ /usr/bin/python3 -u bd3lms/main.py mode=train data=lm1b-wrap data.cache_dir=/content/bd3lms/data data.streaming=true data.max_train_samples=3000 model=tiny model.length=128 model.attn_backend=sdpa algo=bd3lm trainer.accelerator=gpu trainer.devices=1 trainer.num_nodes=1 trainer.precision=16-mixed trainer.num_sanity_val_steps=0 trainer.log_every_n_steps=10 trainer.val_check_interval=50 trainer.max_steps=1500 data.max_valid_samples=100 data.max_test_samples=100 checkpointing.save_dir=/content/repro_runs/bd3lm_schedule_Linear_U0_1_Lp4 checkpointing.resume_from_ckpt=false wandb=null loader.global_batch_size=8 loader.eval_global_batch_size=8 loader.batch_size=8 loader.eval_batch_size=8 loader.num_workers=2 trainer.accumulate_grad_batches=1 block_size=4 training.from_pretrained=/content/repro_runs/bd3lm_base_len128/checkpoints/last.ckpt training.resample=true algo.var_min=false trainer.val_check_interval=10 training.sampling_eps_min=0.001 trai

RuntimeError: Command failed with return code 1

In [None]:
# Cell 7: Display Results

def print_table(rows):
    if not rows:
        print("No data to display.")
        return
    columns = list(rows[0].keys())
    str_rows = [{col: str(row.get(col, "")) for col in columns} for row in rows]
    widths = {col: max(len(col), max(len(row[col]) for row in str_rows)) for col in columns}
    def print_separator():
        print("+" + "+".join("-" * (widths[col] + 2) for col in columns) + "+")
    def print_row(row):
        print("| " + " | ".join(row[col].ljust(widths[col]) for col in columns) + " |")
    print_separator()
    print_row({col: col for col in columns})
    print_separator()
    for row in str_rows:
        print_row(row)
    print_separator()

print("\n" + "=" * 70)
print("TABLE 8: EFFECT OF NOISE SCHEDULE ON LIKELIHOOD ESTIMATION")
print("=" * 70)

# Results for L'=4
print("\n--- L' = 4 (smaller block → heavier masking better) ---")
results_Lp4 = [r for r in results_table8 if r["block_size"] == 4]
results_Lp4_sorted = sorted(results_Lp4, key=lambda x: x["ppl"])
print_table(results_Lp4_sorted)
if results_Lp4_sorted:
    print(f"Best for L'=4: {results_Lp4_sorted[0]['schedule']}")

# Results for L'=16
print("\n--- L' = 16 (larger block → lighter masking better) ---")
results_Lp16 = [r for r in results_table8 if r["block_size"] == 16]
results_Lp16_sorted = sorted(results_Lp16, key=lambda x: x["ppl"])
print_table(results_Lp16_sorted)
if results_Lp16_sorted:
    print(f"Best for L'=16: {results_Lp16_sorted[0]['schedule']}")

## Paper Results (Table 8) για Σύγκριση

| Noise Schedule | L'=4 PPL | L'=16 PPL |
|----------------|----------|----------|
| **Clipped U[0.45,0.95]** | **29.21** | 31.42 |
| Clipped U[0.3,0.8] | 29.38 | **31.12** |
| Linear U[0,1] | 30.18 | 31.72 |

**Expected Pattern:**
- L'=4: Clipped U[0.45,0.95] < Clipped U[0.3,0.8] < Linear (heavier masking better)
- L'=16: Clipped U[0.3,0.8] < Clipped U[0.45,0.95] < Linear (lighter masking better)

**Key Insight:** Avoiding extreme mask rates reduces training variance!

In [None]:
# Cell 8: Verify Pattern Match

print("\n" + "=" * 70)
print("PATTERN VERIFICATION")
print("=" * 70)

# Check L'=4 pattern
if len(results_Lp4_sorted) >= 3:
    linear_4 = next((r['ppl'] for r in results_Lp4 if 'Linear' in r['schedule']), None)
    clip_03_4 = next((r['ppl'] for r in results_Lp4 if '0.3' in r['schedule']), None)
    clip_045_4 = next((r['ppl'] for r in results_Lp4 if '0.45' in r['schedule']), None)

    print(f"\nL'=4:")
    print(f"  Linear U[0,1]:       {linear_4:.2f}")
    print(f"  Clipped U[0.3,0.8]:  {clip_03_4:.2f}")
    print(f"  Clipped U[0.45,0.95]: {clip_045_4:.2f}")

    if clip_045_4 < linear_4:
        print("  ✅ Clipped U[0.45,0.95] < Linear (CORRECT!)")
    else:
        print("  ⚠️ Pattern not as expected (may need more training)")

# Check L'=16 pattern
if len(results_Lp16_sorted) >= 3:
    linear_16 = next((r['ppl'] for r in results_Lp16 if 'Linear' in r['schedule']), None)
    clip_03_16 = next((r['ppl'] for r in results_Lp16 if '0.3' in r['schedule']), None)
    clip_045_16 = next((r['ppl'] for r in results_Lp16 if '0.45' in r['schedule']), None)

    print(f"\nL'=16:")
    print(f"  Linear U[0,1]:       {linear_16:.2f}")
    print(f"  Clipped U[0.3,0.8]:  {clip_03_16:.2f}")
    print(f"  Clipped U[0.45,0.95]: {clip_045_16:.2f}")

    if clip_03_16 < linear_16:
        print("  ✅ Clipped U[0.3,0.8] < Linear (CORRECT!)")
    else:
        print("  ⚠️ Pattern not as expected (may need more training)")

print("\n" + "=" * 70)
print("REPRODUCTION COMPLETE!")
print("=" * 70)