In [1]:
import tensorflow as tfds
import numpy as np
import jax 
import flax.nnx
import jax.numpy as jnp

2025-03-16 17:15:15.592978: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-03-16 17:15:15.782238: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1742145315.858341   36331 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1742145315.880088   36331 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1742145316.036130   36331 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

In [7]:
class vqvae(flax.nnx.Module):
    def __init__(self, 
                 codebookSize = 32,
                 *args, **kwargs):
        super(flax.nnx.Module, self).__init__(*args, **kwargs)
        
        self.rngs = flax.nnx.Rngs(0)
        self.codebookSize = codebookSize
        
        # input size: [32, 32, 3]
        # encoder
        self.conv1 = flax.nnx.Conv(3, 8, (3,3), strides=2, rngs=self.rngs) # [-1, 16, 16, 8]
        self.conv2 = flax.nnx.Conv(8, 16, (3,3), strides=2, rngs=self.rngs) # [-1, 8, 8, 16]
        self.conv3 = flax.nnx.Conv(16, 64, (3,3), strides=2, rngs=self.rngs) # [-1, 4, 4, 64]
        
        # devoer
        self.tconv1 = flax.nnx.ConvTranspose(64, 16, (3,3), strides=2, rngs=self.rngs) # [-1, 8, 8, 16]
        self.tconv2 = flax.nnx.ConvTranspose(16, 8, (3,3), strides=2, rngs=self.rngs) # [-1, 16, 16, 8]
        self.tconv3 = flax.nnx.ConvTranspose(8, 3, (3,3), strides=2, rngs=self.rngs) # [-1, 32, 32, 3]
        
        # setting code book
        self.codeBook = flax.nnx.Param(
            jax.nn.initializers.orthogonal()(jax.random.key(0), (1, self.codebookSize, 64))
            )
        
    
    def __call__(self, input):
        
        # encoding
        d1 = self.conv1(input)
        d2 = self.conv2(d1)
        d3 = self.conv3(d2)
        
        # reshaping for code exchange
        candidateLatents = jax.reshape(d3, [-1, 1, 64])
        # calculating distances
        euDis = jnp.sum((candidateLatents - self.codeBook) ** 2, axis = -1) # [-1, 16]
        activeIndex = jnp.argmax(euDis, axis=-1) # [-1]
        activeIndexOnehot = jax.nnx.one_hot(activeIndex, self.codebookSize) # [-1, 16]
        replacedLatents = jnp.sum(
            jax.reshape(activeIndexOnehot, [-1, self.codebookSize, 1]) * self.codeBook,
            axis = -2
        ) # [-1, 64]
        # reshaping for decoding
        replacedLatents = jax.reshape(replacedLatents, d3.shape)
        
        d4 = self.tconv1(replacedLatents)
        d5 = self.tconv2(d4)
        out = self.tconv3(d5)
        
        return out
        
model = vqvae()