A PyTorch implementation of Multiscreen, the screening-based language model architecture from "Screening Is Enough" (Nakanishi, 2026).
Multiscreen replaces softmax attention with screening: each key is evaluated independently against a learned threshold, removing global competition between keys. This enables absolute query-key relevance, stable training at large learning rates, and strong long-context retrieval.
- Pure PyTorch reference implementation of the Multiscreen model
- 2.6x faster training than naive PyTorch via
torch.compile(~41k tok/s on RTX 5070 Ti for a 154M model) - KV cache inference with up to 5x faster decoding via
torch.compile+ cache - Generic training script using HuggingFace
datasets+tokenizers - Gradient checkpointing for low-VRAM training (-75% VRAM)
- CPU-friendly tests (20 unit tests, no GPU required)
git clone https://github.com/dieOD/multiscreen-pytorch
cd multiscreen-pytorch
pip install -e ".[train]"For the optional torch.compile speedup:
pip install -e ".[train,perf]"On Windows, you also need MSVC (Visual Studio Build Tools with the C++ workload). See docs/setup.md for details.
Train a tiny Multiscreen model on TinyStories with the GPT-2 tokenizer:
python scripts/train.py \
--dataset roneneldan/TinyStories \
--psi 8 \
--max-steps 1000 \
--micro-batch 16This builds an ~8M parameter model (Psi=8 -> 8 layers, 8 heads, hidden_dim=64).
For a paper-comparable 154M run on Wikitext-103:
python scripts/train.py \
--dataset wikitext --config wikitext-103-raw-v1 \
--hidden-dim 1024 --num-layers 18 --num-heads 18 \
--key-dim 32 --value-dim 128 --seq-len 256 \
--max-steps 17000 --peak-lr 1e-2 \
--micro-batch 32 --grad-accum 16 \
--compileGreedy generation with the KV cache (see examples/generation.py):
import torch
from multiscreen import MultiscreenConfig, MultiscreenModel
model = MultiscreenModel(MultiscreenConfig(...)).eval()
# 1. Prefill the prompt in one forward pass
prompt_ids = torch.tensor([[1, 2, 3, 4]])
logits, kv_caches = model(prompt_ids)
next_logits = logits[:, -1, :]
# 2. Incremental decode
for step in range(max_new_tokens):
next_id = next_logits.argmax(dim=-1, keepdim=True)
logits, kv_caches = model(
next_id,
start_pos=prompt_ids.shape[1] + step,
kv_caches=kv_caches,
)
next_logits = logits[:, -1, :]See docs/inference.md for the design (what's cached, how the softmask is rebuilt per step, caveats).
Training throughput / VRAM:
# Default 154M config, B=16
python scripts/benchmark.py
# With torch.compile (~2.6x faster)
python scripts/benchmark.py --compile --batch-size 32
# Export Chrome trace for kernel-level inspection
python scripts/benchmark.py --traceInference throughput:
# Compare eager/compile × full-reforward/KV-cache on 200M model
python scripts/bench_inference.py --prompt 256 --generate 128 --compileEach Multiscreen layer contains N_H parallel gated screening tiles. A tile:
- Projects input into Q, K, V, G
- Normalizes Q, K, V to unit length
- Applies MiPE (RoPE-like rotation, only the first 2 dims, only when window is short)
- Computes bounded similarity:
s = q . k^Tin[-1, 1] - Trim-and-Square:
rho = max(1 - r(1-s), 0)^2 - Softmask: causal + distance-aware cosine window of width
w - Aggregates:
h = sum_j rho_d_ij * v_j - TanhNorm:
tanh(||h||) / ||h|| * h(bounds output norm by 1) - Gates with
tanh(silu(g))and projects back to model dim
r, w are per-head learned parameters. See docs/architecture.md for the math.
multiscreen-pytorch/
├── multiscreen/
│ ├── config.py # MultiscreenConfig
│ ├── model.py # MultiscreenModel + GatedScreeningBlock (with KV cache)
│ ├── data.py # PackedTextDataset (HF datasets loader)
│ ├── trainer.py # Trainer with AMP, grad accum, checkpointing
│ └── compile_utils.py # torch.compile / MSVC environment helpers
├── scripts/
│ ├── train.py # End-to-end training script
│ ├── benchmark.py # Training throughput / VRAM benchmark
│ └── bench_inference.py # Inference throughput benchmark
├── examples/
│ ├── quickstart.py # Minimal forward + backward
│ └── generation.py # Greedy decode with KV cache
├── tests/
│ ├── test_model.py # 14 unit tests for model / gradients / MiPE
│ └── test_kv_cache.py # 6 unit tests for incremental decode correctness
└── docs/
├── architecture.md
├── inference.md # KV cache design and performance
├── setup.md
└── speedup.md
The default model implementation includes several optimizations beyond the naive paper transcription.
Training (154M model, batch=32, seq_len=256, bf16, RTX 5070 Ti):
| Optimization | What changed | Speedup |
|---|---|---|
| MiPE in-place rotation | Replace torch.cat with index assignment |
~2-3% |
| Fused trim-square-mask | Reduce T×T intermediates from 3 to 2 via in-place ops | ~10-15% |
torch.compile |
inductor backend fuses element-wise ops | 2.4x |
| Gradient checkpointing | Trade compute for VRAM (-75% VRAM, enables larger batch) | (compute-bound) |
See docs/speedup.md for the full training optimization journey, including a CUDA-time profile.
Inference (200M model, bf16, prompt=256, generate=128, RTX 5070 Ti):
| Configuration | per-token (ms) | tok/s | Speedup |
|---|---|---|---|
| eager + full re-forward | 16.83 | 59 | 1.0x |
| eager + KV cache | 16.85 | 59 | 1.0x |
| compile + full re-forward | 4.77 | 210 | 3.5x |
| compile + KV cache | 3.41 | 293 | 5.0x |
KV cache alone is a tie in eager mode (the screening matmul is only ~4% of
total work) but stacks cleanly on top of torch.compile, adding another
~40% throughput for a total 5x speedup over the eager baseline. See
docs/inference.md for the design and when each
configuration matters.
This is an unofficial third-party implementation. The original paper authors have a custom Triton implementation (Section 4.5) which is not yet publicly available as far as we know. This repo aims to be the most complete pure-PyTorch reference for researchers wanting to experiment with screening.
If you use this code, please cite the original paper:
@article{nakanishi2026screening,
title={Screening Is Enough},
author={Nakanishi, Ken M.},
journal={arXiv preprint arXiv:2604.01178},
year={2026}
}Apache 2.0. See LICENSE.