## Anchored Diffusion Inference Notebook
Use this notebook to load a trained ADLM checkpoint and generate text samples or compute evaluation metrics on a few samples

### Prerequisites
- The working directory is set to the repository root.
- A trained checkpoint is available on disk.

Adjust the parameters in the next cell to point to your checkpoint and sampling configuration.

In [1]:
from pathlib import Path
import os

# --- User configuration ----------------------------------------------------
# Resolve repository root either from ADLM_PROJECT_ROOT or by walking up from the current directory.
ENV_ROOT = os.environ.get('ADLM_PROJECT_ROOT')
start_path = Path(ENV_ROOT).expanduser().resolve() if ENV_ROOT else Path.cwd().resolve()
def find_repo_root(path: Path) -> Path:
    for candidate in [path, *path.parents]:
        if (candidate / 'configs').is_dir():
            return candidate
    raise FileNotFoundError(
        f'Could not locate a Hydra config directory starting from {path}. '
        'Set ADLM_PROJECT_ROOT or edit PROJECT_ROOT above.'
)
PROJECT_ROOT = find_repo_root(start_path)
CONFIG_DIR = PROJECT_ROOT / 'configs'
CHECKPOINT_PATH = PROJECT_ROOT / 'ckpts/adlm-large.ckpt'
# Supported checkpoints: adlm-medium.ckpt (262B tokens), adlm-large.ckpt (524B tokens)
NUM_SAMPLE_BATCHES = 10  # Number of batches to sample in this notebook
OUTPUT_JSON = PROJECT_ROOT / 'outputs/notebook_adlm_samples.json'

# Hydra/Lightning overrides mirroring the command-line evaluation settings
OVERRIDES = [
    'mode=sample_eval',
    'loader.batch_size=1',
    'loader.eval_batch_size=1',
    'eval.perplexity_batch_size=1',
    'data=openwebtext-split',
    'model=small',
    'parameterization=subs',
    'backbone=dit',
    'model.length=1024',
    'time_conditioning=false',
    '+wandb.offline=true',
    'T=0',
    'sampling.steps=4096',
    'sampling.nucleus_p=0.9',
    'sampling.sampler=remdm-loop',
    'sampling.eta=0.02',
    'sampling.t_on=0.55',
    'sampling.t_off=0.05',
    'sampling.alpha_on=0.9',
    f'sampling.num_sample_batches={NUM_SAMPLE_BATCHES}',
    f'sampling.generated_seqs_path={OUTPUT_JSON.as_posix()}',
    f'eval.checkpoint_path={CHECKPOINT_PATH.as_posix()}',
]

# Ensure we operate from the project root so Hydra finds configs
os.chdir(PROJECT_ROOT)
os.environ.setdefault('HYDRA_FULL_ERROR', '1')
print('Working directory set to', Path.cwd())
print('Checkpoint path:', CHECKPOINT_PATH)

Working directory set to /hdd1/lr/discrete/ADLM
Checkpoint path: /hdd1/lr/discrete/ADLM/ckpts/adlm-large.ckpt


In [2]:
import logging
from hydra import compose, initialize
from hydra.core.global_hydra import GlobalHydra
from omegaconf import OmegaConf

import dataloader
import utils
from adlm_main import _load_from_checkpoint, generate_samples

# Reset Hydra to allow multiple executions in the same notebook
GlobalHydra.instance().clear()

config_dir = PROJECT_ROOT / 'configs'
if not config_dir.is_dir():
    raise FileNotFoundError(
        f"Hydra config directory not found at {config_dir}. "
        "Run cell 2 to set PROJECT_ROOT correctly or update the path above.")

notebook_dir = PROJECT_ROOT / 'notebooks'
config_path_rel = os.path.relpath(config_dir, start=notebook_dir)
with initialize(version_base=None, config_path=config_path_rel):
    cfg = compose(config_name='config_adlm', overrides=OVERRIDES)

print('Mode:', cfg.mode)
print('Sampling batches:', cfg.sampling.num_sample_batches)
print('Checkpoint:', cfg.eval.checkpoint_path)
print('Generated sequences will be saved to:', cfg.sampling.generated_seqs_path)

# Optional: inspect a compact view of the config
display_dict = OmegaConf.to_container(cfg, resolve=True)
{k: display_dict[k] for k in ['mode', 'data', 'model', 'parameterization', 'sampling']}

Mode: sample_eval
Sampling batches: 10
Checkpoint: /hdd1/lr/discrete/ADLM/ckpts/adlm-large.ckpt
Generated sequences will be saved to: /hdd1/lr/discrete/ADLM/outputs/notebook_adlm_samples.json


{'mode': 'sample_eval',
 'data': {'train': 'openwebtext-train',
  'valid': 'openwebtext-valid',
  'tokenizer_name_or_path': 'gpt2',
  'cache_dir': '/hdd1/lr/data',
  'wrap': True,
  'streaming': False},
 'model': {'name': 'small',
  'type': 'ddit',
  'hidden_size': 768,
  'cond_dim': 128,
  'length': 1024,
  'n_blocks': 12,
  'n_heads': 12,
  'scale_by_sigma': True,
  'dropout': 0.1,
  'tie_word_embeddings': False},
 'parameterization': 'subs',
 'sampling': {'predictor': 'ddpm_cache',
  'steps': 4096,
  'noise_removal': True,
  'num_sample_batches': 10,
  'num_sample_log': 2,
  'semi_ar': False,
  'stride_length': 1,
  'num_strides': 1,
  'generated_seqs_path': '/hdd1/lr/discrete/ADLM/outputs/notebook_adlm_samples.json',
  'nucleus_p': 0.9,
  'eta': 0.02,
  'sampler': 'remdm-loop',
  't_on': 0.55,
  't_off': 0.05,
  'alpha_on': 0.9,
  'dfm': False}}

In [3]:
# Prepare tokenizer, logger, and model
import torch
logger = logging.getLogger('adlm_inference')
if not logger.handlers:
    handler = logging.StreamHandler()
    handler.setLevel(logging.INFO)
    logger.addHandler(handler)
logger.setLevel(logging.INFO)

tokenizer = dataloader.get_tokenizer(cfg)
logger.info('Tokenizer loaded. Vocabulary size: %d', tokenizer.vocab_size)

device = 'cuda' if torch.cuda.is_available() and torch.cuda.device_count() > 0 else 'cpu'
logger.info('Using device: %s', device)

model = _load_from_checkpoint(config=cfg, tokenizer=tokenizer)
model = model.to(device)
model.eval()
logger.info('Model loaded and moved to %s', device)

Tokenizer loaded. Vocabulary size: 50257
Using device: cuda
Using device: cuda
Model loaded and moved to cuda
Model loaded and moved to cuda


In [4]:
# Run sampling/inference
if 'cfg' not in globals():
    raise RuntimeError("Configuration not initialized. Run cell 3 to compose the Hydra config before sampling.")
if 'logger' not in globals():
    raise RuntimeError("Logger not initialized. Run cell 4 before sampling.")
if 'tokenizer' not in globals():
    raise RuntimeError("Tokenizer not initialized. Run cell 4 before sampling.")

samples, gen_ppl, entropies = generate_samples(cfg, logger, tokenizer)

logger.info('Generated %d samples', len(samples))
logger.info('Average entropy: %.4f', sum(entropies) / max(len(entropies), 1))
logger.info('Generative perplexity: %.3f', gen_ppl)

# Preview a few sample strings
for idx, text in enumerate(samples[: min(5, len(samples))]):
    print('=' * 80)
    print(f'Sample {idx + 1}')
    print(text)

OUTPUT_JSON.parent.mkdir(parents=True, exist_ok=True)
print('Results JSON saved to:', cfg.sampling.generated_seqs_path)

Generating samples.


Starting Batch  0


100%|██████████| 4096/4096 [01:30<00:00, 45.15it/s]



Starting Batch  1


100%|██████████| 4096/4096 [01:30<00:00, 45.18it/s] 



Starting Batch  2


100%|██████████| 4096/4096 [01:30<00:00, 45.29it/s] 



Starting Batch  3


100%|██████████| 4096/4096 [01:30<00:00, 45.08it/s] 



Starting Batch  4


100%|██████████| 4096/4096 [01:30<00:00, 45.06it/s]



Starting Batch  5


100%|██████████| 4096/4096 [01:30<00:00, 45.31it/s] 



Starting Batch  6


100%|██████████| 4096/4096 [01:30<00:00, 45.40it/s]



Starting Batch  7


100%|██████████| 4096/4096 [01:30<00:00, 45.34it/s] 



Starting Batch  8


100%|██████████| 4096/4096 [01:30<00:00, 45.17it/s] 



Starting Batch  9


100%|██████████| 4096/4096 [01:30<00:00, 45.04it/s]

Generated 10 samples
Average entropy: 5.2350
Generative perplexity: 15.828
Generated 10 samples
Average entropy: 5.2350
Generative perplexity: 15.828


Sample 1
<|endoftext|> It’s smart. It’s much better than having a government, which is very monopolistic. You can find entrepreneurs who solve problems by giving different people problems. If you’re an entrepreneur, you can do all kinds of things faster than if you had a government.

Do you have a vision of what a decentralized government could look like?

My vision is that only a small minority of people ever want to solve a problem, and if they do, they have an overwhelming chance of being able to solve it. The problem with that is because our entire society is wasteful and the whole world is, like, a small minority of all that cares. And when a small minority tries to solve a problem, they usually get their ass kicked and aren’t able to solve it on their own.

What do you think of the kind of global problems that we’re facing now before we have a decentralized government like what we have today?

Who’s heard of a decentralized government? I’ve never even heard of. There are so many 