In [2]:
from typing import List, Dict, Mapping, Tuple

import chex
import jax
import jax.numpy as jnp
import jax.random as jrand
import flax.linen as nn
from flax.training import train_state  # Useful dataclass to keep train state
import optax
import tensorflow as tf
import pdb
import functools

def println(*args):
  for arg in args:
    print(arg)


In [None]:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [3]:
DEVICE_COUNT = len(jax.devices())
DEVICE_COUNT

1

In [10]:
BATCH_SIZE = 32

## Dataset pipline

In [11]:
import tensorflow as tf

# Load the CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

# Normalize the pixel values
x_train, x_test = x_train / 255.0, x_test / 255.0

# Convert the labels to one-hot encoding
y_train = tf.keras.utils.to_categorical(y_train, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)

def create_dataset():
  # Create a TensorFlow data pipeline for the training set
  train_dataset = (
      tf.data.Dataset
      .from_tensor_slices((x_train, y_train))
      .repeat()
      .shuffle(buffer_size=5000)
      .batch(BATCH_SIZE)
      .prefetch(tf.data.AUTOTUNE)
      .as_numpy_iterator())

  # Create a TensorFlow data pipeline for the test set.
  test_dataset = (
      tf.data.Dataset
      .from_tensor_slices((x_test, y_test))
      .batch(BATCH_SIZE)
      .prefetch(tf.data.AUTOTUNE)
      .as_numpy_iterator())
  return train_dataset, test_dataset

def get_batch(dataset):
  images, labels = next(dataset)

  images, labels = jnp.array(images), jnp.array(labels)
  images = jnp.reshape(images, (BATCH_SIZE, -1)) # flatten the images
  return images, labels

train_dataset, test_dataset = create_dataset()

### test get_batch

In [12]:
test_images, test_labels = get_batch(train_dataset)
test_images.shape, test_labels.shape

((32, 784), (32, 10))

In [13]:
test_image, est_label = test_images[0], test_labels[0]
test_image.shape

(784,)

## Modeling

In [52]:
class Autoencoder(nn.Module):

  def setup(self):
    self._encoder = nn.Sequential([
        nn.Dense(128), # 784->128
        nn.relu,
        nn.Dense(64), # 128->64
        nn.relu,
        nn.Dense(12), # 64->12
        nn.relu,
        nn.Dense(3) # 12->3
    ])

    self._decoder = nn.Sequential([
        nn.Dense(12), # 3->12
        nn.relu,
        nn.Dense(64), # 12->64
        nn.relu,
        nn.Dense(128), # 64->128
        nn.relu,
        nn.Dense(784), # 128->784
        nn.sigmoid
    ])

  def __call__(self, x):
    encoded = self._encoder(x)
    decoded = self._decoder(encoded)
    return decoded

In [53]:
model = Autoencoder()

In [91]:
params = model.init(jrand.PRNGKey(99), jnp.zeros(shape=(1, 784)))

In [92]:
train_dataset, test_dataset = create_dataset()

In [97]:
images, labels = get_batch(train_dataset)

In [98]:
opt = optax.adam(learning_rate=0.001)

In [99]:
recons = model.apply(params, images)

In [101]:
state = train_state.TrainState.create(apply_fn=model.apply,
                                      params=params["params"],
                                      tx=opt)

In [114]:
def mse_loss(input, recon):
  return jnp.mean((input - recon)**2)

In [121]:
def train_step(params, model, state, batch):
  def _compute_loss(params):
    inputs, _ = batch
    recons = model.apply({"params": params}, inputs)
    recon_losses = jax.vmap(mse_loss)(inputs, recons)
    loss = jnp.mean(recon_losses) # mean loss across batch
    return loss

  grad_fn = jax.value_and_grad(_compute_loss)
  loss, grads = grad_fn(params)
  state = state.apply_gradients(grads=grads)
  return loss, state

In [116]:
batch = get_batch(train_dataset)

In [122]:
train_step(params["params"], model, state, batch)

(Array(0.23183772, dtype=float32),
 TrainState(step=1, apply_fn=<bound method Module.apply of Autoencoder()>, params={'_decoder': {'layers_0': {'bias': Array([-0.00099997,  0.        , -0.00099993,  0.        , -0.00099999,
         0.00099998, -0.00099987,  0.00099852,  0.        , -0.00099987,
         0.        ,  0.00099999], dtype=float32), 'kernel': Array([[-0.24912775, -0.40926808, -0.2056664 , -0.20552501, -0.24634708,
          0.49280736,  0.23000881, -0.5782363 , -0.66836035,  0.23709968,
         -0.456887  ,  0.55835885],
        [-0.17653255,  0.9846388 ,  0.24080473,  0.02505654, -0.85239697,
          0.09337641,  0.363004  , -0.07419533, -0.50320196,  0.39220592,
         -0.45173275, -0.7583791 ],
        [ 0.26127473,  0.02490273,  0.5945362 , -0.54837614, -0.24334133,
          0.35193622,  0.41233945,  1.0764588 , -0.11943952, -1.2236131 ,
         -0.8682312 ,  0.12213197]], dtype=float32)}, 'layers_2': {'bias': Array([ 0.00099998,  0.        ,  0.        ,  0.000

In [None]:
jnp.mean()