# Table 6 Reproduction - WORKING VERSION

## Fixes Applied:
1. **Config keys**: Added `+` prefix for non-existing keys
2. **Checkpoint conversion**: Convert `.ckpt` to HuggingFace format before `sample_eval`
3. **Length extraction**: Parse actual token counts, not `len(log_text)`

## Table 6 from paper:
| Model | Median # tokens | Max # tokens |
|-------|-----------------|-------------|
| SEDD | 1021 | **1024** (limited!) |
| BD3-LM L'=16 | 798 | **9982** |

In [1]:
# Clone YOUR fork
!cd /content && rm -rf bd3lms
!cd /content && git clone https://github.com/ntua-el21050/bd3lms.git

!mkdir -p /content/bd3lms/data
!mkdir -p /content/repro_runs
!mkdir -p /content/hf_checkpoints
!mkdir -p /content/sample_logs

Cloning into 'bd3lms'...
remote: Enumerating objects: 899, done.[K
remote: Counting objects: 100% (356/356), done.[K
remote: Compressing objects: 100% (144/144), done.[K
remote: Total 899 (delta 271), reused 261 (delta 211), pack-reused 543 (from 1)[K
Receiving objects: 100% (899/899), 3.41 MiB | 31.42 MiB/s, done.
Resolving deltas: 100% (565/565), done.


In [2]:
!pip install -q torchmetrics==1.6.2 datasets==3.3.2 einops==0.8.1 \
    hydra-core==1.3.2 lightning==2.5.0.post0 transformers==4.49.0 \
    huggingface_hub fsspec==2024.2.0 omegaconf==2.3.0

In [3]:
import subprocess
import re
import os
import shutil
import sys
import torch
import json
import numpy as np
from pathlib import Path

sys.path.insert(0, '/content/bd3lms')

def run_main(overrides, timeout=None):
    """Run main.py with overrides."""
    env = dict(os.environ)
    env.setdefault("HYDRA_FULL_ERROR", "1")
    cmd = [sys.executable, "-u", "bd3lms/main.py", *overrides]
    print("\n$", " ".join(cmd[:8]), "...")
    proc = subprocess.run(
        cmd,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
        timeout=timeout,
        check=False,
        env=env,
    )
    print(proc.stdout[-3000:])
    if proc.returncode != 0:
        raise RuntimeError(f"Command failed with return code {proc.returncode}")
    return proc.stdout


def _small_loader_overrides(batch_size=4, num_workers=2):
    return [
        f"loader.global_batch_size={batch_size}",
        f"loader.eval_global_batch_size={batch_size}",
        f"loader.batch_size={batch_size}",
        f"loader.eval_batch_size={batch_size}",
        f"loader.num_workers={num_workers}",
        "trainer.accumulate_grad_batches=1",
    ]

In [4]:
def train_run(run_name, algo, block_size=None, from_pretrained=None, max_steps=800, extra_overrides=None, model_length=1024):
    """Train a model."""
    save_dir = Path("/content/repro_runs") / run_name
    if save_dir.exists():
        shutil.rmtree(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)

    overrides = [
        "mode=train",
        "data=openwebtext-split",
        "++data.cache_dir=/content/bd3lms/data",       # ++ prefix!
        "++data.streaming=true",                        # ++ prefix!
        "++data.max_train_samples=1500",                # ++ prefix!
        "++data.max_valid_samples=100",                 # ++ prefix!
        "++data.max_test_samples=100",                  # ++ prefix!
        "model=tiny",
        f"model.length={model_length}",
        "model.attn_backend=sdpa",
        f"algo={algo}",
        "trainer.accelerator=gpu",
        "trainer.devices=1",
        "trainer.num_nodes=1",
        "trainer.precision=16-mixed",
        "trainer.num_sanity_val_steps=0",
        "trainer.log_every_n_steps=20",
        "trainer.val_check_interval=50",
        f"trainer.max_steps={max_steps}",
        f"checkpointing.save_dir={save_dir}",
        "checkpointing.resume_from_ckpt=false",
        "wandb=null",
    ]
    overrides.extend(_small_loader_overrides(batch_size=4, num_workers=2))

    if block_size is not None:
        overrides.append(f"block_size={block_size}")
    if from_pretrained is not None:
        overrides.append(f"training.from_pretrained={from_pretrained}")
        overrides.append("training.resample=true")
    if extra_overrides:
        overrides.extend(extra_overrides)

    _ = run_main(overrides)

    # Find checkpoint
    ckpt_dir = save_dir / "checkpoints"
    for name in ["best.ckpt", "last.ckpt"]:
        ckpt = ckpt_dir / name
        if ckpt.exists():
            print(f"✓ Checkpoint: {ckpt}")
            return str(ckpt)

    # List what we have
    if ckpt_dir.exists():
        print(f"Available checkpoints: {list(ckpt_dir.glob('*.ckpt'))}")
    raise FileNotFoundError(f"No checkpoint in {ckpt_dir}")

In [5]:
def convert_ckpt_to_hf(ckpt_path, output_dir, block_size):
    """
    Convert Lightning .ckpt to HuggingFace format.
    """
    import transformers

    print(f"\n{'='*50}")
    print(f"Converting {ckpt_path}")
    print(f"To: {output_dir}")
    print(f"{'='*50}")

    # Clean output dir
    if os.path.exists(output_dir):
        shutil.rmtree(output_dir)
    os.makedirs(output_dir, exist_ok=True)

    # Step 1: Load reference model FROM HUGGINGFACE (with weights)
    print("\n[1/4] Loading reference model from HuggingFace...")
    ref_model_id = f"kuleshov-group/bd3lm-owt-block_size{block_size}"

    try:
        model = transformers.AutoModelForMaskedLM.from_pretrained(
            ref_model_id,
            trust_remote_code=True,
            torch_dtype=torch.float32
        )
    except Exception as e:
        print(f"Warning: {e}")
        print("Trying block_size 16 as fallback...")
        model = transformers.AutoModelForMaskedLM.from_pretrained(
            "kuleshov-group/bd3lm-owt-block_size16",
            trust_remote_code=True,
            torch_dtype=torch.float32
        )

    # Step 2: Load Lightning checkpoint
    print("\n[2/4] Loading Lightning checkpoint...")
    checkpoint = torch.load(ckpt_path, map_location='cpu', weights_only=False)
    state_dict = checkpoint.get('state_dict', checkpoint)
    print(f"   Found {len(state_dict)} parameters")

    # Step 3: Clean key names
    print("\n[3/4] Cleaning state dict keys...")
    cleaned = {}
    for k, v in state_dict.items():
        new_k = k
        for prefix in ['backbone.', 'diffusion.backbone.', 'model.', 'module.']:
            if new_k.startswith(prefix):
                new_k = new_k[len(prefix):]
        cleaned[new_k] = v
    print(f"   Cleaned keys: {list(cleaned.keys())[:3]}")

    # Step 4: Apply our weights to the model
    print("\n[4/4] Applying our weights to model...")
    missing, unexpected = model.load_state_dict(cleaned, strict=False)
    print(f"   Missing: {len(missing)}, Unexpected: {len(unexpected)}")

    # Save with our weights
    model.save_pretrained(output_dir)

    # Tokenizer
    tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2')
    tokenizer.save_pretrained(output_dir)

    print(f"\n✓ Saved to {output_dir}")
    return output_dir

In [6]:
def extract_length_stats(log_text, logdir=None):
    """
    Extract actual token length statistics from CSV.
    """
    lengths = []

    # Read CSV file (BD3-LM saves as CSV without header)
    if logdir and os.path.exists(logdir) and os.path.isfile(logdir):
        try:
            import pandas as pd
            df = pd.read_csv(logdir, header=None)
            # Column 1 is length
            lengths.extend(df[1].dropna().astype(int).tolist())
            print(f"Found {len(lengths)} samples in CSV")
        except Exception as e:
            print(f"CSV error: {e}")

    if lengths:
        return {
            'count': len(lengths),
            'median': int(np.median(lengths)),
            'max': int(np.max(lengths)),
            'mean': round(np.mean(lengths), 1),
        }
    return {'count': 0, 'median': None, 'max': None, 'mean': None}

def eval_run(algo, hf_checkpoint_path, block_size=None, num_samples=50, extra_overrides=None, model_length=1024):
    """
    Run sample_eval with HuggingFace checkpoint.
    Run multiple times to collect samples (workaround for variable-length bug).
    """
    import pandas as pd

    logfile = f"/content/sample_logs/varlen_{algo}_bs{block_size}"

    # Clean up
    if os.path.exists(logfile):
        os.remove(logfile)

    all_lengths = []

    # Run multiple times with 1 sample each (workaround)
    for i in range(num_samples):
        print(f"\rGenerating sample {i+1}/{num_samples}...", end="")

        overrides = [
            "mode=sample_eval",
            "data=openwebtext-split",
            "sampling.num_sample_batches=1",  # 1 at a time!
            "++data.cache_dir=/content/bd3lms/data",
            "++data.streaming=true",
            "++data.max_test_samples=1",
            "model=tiny",
            f"model.length={model_length}",
            "model.attn_backend=sdpa",
            f"algo={algo}",
            "algo.backbone=hf_dit",
            "algo.T=5000",
            f"eval.checkpoint_path={hf_checkpoint_path}",
            "sampling.var_length=true",
            "sampling.nucleus_p=0.9",
            "sampling.kv_cache=true",
            f"sampling.logdir={logfile}",
            f"seed={42+i}",  # Different seed each time
            "trainer.accelerator=gpu",
            "trainer.devices=1",
            "trainer.precision=16-mixed",
            "wandb=null",
            f"block_size={block_size}",
            "loader.eval_batch_size=1",
        ]

        try:
            run_main(overrides)
        except:
            pass  # Continue on error

    print("\nDone!")

    # Read results from CSV
    if os.path.exists(logfile) and os.path.isfile(logfile):
        try:
            df = pd.read_csv(logfile, header=None)
            all_lengths = df[1].dropna().astype(int).tolist()
        except:
            pass

    if all_lengths:
        stats = {
            'count': len(all_lengths),
            'median': int(np.median(all_lengths)),
            'max': int(np.max(all_lengths)),
            'mean': round(np.mean(all_lengths), 1),
        }
    else:
        stats = {'count': 0, 'median': None, 'max': None, 'mean': None}

    print(f"\n=== Length Statistics ===")
    print(f"Samples: {stats['count']}")
    print(f"Median: {stats['median']} tokens")
    print(f"Max: {stats['max']} tokens")

    return stats

---
## RUN EXPERIMENTS
---

In [7]:
max_model_length = 4096 # Reduced from max_model_length (131000, as in OWT) to prevent OOM errors

In [10]:
results = []

# ========================================
# STEP 1: Train BD3-LM Base (L'=1024)
# ========================================
print("\n" + "="*60)
print("STEP 1: Training BD3-LM Base (block_size=1024)")
print("="*60)

bd3lm_base_ckpt = train_run(
    run_name="bd3lm_base_L1024",
    algo="bd3lm",
    block_size=1024,
    max_steps=800,
    extra_overrides=[
        "training.resample=false",
        "algo.var_min=false",
    ]
)
print(f"✓ Base: {bd3lm_base_ckpt}")


STEP 1: Training BD3-LM Base (block_size=1024)

$ /usr/bin/python3 -u bd3lms/main.py mode=train data=openwebtext-split ++data.cache_dir=/content/bd3lms/data ++data.streaming=true ++data.max_train_samples=1500 ...
k if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Epoch 1, global step 568: 'val/nll' was not in top 1
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Epoch 1, global step 618: 'val/nll' was not in top 1
huggingface/tokeniz

In [11]:
# ========================================
# STEP 2: Fine-tune with L'=16
# ========================================
print("\n" + "="*60)
print("STEP 2: Fine-tuning BD3-LM (block_size=16)")
print("="*60)

bd3lm_ft_ckpt = train_run(
    run_name="bd3lm_finetune_L16",
    algo="bd3lm",
    block_size=16,
    from_pretrained=bd3lm_base_ckpt,
    max_steps=500,
    extra_overrides=[
        "algo.var_min=false",
    ]
)
print(f"✓ Fine-tuned: {bd3lm_ft_ckpt}")


STEP 2: Fine-tuning BD3-LM (block_size=16)

$ /usr/bin/python3 -u bd3lms/main.py mode=train data=openwebtext-split ++data.cache_dir=/content/bd3lms/data ++data.streaming=true ++data.max_train_samples=1500 ...
an either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Epoch 0, global step 300: 'val/nll' was not in top 1
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Epoch 0, global step 35

In [12]:
# ========================================
# STEP 3: Convert to HuggingFace format
# THIS IS THE KEY FIX!
# ========================================
print("\n" + "="*60)
print("STEP 3: Converting to HuggingFace format")
print("="*60)

hf_checkpoint = convert_ckpt_to_hf(
    ckpt_path=bd3lm_ft_ckpt,
    output_dir="/content/hf_checkpoints/bd3lm_L16",
    block_size=16
)
print(f"✓ HF checkpoint: {hf_checkpoint}")


STEP 3: Converting to HuggingFace format

Converting /content/repro_runs/bd3lm_finetune_L16/checkpoints/best.ckpt
To: /content/hf_checkpoints/bd3lm_L16

[1/4] Loading reference model from HuggingFace...


Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).



[2/4] Loading Lightning checkpoint...
   Found 93 parameters

[3/4] Cleaning state dict keys...
   Cleaned keys: ['sampling_eps_min', 'sampling_eps_max', 'vocab_embed.embedding']

[4/4] Applying our weights to model...
   Missing: 131, Unexpected: 91

✓ Saved to /content/hf_checkpoints/bd3lm_L16
✓ HF checkpoint: /content/hf_checkpoints/bd3lm_L16


In [13]:
# ========================================
# STEP 4: Variable-Length Generation
# ========================================
print("\n" + "="*60)
print("STEP 4: Variable-Length Generation (Table 6)")
print("="*60)

bd3lm_stats = eval_run(
    algo="bd3lm",
    hf_checkpoint_path=hf_checkpoint,  # HF dir, NOT .ckpt!
    block_size=16,
    num_samples=50,
    model_length=max_model_length
)

results.append({
    "model": "BD3-LM L'=16 (ours)",
    "median_tokens": bd3lm_stats['median'],
    "max_tokens": bd3lm_stats['max'],
})


STEP 4: Variable-Length Generation (Table 6)
Generating sample 1/50...
$ /usr/bin/python3 -u bd3lms/main.py mode=sample_eval data=openwebtext-split sampling.num_sample_batches=1 ++data.cache_dir=/content/bd3lms/data ++data.streaming=true ...
 German intelligence over the past decade, Spiegel and Guardian said.\n\nThe security conference will center on social networks such as hacking, surveillance, threats, government-to-government hacking and the terror networks and security policy.<|endoftext|>']
Generative perplexity: tensor(33.0763, device='cuda:0')
Entropy: tensor(5.1024, device='cuda:0')
['<|endoftext|>CLOSE The NSA says President Barack Obama\'s efforts to cryptanalyze phones in terrorism cases is getting worse and could pose a danger. David Martin reports. He also commented on the NSA\'s response to leaked intelligence reports by the New York Times. Matt Kryger for USA TODAY\n\nThe top US military commander in Germany is planning to host one of world\'s biggest security and te

In [14]:
# Train Autoregressive LM for comparison
print("\n" + "="*60)
print("Training Autoregressive LM (block_size=16)")
print("="*60)

ar_lm_ckpt = train_run(
	run_name="ar_lm_L16",
	algo="ar",
	block_size=16,
	max_steps=800,
)
print(f"✓ AR LM: {ar_lm_ckpt}")
ar_lm_hf_checkpoint = convert_ckpt_to_hf(
	ckpt_path=ar_lm_ckpt,
	output_dir="/content/hf_checkpoints/ar_lm_L16",
	block_size=16
)

ar_lm_stats = eval_run(
	algo="ar",
	hf_checkpoint_path=ar_lm_hf_checkpoint,
	block_size=16,
	num_samples=50,
  model_length=max_model_length
)

results.append({
	"model": "Autoregressive LM L'=16 (baseline)",
	"median_tokens": ar_lm_stats['median'],
	"max_tokens": ar_lm_stats['max'],
})


Training Autoregressive LM (block_size=16)

$ /usr/bin/python3 -u bd3lms/main.py mode=train data=openwebtext-split ++data.cache_dir=/content/bd3lms/data ++data.streaming=true ++data.max_train_samples=1500 ...
rk if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Epoch 1, global step 618: 'val/nll' reached 7.91222 (best 7.91222), saving model to '/content/repro_runs/ar_lm_L16/checkpoints/best.ckpt' as top 1
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLEL

In [15]:
# Train SEDD for comparison
print("\n" + "="*60)
print("Training SEDD (block_size=16)")
print("="*60)

sedd_ckpt = train_run(
	run_name="sedd_L16",
	algo="sedd",
	block_size=16,
	max_steps=800,
	extra_overrides=[
		"training.resample=false",
		"algo.var_min=false",
    "algo.parameterization=sedd"
	],
)
print(f"✓ SEDD: {sedd_ckpt}")
sedd_hf_checkpoint = convert_ckpt_to_hf(
	ckpt_path=sedd_ckpt,
	output_dir="/content/hf_checkpoints/sedd_L16",
	block_size=16
)

sedd_stats = eval_run(
	algo="sedd",
	hf_checkpoint_path=sedd_hf_checkpoint,
	block_size=16,
	num_samples=50,
  extra_overrides=[
      "algo.parameterization=sedd"
    ],
)

results.append({
	"model": "SEDD L'=16 (baseline)",
	"median_tokens": sedd_stats['median'],
	"max_tokens": sedd_stats['max'],
})


Training SEDD (block_size=16)

$ /usr/bin/python3 -u bd3lms/main.py mode=train data=openwebtext-split ++data.cache_dir=/content/bd3lms/data ++data.streaming=true ++data.max_train_samples=1500 ...
k if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Epoch 1, global step 568: 'val/nll' was not in top 1
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Epoch 1, global step 618: 'val/nll' was not in top 1
huggingface/tokenizers: The current 

In [16]:
# ========================================
# RESULTS
# ========================================
import pandas as pd

print("\n" + "="*60)
print("TABLE 6 RESULTS")
print("="*60)

comparison = [
    {"model": "SEDD (paper)", "median_tokens": 1021, "max_tokens": 1024},
    {"model": "BD3-LM L'=16 (paper)", "median_tokens": 798, "max_tokens": 9982},
] + results

df = pd.DataFrame(comparison)
print(df.to_string(index=False))

print("\n" + "="*60)
print("KEY INSIGHT: SEDD limited to 1024, BD3-LM generates ~10x longer!")
print("="*60)


TABLE 6 RESULTS
                             model  median_tokens  max_tokens
                      SEDD (paper)         1021.0      1024.0
              BD3-LM L'=16 (paper)          798.0      9982.0
               BD3-LM L'=16 (ours)          959.0      4095.0
Autoregressive LM L'=16 (baseline)            NaN         NaN
             SEDD L'=16 (baseline)            NaN         NaN

KEY INSIGHT: SEDD limited to 1024, BD3-LM generates ~10x longer!
