In [1]:
import tensorflow_datasets as tfds
import numpy as np
import jax 
import flax.nnx
import jax.numpy as jnp
import tensorflow as tf
import optax
from PIL import Image

2025-03-17 11:34:58.003638: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1742211298.029526   17281 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1742211298.039978   17281 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1742211298.063433   17281 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1742211298.063481   17281 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1742211298.063485   17281 computation_placer.cc:177] computation placer alr

In [2]:
# create the dataset
batchSize = 512
dsimg = tfds.load("beans", split='train', shuffle_files=True, batch_size=-1)['image'].numpy()
reImg = tf.image.resize(dsimg, [256,256])
dataset = tf.data.Dataset.from_tensor_slices(reImg)
# dataset = dataset.batch(batchSize, drop_remainder=True).repeat().shuffle(3, reshuffle_each_iteration=True).prefetch(tf.data.AUTOTUNE)
dataset = dataset.repeat().shuffle(3, reshuffle_each_iteration=True).batch(batchSize, drop_remainder=True).map(lambda x: tf.image.random_crop(value=x, size=(batchSize, 128, 128, 3)), num_parallel_calls=8).prefetch(8)
ds_iter = iter(dataset)

2025-03-17 11:35:01.219764: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:47] Overriding orig_value setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.
I0000 00:00:1742211301.219842   17281 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 5566 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 4060 Laptop GPU, pci bus id: 0000:01:00.0, compute capability: 8.9


In [3]:
class vqvae(flax.nnx.Module):
    def __init__(self, 
                 codebookSize = 64,
                 l2reg = 1e-4,
                 *args, **kwargs):
        super(flax.nnx.Module, self).__init__(*args, **kwargs)
        
        self.rngs = flax.nnx.Rngs(0)
        self.codebookSize = codebookSize
        self.activation = flax.nnx.swish
        self.l2reg = l2reg
        self.layers = []
        
        
        # input size: [-1, 128, 128, 3]
        # encoder
        self.conv1 = flax.nnx.Conv(3, 32, (3,3), strides=2, rngs=self.rngs) # [-1, 64, 64, 32]
        self.layers.append(self.conv1)
        self.conv2 = flax.nnx.Conv(32, 64, (3,3), strides=2, rngs=self.rngs) # [-1, 32, 32, 64]
        self.layers.append(self.conv2)
        self.conv3 = flax.nnx.Conv(64, 128, (3,3), strides=2, rngs=self.rngs) # [-1, 16, 16, 128]
        self.layers.append(self.conv3)
        
        # devoer
        self.tconv1 = flax.nnx.ConvTranspose(128, 64, (3,3), strides=2, rngs=self.rngs) # [-1, 32, 32, 64]
        self.layers.append(self.tconv1)
        self.tconv2 = flax.nnx.ConvTranspose(64, 32, (3,3), strides=2, rngs=self.rngs) # [-1, 64, 64, 32]
        self.layers.append(self.tconv2)
        self.tconv3 = flax.nnx.ConvTranspose(32, 3, (3,3), strides=2, rngs=self.rngs) # [-1, 128, 128, 3]
        self.layers.append(self.tconv3)
        
        # setting code book
        self.codeBook = flax.nnx.Param(
            # jax.nn.initializers.orthogonal()(jax.random.key(0), (1, self.codebookSize, 128))
            jax.nn.initializers.truncated_normal(0.05)(jax.random.key(0), (1, self.codebookSize, 128))
            )
        
    @flax.nnx.jit
    def __call__(self, input):
        
        # encoding
        d1 = self.activation(self.conv1(input))
        d2 = self.activation(self.conv2(d1))
        d3 = self.conv3(d2)
        
        # reshaping for code exchange
        candidateLatents = jnp.reshape(d3, [-1, 1, 128])
        # calculating distances
        euDis = jnp.sum((candidateLatents - self.codeBook) ** 2, axis = -1) # [-1, 16]
        activeIndex = jnp.argmin(euDis, axis=-1) # [-1]
        # replacing codes
        activeIndexOnehot = jax.nn.one_hot(activeIndex, self.codebookSize) # [-1, 16]
        replacedLatents = jnp.sum(
            jnp.reshape(activeIndexOnehot, [-1, self.codebookSize, 1]) * self.codeBook,
            axis = -2
        ) # [-1, 64]
        
        # commit and vq loss
        candidateLatens4Loss = jnp.reshape(candidateLatents, [-1, 128])
        commitLoss = jnp.mean(jnp.sum((jax.lax.stop_gradient(candidateLatens4Loss) - replacedLatents) ** 2, axis=-1) ** .5)  # commit loss
        vqLoss = jnp.mean(jnp.sum((candidateLatens4Loss - jax.lax.stop_gradient(replacedLatents)) ** 2, axis=-1) ** .5) *.25 # vq loss
        self.sow(flax.nnx.Intermediate, "commitLoss", commitLoss, reduce_fn=lambda x, y: y)
        self.sow(flax.nnx.Intermediate, "vqLoss", vqLoss, reduce_fn=lambda x, y: y)
    
    
        # reshaping replaced latents
        replacedLatents = jnp.reshape(replacedLatents, d3.shape)
        # straight throught estimated 
        replacedLatents = jax.lax.stop_gradient(replacedLatents - d3) + d3
        
        d4 = self.activation(self.tconv1(replacedLatents))
        d5 = self.activation(self.tconv2(d4))
        out = self.tconv3(d5)
        
        return out
    
    def l2Reg(self):
        regLoss = 0.
        for layer in self.layers:
            regLoss += jnp.sum(jax.tree_util.tree_leaves(layer.kernel)[0] ** 2)
            regLoss += jnp.sum(jax.tree_util.tree_leaves(layer.bias)[0] ** 2)
        return regLoss * self.l2reg
    
        
model = vqvae()
model(jnp.ones([5, 32, 32, 3]))
model.vqLoss

Intermediate( # 1 (4 B)
  value=Array(0.71752703, dtype=float32)
)

In [4]:
@flax.nnx.jit
def loss_fn(model, x):
    y = (jnp.array(x) / 255. ) - 1.
    y_hat = model(y)
    se = jnp.mean((y_hat - y) ** 2) 
    return (se + model.commitLoss + model.vqLoss + model.l2Reg()) 

In [5]:
learningRate = 1e-4

optChain = optax.chain(
   optax.clip_by_global_norm(1.0),
   optax.adamw(learningRate),
)
opt = flax.nnx.Optimizer(model, optChain)
grad_fn = flax.nnx.value_and_grad(loss_fn)

@flax.nnx.jit
def update_model_weights(model, y):
   loss, grads = grad_fn(model, y)
   opt.update(grads)
   return loss

In [None]:
trainingStep = 50000
for step in range(trainingStep):
    x = jnp.array(next(ds_iter))
    loss = update_model_weights(model, x)
    if step % 1000 == 0 :
        print("step:{}  loss:{}".format(step, loss))
        
        y = (jnp.array(x) / 255. ) - 1.
        y_hat = model(y)
        
        def give_img(x, name):
            pic = x
            # pic = tf.reshape(pic, [128,128,3])
            # pic = tf.reshape(pic, [3,128,128])
            # pic = tf.transpose(pic, [1,2,0])
            pic = (pic + 1) * 128
            pic = Image.fromarray(tf.cast(pic, tf.uint8).numpy())
            pic.save(name)
            
        give_img(y[0], 'bean.jpg')
        give_img(y_hat[0], 'bean_hat.jpg')
    

step:0  loss:3.8211207389831543
step:1000  loss:3.8620593547821045
step:2000  loss:3.9018325805664062
