Coherence-Routed Transformer: Adaptive attention routing using phase coherence metrics.
CoRT is a transformer architecture that uses phase coherence metrics from circular statistics to intelligently route tokens between expensive full attention and lightweight mixing operations. The result is faster training with better perplexity.
from cort import CoRTModel, CoRTConfig
# Create model
config = CoRTConfig(vocab_size=50257, d_model=512, n_layers=6)
model = CoRTModel(config)
# Forward pass
import torch
input_ids = torch.randint(0, 50257, (2, 128))
logits = model(input_ids) # [2, 128, 50257]
# Generate text
generated = model.generate(input_ids[:, :10], max_new_tokens=50)Not all tokens need full attention. In natural language, most tokens are contextually coherent - their relationships are predictable and don't require expensive O(n²) attention to resolve.
CoRT measures this coherence using R̄ (R-bar), the Mean Resultant Length from circular statistics:
R̄ = |mean(exp(i · φ))|
Where φ represents token phases derived from embeddings. R̄ ranges from 0 (uniform/scattered) to 1 (perfectly aligned).
Routing Decision:
- Low R̄ → Complex relationships → Full attention
- High R̄ → Aligned representations → Lightweight mixing
┌─────────────────────────────────────────────────────────────┐
│ CoRT Layer │
├─────────────────────────────────────────────────────────────┤
│ │
│ Input ──► LayerNorm ──► CoherentRouter │
│ │ │
│ ┌─────────┴─────────┐ │
│ ▼ ▼ │
│ Low Coherence High Coherence │
│ (need attention) (can be mixed) │
│ │ │ │
│ ▼ ▼ │
│ MultiHeadAttention CoherenceMixer │
│ │ │ │
│ └─────────┬─────────┘ │
│ ▼ │
│ Merge ──► LayerNorm ──► FFN │
│ │
│ ▼ │
│ Output │
│ │
└─────────────────────────────────────────────────────────────┘
Trained on WikiText-103 for 3 epochs with 48M parameter models:
| Model | Val Perplexity | Throughput | vs Standard |
|---|---|---|---|
| Standard Transformer | 1577.18 | 10,751 tok/s | baseline |
| CoRT | 2.38 | 10,239 tok/s | -99.8% PPL |
CoRT achieves 99.8% lower perplexity than a standard transformer with comparable throughput and only 8% more parameters.
pip install cort-transformerFor training utilities:
pip install cort-transformer[training]For development:
pip install cort-transformer[dev]from cort import compute_phase_coherence, compute_local_coherence
# Per-token coherence
hidden = torch.randn(2, 128, 512)
coherence = compute_phase_coherence(hidden, mode="token") # [2, 128]
# Local windowed coherence
local_coh = compute_local_coherence(hidden, window=8) # [2, 128]from cort import CoRTConfig
# Preset sizes
small = CoRTConfig.small() # ~48M params
medium = CoRTConfig.medium() # ~124M params
large = CoRTConfig.large() # ~350M params
# Custom configuration
config = CoRTConfig(
vocab_size=50257,
d_model=768,
n_heads=12,
n_layers=12,
route_frac=0.15, # 15% tokens to attention
adaptive_routing=True, # PID-tuned routing
coherence_mode="combined",
)CoRT includes a PID controller that dynamically adjusts the routing fraction based on observed coherence levels:
config = CoRTConfig(
adaptive_routing=True,
route_frac=0.15, # Base fraction, will be adjusted
)The controller targets optimal coherence, routing more tokens to attention when coherence is low and fewer when high.
from cort import CoRTModel, CoRTConfig
from cort.utils import Trainer, TrainingConfig
# Model
model = CoRTModel(CoRTConfig.small())
# Training config
train_config = TrainingConfig(
learning_rate=3e-4,
weight_decay=0.1,
log_interval=100,
)
# Train
trainer = Trainer(model, train_config)
trainer.train(train_loader, num_epochs=3)# Save
model.save_pretrained("my_model")
# Load
model = CoRTModel.from_pretrained("my_model")Token embeddings are converted to phases via sigmoid mapping to [0, 2π]:
phases = 2π · sigmoid(hidden_states)The coherence R̄ is computed as the magnitude of the mean unit vector:
R̄ = sqrt(mean(cos(φ))² + mean(sin(φ))²)A learned router combines:
- Entropy-based importance (variance of hidden states)
- Phase coherence (R̄ score)
- Learned weights (task-specific routing)
score = w_entropy · entropy + w_coherence · R̄ + w_learned · f(h)Low scores → attention path, high scores → mixing path.
- Attention path: Standard multi-head self-attention
- Mixing path: Lightweight phase-aware token mixing
The paths are computed in parallel and merged based on routing decisions.
A PID controller monitors coherence levels and adjusts the routing fraction to maintain optimal efficiency:
route_frac += Kp·error + Ki·∫error + Kd·d(error)/dt| Class | Description |
|---|---|
CoRTModel |
Complete transformer model |
CoRTConfig |
Model configuration |
CoRTLayer |
Single transformer layer |
CoherentRouter |
Phase coherence-based router |
CoherenceMixer |
Lightweight token mixer |
| Function | Description |
|---|---|
compute_phase_coherence |
Compute R̄ from hidden states |
compute_local_coherence |
Windowed local coherence |
r_bar_from_phases |
R̄ from phase angles |
import torch
from cort import CoRTModel, CoRTConfig
# Create model
config = CoRTConfig(vocab_size=50257, d_model=512, n_layers=6)
model = CoRTModel(config).cuda()
# Training loop
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
for batch in dataloader:
x, y = batch[:, :-1].cuda(), batch[:, 1:].cuda()
logits = model(x)
loss = F.cross_entropy(logits.view(-1, config.vocab_size), y.view(-1))
optimizer.zero_grad()
loss.backward()
optimizer.step()import tiktoken
enc = tiktoken.get_encoding("gpt2")
prompt = "Once upon a time"
input_ids = torch.tensor([enc.encode(prompt)]).cuda()
generated = model.generate(
input_ids,
max_new_tokens=100,
temperature=0.8,
top_p=0.9,
)
print(enc.decode(generated[0].tolist()))# Get per-layer routing stats
logits, stats = model(input_ids, return_stats=True)
for i, layer_stats in enumerate(stats):
print(f"Layer {i}: {layer_stats['attn_ratio']:.1%} to attention")@software{cort2024,
author = {Vaca, Dylan},
title = {CoRT: Coherence-Routed Transformer},
year = {2024},
url = {https://github.com/followthesapper/cort}
}MIT License - see LICENSE for details.
Contributions welcome! Please read our contributing guidelines and submit pull requests.
CoRT builds on ideas from:
- Circular statistics and the Mean Resultant Length (R̄)
- Adaptive computation in transformers
- Mixture of Experts routing mechanisms