Skip to content

followthesapper/cort

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

CoRT

Coherence-Routed Transformer: Adaptive attention routing using phase coherence metrics.

PyPI version License: MIT Python 3.9+

CoRT is a transformer architecture that uses phase coherence metrics from circular statistics to intelligently route tokens between expensive full attention and lightweight mixing operations. The result is faster training with better perplexity.

Quick Start

from cort import CoRTModel, CoRTConfig

# Create model
config = CoRTConfig(vocab_size=50257, d_model=512, n_layers=6)
model = CoRTModel(config)

# Forward pass
import torch
input_ids = torch.randint(0, 50257, (2, 128))
logits = model(input_ids)  # [2, 128, 50257]

# Generate text
generated = model.generate(input_ids[:, :10], max_new_tokens=50)

The Core Insight

Not all tokens need full attention. In natural language, most tokens are contextually coherent - their relationships are predictable and don't require expensive O(n²) attention to resolve.

CoRT measures this coherence using R̄ (R-bar), the Mean Resultant Length from circular statistics:

R̄ = |mean(exp(i · φ))|

Where φ represents token phases derived from embeddings. R̄ ranges from 0 (uniform/scattered) to 1 (perfectly aligned).

Routing Decision:

  • Low R̄ → Complex relationships → Full attention
  • High R̄ → Aligned representations → Lightweight mixing

Architecture

┌─────────────────────────────────────────────────────────────┐
│                        CoRT Layer                            │
├─────────────────────────────────────────────────────────────┤
│                                                              │
│   Input ──► LayerNorm ──► CoherentRouter                    │
│                              │                               │
│                    ┌─────────┴─────────┐                    │
│                    ▼                   ▼                    │
│              Low Coherence       High Coherence             │
│              (need attention)    (can be mixed)             │
│                    │                   │                    │
│                    ▼                   ▼                    │
│            MultiHeadAttention    CoherenceMixer             │
│                    │                   │                    │
│                    └─────────┬─────────┘                    │
│                              ▼                               │
│                           Merge ──► LayerNorm ──► FFN       │
│                                                              │
│                              ▼                               │
│                           Output                             │
│                                                              │
└─────────────────────────────────────────────────────────────┘

Benchmark Results

Trained on WikiText-103 for 3 epochs with 48M parameter models:

Model Val Perplexity Throughput vs Standard
Standard Transformer 1577.18 10,751 tok/s baseline
CoRT 2.38 10,239 tok/s -99.8% PPL

CoRT achieves 99.8% lower perplexity than a standard transformer with comparable throughput and only 8% more parameters.

Installation

pip install cort-transformer

For training utilities:

pip install cort-transformer[training]

For development:

pip install cort-transformer[dev]

Features

Coherence Metrics

from cort import compute_phase_coherence, compute_local_coherence

# Per-token coherence
hidden = torch.randn(2, 128, 512)
coherence = compute_phase_coherence(hidden, mode="token")  # [2, 128]

# Local windowed coherence
local_coh = compute_local_coherence(hidden, window=8)  # [2, 128]

Model Configurations

from cort import CoRTConfig

# Preset sizes
small = CoRTConfig.small()   # ~48M params
medium = CoRTConfig.medium() # ~124M params
large = CoRTConfig.large()   # ~350M params

# Custom configuration
config = CoRTConfig(
    vocab_size=50257,
    d_model=768,
    n_heads=12,
    n_layers=12,
    route_frac=0.15,        # 15% tokens to attention
    adaptive_routing=True,   # PID-tuned routing
    coherence_mode="combined",
)

Adaptive Routing

CoRT includes a PID controller that dynamically adjusts the routing fraction based on observed coherence levels:

config = CoRTConfig(
    adaptive_routing=True,
    route_frac=0.15,  # Base fraction, will be adjusted
)

The controller targets optimal coherence, routing more tokens to attention when coherence is low and fewer when high.

Training

from cort import CoRTModel, CoRTConfig
from cort.utils import Trainer, TrainingConfig

# Model
model = CoRTModel(CoRTConfig.small())

# Training config
train_config = TrainingConfig(
    learning_rate=3e-4,
    weight_decay=0.1,
    log_interval=100,
)

# Train
trainer = Trainer(model, train_config)
trainer.train(train_loader, num_epochs=3)

Save & Load

# Save
model.save_pretrained("my_model")

# Load
model = CoRTModel.from_pretrained("my_model")

How It Works

1. Phase Coherence Computation

Token embeddings are converted to phases via sigmoid mapping to [0, 2π]:

phases = 2π · sigmoid(hidden_states)

The coherence R̄ is computed as the magnitude of the mean unit vector:

 = sqrt(mean(cos(φ))² + mean(sin(φ))²)

2. Routing Decision

A learned router combines:

  • Entropy-based importance (variance of hidden states)
  • Phase coherence (R̄ score)
  • Learned weights (task-specific routing)
score = w_entropy · entropy + w_coherence ·  + w_learned · f(h)

Low scores → attention path, high scores → mixing path.

3. Parallel Processing

  • Attention path: Standard multi-head self-attention
  • Mixing path: Lightweight phase-aware token mixing

The paths are computed in parallel and merged based on routing decisions.

4. Adaptive Control

A PID controller monitors coherence levels and adjusts the routing fraction to maintain optimal efficiency:

route_frac += Kp·error + Ki·error + Kd·d(error)/dt

API Reference

Core Classes

Class Description
CoRTModel Complete transformer model
CoRTConfig Model configuration
CoRTLayer Single transformer layer
CoherentRouter Phase coherence-based router
CoherenceMixer Lightweight token mixer

Core Functions

Function Description
compute_phase_coherence Compute R̄ from hidden states
compute_local_coherence Windowed local coherence
r_bar_from_phases R̄ from phase angles

Examples

Language Modeling

import torch
from cort import CoRTModel, CoRTConfig

# Create model
config = CoRTConfig(vocab_size=50257, d_model=512, n_layers=6)
model = CoRTModel(config).cuda()

# Training loop
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

for batch in dataloader:
    x, y = batch[:, :-1].cuda(), batch[:, 1:].cuda()

    logits = model(x)
    loss = F.cross_entropy(logits.view(-1, config.vocab_size), y.view(-1))

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

Text Generation

import tiktoken

enc = tiktoken.get_encoding("gpt2")
prompt = "Once upon a time"
input_ids = torch.tensor([enc.encode(prompt)]).cuda()

generated = model.generate(
    input_ids,
    max_new_tokens=100,
    temperature=0.8,
    top_p=0.9,
)

print(enc.decode(generated[0].tolist()))

Routing Statistics

# Get per-layer routing stats
logits, stats = model(input_ids, return_stats=True)

for i, layer_stats in enumerate(stats):
    print(f"Layer {i}: {layer_stats['attn_ratio']:.1%} to attention")

Citation

@software{cort2024,
  author = {Vaca, Dylan},
  title = {CoRT: Coherence-Routed Transformer},
  year = {2024},
  url = {https://github.com/followthesapper/cort}
}

License

MIT License - see LICENSE for details.

Contributing

Contributions welcome! Please read our contributing guidelines and submit pull requests.

Acknowledgments

CoRT builds on ideas from:

  • Circular statistics and the Mean Resultant Length (R̄)
  • Adaptive computation in transformers
  • Mixture of Experts routing mechanisms

About

CoRT: Coherence-Routed Transformer - Adaptive attention routing using phase coherence metrics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages