In [None]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

In [9]:
# Class for FSQ with corresponding member functions
'''
FSQ Class

Input : Matrix of all encoded vectors of training dataset
input matrix : [batch_size, latent vector size (d)]

Output : Matrix of all quantized vectors
output matrix : [batch_size, 2⌊L/2⌋ + 1 (quantization-level)]

'''

class FSQ:
  def __init__(self, batch_size, encoder_length, L, debug=False):
    self.batch_size = batch_size
    self.z_len = encoder_length
    self.debug = debug
    self.level = L

    if self.debug:
      print(f"The latent vector matrix should contain {self.batch_size} vectors of size {self.z_len}")
      print(f"Quantization level {L}")
      print(f"Each value of z is mapped from {-jnp.floor(L/2)} to {jnp.floor(L/2)}")
      print(f"Size of codebook : {jnp.power(self.level, self.z_len)}")

      if (self.level.shape[1] != self.z_len):
        print("The quantisation levels and the latent vector length are not matching!")


  def round_matrix(self, matrix):
    return jnp.round(matrix)

  def q_func(self, x):
    # Quantisation function according to the FSQ paper
    # Broadcasting works automatically
    return jnp.tanh(x) * jnp.floor(self.level/2)

  def quantise(self, x):
    x_quantised = self.q_func(x)
    x_quatised_rounded = self.round_matrix(x_quantised)
    return x_quatised_rounded




In [10]:
# config
batch_size = 4
z_len = 3   # latent vector length


In [12]:
import jax.numpy as jnp

# -----------------------------------
# 1. Dummy latent vectors
# -----------------------------------

# Random encoded latent vectors
z = jnp.array([
    [0.1, 2.0, -3.0],
    [1.7, -0.5, 0.3],
    [5.0, 0.0, -1.0],
    [-2.3, 4.5, 2.2]
])

# -----------------------------------
# 2. Per-dimension L (now a column vector)
# Example: L = [7, 5, 9]
# -----------------------------------
L = jnp.array([[7, 5, 9]])   # shape = (1, z_len)

# -----------------------------------
# 3. Initialize FSQ
# -----------------------------------
fsq = FSQ(batch_size=batch_size, encoder_length=z_len, L=L, debug=True)

# -----------------------------------
# 4. Quantize
# -----------------------------------
z_hat = fsq.quantise(z)

print("\nInput z:\n", z)
print("\nQuantized z:\n", z_hat)

# -----------------------------------
# 5. Codebook usage stats
# -----------------------------------

# Total codebook size = product of all levels
total_codes = int(jnp.prod(L))
# Unique codes used in this batch
unique_codes = jnp.unique(z_hat, axis=0).shape[0]

usage_percentage = (unique_codes / total_codes) * 100

print("\n--- Codebook Usage Stats ---")
print(f"Total possible codes (|C|): {total_codes}")
print(f"Unique codes used in batch: {unique_codes}")
print(f"Codebook utilization: {usage_percentage:.2f}%")


The latent vector matrix should contain 4 vectors of size 3
Quantization level [[7 5 9]]
Each value of z is mapped from [[-3. -2. -4.]] to [[3. 2. 4.]]
Size of codebook : [[343 125 729]]

Input z:
 [[ 0.1  2.  -3. ]
 [ 1.7 -0.5  0.3]
 [ 5.   0.  -1. ]
 [-2.3  4.5  2.2]]

Quantized z:
 [[ 0.  2. -4.]
 [ 3. -1.  1.]
 [ 3.  0. -3.]
 [-3.  2.  4.]]

--- Codebook Usage Stats ---
Total possible codes (|C|): 315
Unique codes used in batch: 4
Codebook utilization: 1.27%
