In [None]:
import jax.numpy as jnp
import numpy as np
import flax.linen as nn  # or nnx.Module if you prefer


class FiniteScalarQuantizer(nn.Module):
    def __init__(self, latent_dim, L, debug=False):
        """
        latent_dim : D (number of channels)
        L : iterable of ints, length D, number of levels per dimension
        """
        self.latent_dim = latent_dim
        self.L = jnp.array(L, dtype=jnp.int32)
        self.debug = debug

        # Compute FSQ basis (mixed radix)
        # basis[d] = product_{k>d} L[k]
        basis_list = []
        acc = 1
        for l in reversed(L):
            basis_list.append(acc)
            acc *= l
        self._basis = jnp.array(list(reversed(basis_list)), dtype=jnp.int32)

        # Store numpy version for easy mod in inverse
        self._levels_np = np.array(L, dtype=np.int32)

        if self.debug:
            print("Levels per dim:", self.L)
            print("Basis:", self._basis)

    # ----------------------------------------------------------------------
    # Scale & shift (FSQ uses centered integers)
    # ----------------------------------------------------------------------
    def _scale_and_shift(self, zhat):
        """
        zhat has integer quantized values in the range [-floor(L/2), floor(L/2)]
        Shift them to [0, L-1].
        """
        return zhat + jnp.floor(self.L / 2)

    def _scale_and_shift_inverse(self, zhat_shifted):
        """
        Reverse of scale_and_shift:
        input in [0, L-1]
        output in [-floor(L/2), floor(L/2)]
        """
        return zhat_shifted - jnp.floor(self.L / 2)

    # ----------------------------------------------------------------------
    # Your quantization function
    # ----------------------------------------------------------------------
    def round_matrix(self, x):
        # Straight-through estimator version (if needed)
        # return x + jax.lax.stop_gradient(jnp.round(x) - x)
        return jnp.round(x)

    def q_func(self, x):
        """
        x: [N, D]
        Returns unrounded quantization targets
        """
        L_broadcast = self.L[None, :]
        return jnp.tanh(x) * jnp.floor(L_broadcast / 2)

    def __call__(self, z):
        """
        z : [B, H, W, D]
        returns quantized z_q : [B, H, W, D]
        """
        B, H, W, D = z.shape
        assert D == self.latent_dim

        z_flat = z.reshape((B * H * W, D))
        z_q_flat = self.round_matrix(self.q_func(z_flat))
        z_q = z_q_flat.reshape((B, H, W, D))
        return z_q

    def codes_to_indexes(self, zhat):
        """
        zhat: [..., D] quantized integer codes (e.g., output of your quantizer)
        Returns: [...,] integer index for each vector
        """
        assert zhat.shape[-1] == self.latent_dim

        # shift from centered integers to [0, L-1]
        shifted = self._scale_and_shift(zhat)

        # mixed-radix flattening
        return (shifted * self._basis).sum(axis=-1).astype(jnp.uint32)

    def indexes_to_codes(self, indices):
        """
        indices: [...,] integer indices
        Returns: [..., D] quantized integer codes
        """
        # add trailing dimension
        idx = indices[..., jnp.newaxis]  # [..., 1]

        # mod & floor divide to extract mixed radix digits
        codes_non_centered = np.mod(
            np.floor_divide(idx, self._basis), 
            self._levels_np
        )

        # Shift back to centered integer grid
        return self._scale_and_shift_inverse(codes_non_centered)


In [10]:
# config
batch_size = 4
z_len = 3   # latent vector length


In [12]:
import jax.numpy as jnp

# -----------------------------------
# 1. Dummy latent vectors
# -----------------------------------

# Random encoded latent vectors
z = jnp.array([
    [0.1, 2.0, -3.0],
    [1.7, -0.5, 0.3],
    [5.0, 0.0, -1.0],
    [-2.3, 4.5, 2.2]
])

# -----------------------------------
# 2. Per-dimension L (now a column vector)
# Example: L = [7, 5, 9]
# -----------------------------------
L = jnp.array([[7, 5, 9]])   # shape = (1, z_len)

# -----------------------------------
# 3. Initialize FSQ
# -----------------------------------
fsq = FSQ(batch_size=batch_size, encoder_length=z_len, L=L, debug=True)

# -----------------------------------
# 4. Quantize
# -----------------------------------
z_hat = fsq.quantise(z)

print("\nInput z:\n", z)
print("\nQuantized z:\n", z_hat)

# -----------------------------------
# 5. Codebook usage stats
# -----------------------------------

# Total codebook size = product of all levels
total_codes = int(jnp.prod(L))
# Unique codes used in this batch
unique_codes = jnp.unique(z_hat, axis=0).shape[0]

usage_percentage = (unique_codes / total_codes) * 100

print("\n--- Codebook Usage Stats ---")
print(f"Total possible codes (|C|): {total_codes}")
print(f"Unique codes used in batch: {unique_codes}")
print(f"Codebook utilization: {usage_percentage:.2f}%")


The latent vector matrix should contain 4 vectors of size 3
Quantization level [[7 5 9]]
Each value of z is mapped from [[-3. -2. -4.]] to [[3. 2. 4.]]
Size of codebook : [[343 125 729]]

Input z:
 [[ 0.1  2.  -3. ]
 [ 1.7 -0.5  0.3]
 [ 5.   0.  -1. ]
 [-2.3  4.5  2.2]]

Quantized z:
 [[ 0.  2. -4.]
 [ 3. -1.  1.]
 [ 3.  0. -3.]
 [-3.  2.  4.]]

--- Codebook Usage Stats ---
Total possible codes (|C|): 315
Unique codes used in batch: 4
Codebook utilization: 1.27%
