In [24]:
from typing import List, Dict, Mapping, Tuple, Any

import chex
import jax
import jax.numpy as jnp
import jax.random as jrand
import flax.linen as nn
from flax.training import train_state  # Useful dataclass to keep train state
import optax
import tensorflow as tf
import pdb
import functools

def println(*args):
  for arg in args:
    print(arg)


In [2]:
# import jax.tools.colab_tpu
# jax.tools.colab_tpu.setup_tpu()
# jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [4]:
DEVICE_COUNT = len(jax.devices())
DEVICE_COUNT

1

In [5]:
BATCH_SIZE = 16

## Dataset pipline

In [6]:
import tensorflow as tf

# Load the CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

# Normalize the pixel values
x_train, x_test = ((x_train / 127.5) - 1), ((x_test / 127.5) -1)

# Convert the labels to one-hot encoding
y_train = tf.keras.utils.to_categorical(y_train, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)

def create_dataset():
  # Create a TensorFlow data pipeline for the training set
  train_dataset = (
      tf.data.Dataset
      .from_tensor_slices((x_train, y_train))
      .repeat()
      .shuffle(buffer_size=5000)
      .batch(BATCH_SIZE)
      .prefetch(tf.data.AUTOTUNE)
      .as_numpy_iterator())

  # Create a TensorFlow data pipeline for the test set.
  test_dataset = (
      tf.data.Dataset
      .from_tensor_slices((x_test, y_test))
      .batch(BATCH_SIZE)
      .prefetch(tf.data.AUTOTUNE)
      .as_numpy_iterator())
  return train_dataset, test_dataset

def get_batch(dataset):
  images, labels = next(dataset)

  images, labels = jnp.array(images), jnp.array(labels)
  # images = jnp.reshape(images, (BATCH_SIZE, -1)) # flatten the images
  return images, labels

train_dataset, test_dataset = create_dataset()

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


### test get_batch

In [10]:
test_images, test_labels = get_batch(train_dataset)
test_images.shape, test_labels.shape

((16, 28, 28), (16, 10))

## Modeling

In [11]:
class Autoencoder(nn.Module):
    num_embeddings: int = 3
    embedding_dim: int = 2
    beta: float = 0.2

    def setup(self):
        self.pre_quant_conv = nn.Conv(features=2, kernel_size=(1,), padding='SAME')
        self.embedding = nn.Embed(num_embeddings=self.num_embeddings, features=self.embedding_dim)
        self.post_quant_conv = nn.Conv(features=4, kernel_size=(1,), padding='SAME')

    @nn.compact
    def __call__(self, x, training: bool = True):
        # Encoder
        x = nn.Conv(features=16, kernel_size=(4,), strides=(2,), padding='SAME')(x)
        x = nn.BatchNorm(use_running_average=not training)(x)
        x = nn.relu(x)
        x = nn.Conv(features=4, kernel_size=(4,), strides=(2,), padding='SAME')(x)
        x = nn.BatchNorm(use_running_average=not training)(x)
        encoded_output = nn.relu(x)

        quant_input = self.pre_quant_conv(encoded_output)

        ## Quantization
        B, H, W, C = quant_input.shape
        quant_input = quant_input.reshape((B, H * W, C))

        # Compute pairwise distances and find index of nearest embedding
        min_encoding_indices = self._dist_batch(quant_input, self.embedding.embedding)

        # Select the embedding weights
        quant_out = self.embedding(min_encoding_indices)

        # Compute losses
        commitment_loss = jnp.mean((quant_out - quant_input)**2)
        codebook_loss = jnp.mean((quant_out - jax.lax.stop_gradient(quant_input))**2)
        quantize_losses = codebook_loss + self.beta * commitment_loss

        # Ensure straight through gradient
        quant_out = quant_input + jax.lax.stop_gradient(quant_out - quant_input)

        # Reshaping back to original input shape
        quant_out = quant_out.reshape((B, H, W, C))

        ## Decoder part
        decoder_input = self.post_quant_conv(quant_out)
        x = nn.ConvTranspose(features=16, kernel_size=(4,), strides=(2,), padding='SAME')(decoder_input)
        x = nn.BatchNorm(use_running_average=not training)(x)
        x = nn.relu(x)
        output = nn.ConvTranspose(features=1, kernel_size=(4,), strides=(2,), padding='SAME')(x)
        output = nn.tanh(output)

        return output, quantize_losses

    def _dist(self, quant_input_single_row, embedding_table):
      distances = jnp.sum((quant_input_single_row - embedding_table)**2, axis=-1)
      min_index = jnp.argmin(distances)
      return min_index

    def _dist_batch(self, quant_input_batch, embedding_table):
      quant_input_fn = jax.vmap(self._dist, in_axes=(0, None), out_axes=(0))
      quant_input_batch_fn = jax.vmap(quant_input_fn, in_axes=(0, None), out_axes=(0))
      return quant_input_batch_fn(quant_input_batch, embedding_table)


In [25]:
class TrainState(train_state.TrainState):
  key: jax.random.KeyArray
  batch_stats: Any

random_key = jax.random.PRNGKey(99)
random_key, random_subkey = jax.random.split(random_key)

model = Autoencoder()

test_image, test_label = test_images[0], test_labels[0]
test_image = test_image[jnp.newaxis, :, :, jnp.newaxis] # B, H, W, C
(output, quantize_losses), params = model.init_with_output(jrand.PRNGKey(99), test_image, training=True)
batch_stats = params["batch_stats"]
params = params["params"]

For more information, see https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html
  key: jax.random.KeyArray


In [30]:
def model_apply(params, batch_stats, inputs, training):
  return model.apply({"params": params, "batch_stats": batch_stats}, inputs, training=training)

model_apply_batch = jax.vmap(model_apply, in_axes=(None, None, 0, None), out_axes=(0))

def forward_pass(params, batch_stats, state, batch):
  inputs, _ = batch # you are using reconstruction loss, so NO need of targets.
  reconstructed, new_batch_stats, = state.apply_fn(
      params,
      batch_stats,
      inputs,
      True, # training
      mutable=['batch_stats'],
    )

  PER_HOST_BATCH_SIZE = BATCH_SIZE // jax.device_count()

  reconstructed = jnp.reshape(reconstructed, (PER_HOST_BATCH_SIZE, -1)) # flatten the reconstructed to (32, 784)

  chex.assert_shape(inputs, (PER_HOST_BATCH_SIZE, 784))
  chex.assert_shape(reconstructed, (PER_HOST_BATCH_SIZE, 784))

  loss = (inputs - reconstructed) ** 2
  loss = loss.mean()
  return loss, new_batch_stats

def train_step(state, inputs, targets):
  batch = inputs, targets
  grad_fn = jax.value_and_grad(forward_pass, argnums=(0))  # differentiate wrt 0th pos argument.
  params = state.params
  (loss, new_batch_stats), grads = grad_fn(state.params, state.batch_stats, state, batch)

  loss = jax.lax.pmean(loss, axis_name="devices")
  grads = jax.lax.pmean(grads, axis_name="devices")

  state = state.apply_gradients(grads=grads, batch_stats=new_batch_stats)
  return state, loss

opt = optax.adam(learning_rate=0.001)
state = TrainState.create(apply_fn=model_apply_batch, params=params, tx=opt, key=random_key, batch_stats=batch_stats)

In [31]:
# pmap the train_step.
train_step_pmap = jax.pmap(train_step, in_axes=(0, 0, 0), out_axes=(0), axis_name="devices")

In [32]:
# replicate state
states = jax.device_put_replicated(state, jax.local_devices())

In [33]:
# Training loop
num_epochs = 20
steps_per_epoch = len(x_train) // BATCH_SIZE

for epoch in range(num_epochs):
    print("epoch: ", epoch)
    train_dataset, _ = create_dataset()

    for step in range(steps_per_epoch):
        inputs, targets = get_batch(train_dataset)

        # create device dimension for minibatch
        inputs = inputs.reshape((jax.device_count(), -1, 28, 28, 1))
        targets = targets.reshape((jax.device_count(), -1, 10))

        states, loss = train_step_pmap(states, inputs, targets)
        print("loss", loss[0]) if step%100==0 else None

epoch:  0


ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())

In [None]:
state = jax.tree_map(lambda x: x[0], states)

In [None]:
import matplotlib.pyplot as plt

def plot_reconstructions(model, params, batch, n=10):
    inputs, _ = batch
    reconstructed = state.apply_fn(params, inputs, False)
    fig, axes = plt.subplots(2, n, figsize=(n * 2, 4))
    for i in range(n):
        axes[0, i].imshow(inputs[i].reshape(28, 28), cmap='gray')
        axes[0, i].axis('off')
        axes[1, i].imshow(reconstructed[i].reshape(28, 28), cmap='gray')
        axes[1, i].axis('off')
    plt.show()

# Visualize some reconstructionsmodel_apply_batch
train_dataset, test_dataset = create_dataset()
plot_reconstructions(model, state.params, get_batch(test_dataset))
