Summary
Proposing a TernaryDense layer for NNX as a drop-in replacement for nn.Linear in ternary-weight models (BitNet b1.58 and successors).
Motivation
BitNet b1.58 (arxiv 2402.17764) showed that training with weights quantized to {-1, 0, +1} reaches competitive performance vs full-precision at scale, with significant inference savings at high sparsity (55–65% zero weights). Several follow-on works use the same approach. Community interest is growing as inference cost becomes a primary concern.
The main things a JAX/Flax user needs to run these models that aren't currently in the library:
-
A trainable ternary layer — float weights stored and trained, quantized to {-1, 0, +1} only during the forward pass, with gradients flowing back through a Straight-Through Estimator (STE). Without STE the layer is untrainable (jnp.sign has zero gradient).
-
Beta scaling — beta = mean(|W|) multiplied into the output restores activation scale lost by quantization (BitNet §3.1). Without it, the layer has systematically smaller outputs than the float baseline.
Proposed API (NNX)
class TernaryDense(nnx.Module):
def __init__(self, in_features, out_features, rngs, threshold=None):
...
def __call__(self, x):
abs_k = jnp.abs(self.kernel.value)
t = 0.5 * jnp.mean(abs_k) if self.threshold is None else self.threshold
k_q = jnp.sign(self.kernel.value) * (abs_k > t)
# STE: forward uses k_q, backward sees identity
k = self.kernel.value + jax.lax.stop_gradient(k_q - self.kernel.value)
beta = jnp.mean(abs_k) if self.threshold is None else 1.0
return beta * (x @ k) + (self.bias.value if self.bias else 0)
What I'd like to know before writing the PR
- Is
TernaryDense a fit for NNX, or is this better served by a standalone recipe/example in the docs?
- If it belongs in NNX, is there a preferred pattern for custom quantization layers (e.g. via
nnx.Variable subclasses)?
- Any existing precedent in NNX for quantization-aware training layers would help me match the style.
Happy to write the implementation once the approach is agreed — or to contribute it as a documented example if that is preferred.
cc @vfdev-5 (based on feedback on #5466)
Summary
Proposing a
TernaryDenselayer for NNX as a drop-in replacement fornn.Linearin ternary-weight models (BitNet b1.58 and successors).Motivation
BitNet b1.58 (arxiv 2402.17764) showed that training with weights quantized to
{-1, 0, +1}reaches competitive performance vs full-precision at scale, with significant inference savings at high sparsity (55–65% zero weights). Several follow-on works use the same approach. Community interest is growing as inference cost becomes a primary concern.The main things a JAX/Flax user needs to run these models that aren't currently in the library:
A trainable ternary layer — float weights stored and trained, quantized to
{-1, 0, +1}only during the forward pass, with gradients flowing back through a Straight-Through Estimator (STE). Without STE the layer is untrainable (jnp.signhas zero gradient).Beta scaling —
beta = mean(|W|)multiplied into the output restores activation scale lost by quantization (BitNet §3.1). Without it, the layer has systematically smaller outputs than the float baseline.Proposed API (NNX)
What I'd like to know before writing the PR
TernaryDensea fit for NNX, or is this better served by a standalone recipe/example in the docs?nnx.Variablesubclasses)?Happy to write the implementation once the approach is agreed — or to contribute it as a documented example if that is preferred.
cc @vfdev-5 (based on feedback on #5466)