linen: add TernaryDense — ternary-weight linear layer {-1, 0, +1}#5466
linen: add TernaryDense — ternary-weight linear layer {-1, 0, +1}#5466eriirfos-eng wants to merge 2 commits into
Conversation
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
|
I signed the CLA. |
|
Proactively flagging two gaps before review: 1. Missing STE — layer is not trainable as-is The forward pass calls @jax.custom_vjp
def ternary_quantize(w, threshold):
return jnp.where(jnp.abs(w) > threshold, jnp.sign(w), 0.0)
def ternary_quantize_fwd(w, threshold):
return ternary_quantize(w, threshold), (w, threshold)
def ternary_quantize_bwd(res, g):
w, threshold = res
# STE: pass gradient straight through (clamp optional)
return g, None
ternary_quantize.defvjp(ternary_quantize_fwd, ternary_quantize_bwd)2. Missing beta scale — deviates from BitNet b1.58 BitNet b1.58 (§3.1) multiplies the ternary output by beta = jnp.mean(jnp.abs(kernel)) # per BitNet b1.58 §3.1
y = x @ (ternary_quantize(kernel, threshold) * beta)Without this the output magnitude shrinks as the model learns sparser weights, destabilising training. 3. Storage (inference path) For inference-only deployments the float32 params are unnecessary. A companion Will push a fix for (1) and (2) shortly. |
- STE via lax.stop_gradient: kernel_ste = kernel + sg(kernel_q - kernel) so gradients flow straight through the quantization boundary - beta = mean(|kernel|) multiplied into the output when using auto threshold, restoring activation scale as weights become sparse Without STE the quantization boundary kills gradients and the float params never update. Without beta the output magnitude shrinks as sparsity grows, destabilising training. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
|
@eriirfos-eng thanks for the PR! However, please keep in mind that the approach for the contributions of new features is: first, open an issue to discuss the need and, second, if approved code and open a PR. Contributions of new features to linen are unlikely to land. Even for nnx this TernaryDense layer we should first see if several people asks for that as we want to keep the library small and containing essential things. |
|
@vfdev-5 — understood, thanks for the clear explanation of the process. I opened #5468 as the proper discussion issue for NNX TernaryDense. I'll wait there for maintainer feedback before writing any code. Since new linen features are unlikely to land and the right path is the issue-first process, I'll close this PR. If the discussion in #5468 leads to an approved direction I'll open a fresh, correctly-targeted PR then. |
Summary
Adds
nn.TernaryDensetoflax/linen/linear.pyas a drop-in replacement fornn.Densefor ternary-weight models (BitNet b1.58, et al.).Motivation
Ternary-weight models train directly in
{-γ, 0, +γ}— no post-training quantization step. On every forward pass the kernel is quantized on-the-fly to{-1, 0, +1}; the underlying float parameters remain in the variable collection so that gradient-based optimisers can update them between steps (enabling training, not just inference).Design
The class mirrors
nn.Denseexactly — same attributes, same@compactpattern, samelax.dot_generalcall — with one addition: athresholdattribute and a three-line quantization step before the matmul:threshold=None(default):mean(|kernel|)per call — the BitNet b1.58 §3.1 rule (https://arxiv.org/abs/2402.17764).threshold=0.0: all weights become ±1, no zeros.1e9): all-zero effective kernel.The parameter shape and variable collection are identical to
nn.Dense, so checkpoints from aTernaryDenselayer can be loaded into aDenseof the same shape and vice versa.Files changed
flax/linen/linear.pyTernaryDenseclass afterDenseflax/linen/__init__.pyReference
🤖 Generated with Claude Code