Skip to content

jvoltci/scree

Repository files navigation

scree hero banner

PyPI version Python versions License CI Docs GitHub stars

A cross-framework ragged tensor primitive for variable-length sequence data.

📚 Documentation · 📦 PyPI · 💬 Discussions · 🐛 Issues

import scree
import numpy as np

# Three sequences of different lengths.
seqs = [np.random.randn(n, 8).astype(np.float32) for n in [4, 2, 7]]

# Pack them into one scree.Array — no padding.
arr = scree.pack(seqs)
# arr.values: shape (13, 8), arr.offsets: [0, 4, 6, 13]

# Run varlen attention. Each sequence attends only to itself.
from scree.kernels.reference import varlen_attention
out = varlen_attention(arr, arr, arr, causal=True)

Why

Variable-length sequence data is everywhere in modern ML — transformer training, inference batching, multimodal interleaving, MoE routing — yet every team carries their own incompatible representation:

  • torch.nested (PyTorch only, in beta since 2021)
  • TF RaggedTensor (TensorFlow only)
  • FlashAttention cu_seqlens (a convention, not a typed primitive)
  • vLLM / SGLang packed batches (internal data structures)
  • HuggingFace attention_mask (pads, then masks — wasting memory and FLOPs)

scree ships one primitive — a packed values + offsets + ragged_dim array — that bridges across frameworks and ships with reference varlen kernels for attention, layernorm, softmax, and scatter/gather.

What you get

Memory savings vs HF padded on realistic LLM length distributions (log-normal):

Workload Mean savings Min – Max
Training-style (batch 64, mean_len 256, σ=0.6) 71% 63% – 84%
Inference-style (batch 32, mean_len 1024, σ=1.2) 85% 75% – 94%

Reproduce: python benchmarks/bench_memory.py

CPU throughput vs a naive padded attention baseline on a real batch (16 seqs × log-normal lengths, 1980 real / 4464 padded tokens, 4 heads × head_dim 32, fp32, no mask optimization):

Operation scree padded baseline Speedup
varlen_attention 34.7 ms 228.3 ms 6.6×
varlen_rmsnorm 0.13 ms 0.28 ms 2.1×

Reproduce: python benchmarks/bench_throughput.py

GPU forward + training step vs FlashAttention-2 on H100. Headline workload: 16 seqs × log-normal lengths, 12160 total tokens, 16 heads × head_dim 64, fp16, causal:

Operation FA-2 scree-Triton Ratio
forward only 0.165 ms 0.216 ms 1.30×
forward + backward (training step) 0.688 ms 1.106 ms 1.61×

Correctness: forward max abs diff 4.88e-04; backward dq 9.77e-04, dk 1.95e-03, dv 1.95e-03 vs FA-2 (all PASS within fp16 tolerance). Reproduce: modal run benchmarks/modal_bench.py + modal run benchmarks/modal_autograd_bench.py (~$0.40 of Modal credit total).

Across 27 shapes (head_dim × n_heads × mean_len): scree is closer to FA-2 on large workloads, slower on small ones (wrapper allocation overhead is per-call; it amortizes as the kernel grows). See benchmarks/modal_multishape_sweep.py.

Workload range Forward ratio Training-step ratio
Best (large: head_dim=64, n_heads=16, mean_len=2048) 1.21× 1.45×
Median across 27 shapes 1.95× 2.01×
Worst (toy: head_dim=32, n_heads=4, mean_len=256) 3.53× 1.77×

For production LLM training (head_dim ≥ 64, n_heads ≥ 8, mean_len ≥ 1024), expect 1.2–2.0× of FA-2.

Zero-copy bridges to the things you already use:

import scree.bridges as bridges

arr = scree.from_cu_seqlens(values, cu_seqlens)         # FlashAttention
arr = bridges.from_hf_padded(hidden_states, attn_mask)  # HuggingFace
arr = bridges.from_torch_nested(nt)                     # torch.nested

bridges.to_torch_nested(arr)   # → torch.NestedTensor
bridges.to_hf_padded(arr)      # → (hidden_states, attention_mask)
bridges.to_torch(arr)          # numpy values → torch tensors via DLPack

One primitive, every framework — values and offsets can be NumPy, PyTorch, MLX (Apple Silicon, via Metal), or JAX. All four backends pass the same correctness suite end-to-end.

The name

A scree is the irregular pile of rock fragments accumulated on a mountain slope. Variable-length sequences pack against each other the same way: irregular shapes, fitted by their irregularity, not despite it.

Status

v0.0.1 on PyPI. Reference Python kernels for all four backends and a full FA-2-style Triton kernel set (forward + backward + RMSNorm + LayerNorm) ship today. See the table below.

Component Status
scree.Array type + invariants
pack / unpack / to_padded / from_padded
Reference varlen attention / layernorm / softmax
Bridges: torch.nested, HF padded, FA cu_seqlens, DLPack
NumPy + PyTorch + MLX + JAX backends
Triton varlen_attention forward (H100) ✅ 1.30× of FA-2
Triton varlen_attention backward (FA-2 style: preprocess + dKV + dQ) ✅ 1.61× of FA-2 (full training step)
Triton varlen_rmsnorm (H100) ✅ 13.97× of PyTorch reference
Triton varlen_layernorm (H100) ✅ 1.31× of torch.nn.functional.layer_norm

Install

pip install scree              # numpy backend
pip install "scree[torch]"     # + PyTorch backend
pip install "scree[mlx]"       # + MLX backend (Apple Silicon, Metal)
pip install "scree[jax]"       # + JAX backend

Examples

Documentation

Full rendered site at https://jvoltci.github.io/scree/.

Contributing

PRs welcome. See CONTRIBUTING.md for the workflow. Open a GitHub Discussion for anything beyond a small fix.

License

Apache-2.0

About

Variable-length tensors with Triton kernels at 1.6× of FlashAttention-2 on H100. NumPy / PyTorch / MLX / JAX backends.

Topics

Resources

License

Contributing

Stars

Watchers

Forks

Packages

 
 
 

Contributors

Languages