A production-ready implementation of a discrete diffusion model for text generation. Unlike autoregressive models (like GPT), this model uses bidirectional attention and learns to denoise masked tokens through iterative refinement.
Check out fine tune version: https://github.com/nihilisticneuralnet/HinDiffusionLM
pip install torch numpy matplotlib# Prepare your data
echo "Your training text here..." > data.txt
# Train the model
python diffusion_lm.py --train
# Generate text
python diffusion_lm.pyfrom diffusion_lm import (
MaskedDiffusionLM,
ModelConfig,
TrainingConfig,
Trainer,
load_data
)
# Load data
data, itos, stoi, mask_token_id = load_data('data.txt')
n = int(0.9 * len(data))
train_data, val_data = data[:n], data[n:]
# Configure model
model_config = ModelConfig(
vocab_size=len(itos),
n_layer=6, # Number of transformer blocks
n_head=6, # Number of attention heads
n_embd=384, # Embedding dimension
block_size=256, # Maximum sequence length
dropout=0.1
)
# Configure training
train_config = TrainingConfig(
batch_size=64,
learning_rate=3e-4,
max_iters=10000,
warmup_iters=500,
eval_interval=500
)
# Create and train model
model = MaskedDiffusionLM(model_config)
trainer = Trainer(model, train_config, mask_token_id, train_data, val_data, itos)
trainer.train()from diffusion_lm import MaskedDiffusionLM, GenerationConfig
import torch
# Load model
model = MaskedDiffusionLM.from_pretrained('checkpoints/best_model.pt')
model.eval()
# Prepare prompt
prompt = torch.tensor([[1, 2, 3, 4]], device='cuda') # Your token IDs
# Configure generation
gen_config = GenerationConfig(
max_new_tokens=500,
temperature=0.8,
confidence_threshold=0.95,
top_k=3
)
# Generate
output, stats = model.generate(prompt, gen_config, mask_token_id)
print(f"Generated in {stats['total_steps']} steps")
print(f"Average tokens per step: {stats['avg_tokens_per_step']:.2f}")from utils import ModelEvaluator
evaluator = ModelEvaluator(model, itos, device='cuda')
# Calculate perplexity
ppl = evaluator.calculate_perplexity(val_data, mask_token_id)
print(f"Perplexity: {ppl:.2f}")
# Analyze decoding trajectory
trajectory = evaluator.analyze_decoding_trajectory(
prompt,
max_new_tokens=500,
mask_token_id=mask_token_id
)
# Visualize
evaluator.plot_decoding_trajectory(trajectory, 'trajectory.png')
# Test different masking rates
masking_results = evaluator.evaluate_masking_rates(
val_data,
mask_token_id,
masking_rates=[0.1, 0.3, 0.5, 0.7, 0.9]
)
evaluator.plot_masking_rate_analysis(masking_results, 'masking_rates.png')