Skip to content

linen: add TernaryDense — ternary-weight linear layer {-1, 0, +1}#5466

Closed
eriirfos-eng wants to merge 2 commits into
google:mainfrom
eriirfos-eng:feat/linen/ternary-dense
Closed

linen: add TernaryDense — ternary-weight linear layer {-1, 0, +1}#5466
eriirfos-eng wants to merge 2 commits into
google:mainfrom
eriirfos-eng:feat/linen/ternary-dense

Conversation

@eriirfos-eng
Copy link
Copy Markdown

Summary

Adds nn.TernaryDense to flax/linen/linear.py as a drop-in replacement for nn.Dense for ternary-weight models (BitNet b1.58, et al.).

Motivation

Ternary-weight models train directly in {-γ, 0, +γ} — no post-training quantization step. On every forward pass the kernel is quantized on-the-fly to {-1, 0, +1}; the underlying float parameters remain in the variable collection so that gradient-based optimisers can update them between steps (enabling training, not just inference).

Design

layer = nn.TernaryDense(features=64)
params = layer.init(jax.random.key(0), jnp.ones((batch, dim)))
y = layer.apply(params, x)

The class mirrors nn.Dense exactly — same attributes, same @compact pattern, same lax.dot_general call — with one addition: a threshold attribute and a three-line quantization step before the matmul:

abs_k = jnp.abs(kernel)
t = jnp.mean(abs_k) if self.threshold is None else self.threshold
kernel = jnp.sign(kernel) * (abs_k > t).astype(kernel.dtype)
  • threshold=None (default): mean(|kernel|) per call — the BitNet b1.58 §3.1 rule (https://arxiv.org/abs/2402.17764).
  • threshold=0.0: all weights become ±1, no zeros.
  • Large threshold (e.g. 1e9): all-zero effective kernel.

The parameter shape and variable collection are identical to nn.Dense, so checkpoints from a TernaryDense layer can be loaded into a Dense of the same shape and vice versa.

Files changed

File Change
flax/linen/linear.py +96 lines — TernaryDense class after Dense
flax/linen/__init__.py +1 export

Reference

🤖 Generated with Claude Code

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@google-cla
Copy link
Copy Markdown

google-cla Bot commented May 20, 2026

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@eriirfos-eng
Copy link
Copy Markdown
Author

I signed the CLA.

@eriirfos-eng
Copy link
Copy Markdown
Author

Proactively flagging two gaps before review:

1. Missing STE — layer is not trainable as-is

The forward pass calls jnp.sign which has zero gradient almost everywhere. Gradients die at the quantization boundary and the float params never update. Fix: jax.custom_vjp straight-through estimator so the backward treats the quantizer as identity.

@jax.custom_vjp
def ternary_quantize(w, threshold):
    return jnp.where(jnp.abs(w) > threshold, jnp.sign(w), 0.0)

def ternary_quantize_fwd(w, threshold):
    return ternary_quantize(w, threshold), (w, threshold)

def ternary_quantize_bwd(res, g):
    w, threshold = res
    # STE: pass gradient straight through (clamp optional)
    return g, None

ternary_quantize.defvjp(ternary_quantize_fwd, ternary_quantize_bwd)

2. Missing beta scale — deviates from BitNet b1.58

BitNet b1.58 (§3.1) multiplies the ternary output by beta = mean(|W|) to preserve the activation scale:

beta = jnp.mean(jnp.abs(kernel))          # per BitNet b1.58 §3.1
y = x @ (ternary_quantize(kernel, threshold) * beta)

Without this the output magnitude shrinks as the model learns sparser weights, destabilising training.

3. Storage (inference path)

For inference-only deployments the float32 params are unnecessary. A companion TernaryDenseInference that stores weights as jnp.int8 (or a packed 2-bit representation) would give the memory reduction. Happy to add that as a follow-up if Flax maintainers want it in the same PR.

Will push a fix for (1) and (2) shortly.

- STE via lax.stop_gradient: kernel_ste = kernel + sg(kernel_q - kernel)
  so gradients flow straight through the quantization boundary
- beta = mean(|kernel|) multiplied into the output when using auto
  threshold, restoring activation scale as weights become sparse

Without STE the quantization boundary kills gradients and the float
params never update. Without beta the output magnitude shrinks as
sparsity grows, destabilising training.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@vfdev-5
Copy link
Copy Markdown
Collaborator

vfdev-5 commented May 21, 2026

@eriirfos-eng thanks for the PR! However, please keep in mind that the approach for the contributions of new features is: first, open an issue to discuss the need and, second, if approved code and open a PR.

Contributions of new features to linen are unlikely to land. Even for nnx this TernaryDense layer we should first see if several people asks for that as we want to keep the library small and containing essential things.

@eriirfos-eng
Copy link
Copy Markdown
Author

@vfdev-5 — understood, thanks for the clear explanation of the process.

I opened #5468 as the proper discussion issue for NNX TernaryDense. I'll wait there for maintainer feedback before writing any code.

Since new linen features are unlikely to land and the right path is the issue-first process, I'll close this PR. If the discussion in #5468 leads to an approved direction I'll open a fresh, correctly-targeted PR then.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants