Skip to content

[NNX] TernaryDense — ternary-weight dense layer for BitNet b1.58 style models #5468

@eriirfos-eng

Description

@eriirfos-eng

Summary

Proposing a TernaryDense layer for NNX as a drop-in replacement for nn.Linear in ternary-weight models (BitNet b1.58 and successors).

Motivation

BitNet b1.58 (arxiv 2402.17764) showed that training with weights quantized to {-1, 0, +1} reaches competitive performance vs full-precision at scale, with significant inference savings at high sparsity (55–65% zero weights). Several follow-on works use the same approach. Community interest is growing as inference cost becomes a primary concern.

The main things a JAX/Flax user needs to run these models that aren't currently in the library:

  1. A trainable ternary layer — float weights stored and trained, quantized to {-1, 0, +1} only during the forward pass, with gradients flowing back through a Straight-Through Estimator (STE). Without STE the layer is untrainable (jnp.sign has zero gradient).

  2. Beta scalingbeta = mean(|W|) multiplied into the output restores activation scale lost by quantization (BitNet §3.1). Without it, the layer has systematically smaller outputs than the float baseline.

Proposed API (NNX)

class TernaryDense(nnx.Module):
    def __init__(self, in_features, out_features, rngs, threshold=None):
        ...

    def __call__(self, x):
        abs_k = jnp.abs(self.kernel.value)
        t = 0.5 * jnp.mean(abs_k) if self.threshold is None else self.threshold
        k_q = jnp.sign(self.kernel.value) * (abs_k > t)
        # STE: forward uses k_q, backward sees identity
        k = self.kernel.value + jax.lax.stop_gradient(k_q - self.kernel.value)
        beta = jnp.mean(abs_k) if self.threshold is None else 1.0
        return beta * (x @ k) + (self.bias.value if self.bias else 0)

What I'd like to know before writing the PR

  • Is TernaryDense a fit for NNX, or is this better served by a standalone recipe/example in the docs?
  • If it belongs in NNX, is there a preferred pattern for custom quantization layers (e.g. via nnx.Variable subclasses)?
  • Any existing precedent in NNX for quantization-aware training layers would help me match the style.

Happy to write the implementation once the approach is agreed — or to contribute it as a documented example if that is preferred.

cc @vfdev-5 (based on feedback on #5466)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions