In [37]:
from typing import List, Dict, Mapping, Tuple

import chex
import jax
import jax.numpy as jnp
import jax.random as jrand
import flax.linen as nn
from flax.training import train_state  # Useful dataclass to keep train state
import optax
import tensorflow as tf
import pdb
import functools

def println(*args):
  for arg in args:
    print(arg)


In [38]:
# import jax.tools.colab_tpu
# jax.tools.colab_tpu.setup_tpu()
# jax.devices()

In [39]:
DEVICE_COUNT = len(jax.devices())
DEVICE_COUNT

1

In [40]:
BATCH_SIZE = 96

## Dataset pipline

In [41]:
import tensorflow as tf

# Load the CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

# Normalize the pixel values
x_train, x_test = x_train / 255.0, x_test / 255.0

# Convert the labels to one-hot encoding
y_train = tf.keras.utils.to_categorical(y_train, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)

def create_dataset():
  # Create a TensorFlow data pipeline for the training set
  train_dataset = (
      tf.data.Dataset
      .from_tensor_slices((x_train, y_train))
      .repeat()
      .shuffle(buffer_size=5000)
      .batch(BATCH_SIZE)
      .prefetch(tf.data.AUTOTUNE)
      .as_numpy_iterator())

  # Create a TensorFlow data pipeline for the test set.
  test_dataset = (
      tf.data.Dataset
      .from_tensor_slices((x_test, y_test))
      .batch(BATCH_SIZE)
      .prefetch(tf.data.AUTOTUNE)
      .as_numpy_iterator())
  return train_dataset, test_dataset

def get_batch(dataset):
  images, labels = next(dataset)

  images, labels = jnp.array(images), jnp.array(labels)
  images = jnp.reshape(images, (BATCH_SIZE, -1)) # flatten the images
  return images, labels

train_dataset, test_dataset = create_dataset()

### test get_batch

In [42]:
test_images, test_labels = get_batch(train_dataset)
test_images.shape, test_labels.shape

((96, 784), (96, 10))

In [43]:
test_image, est_label = test_images[0], test_labels[0]
test_image.shape

(784,)

## Modeling

In [44]:
class Autoencoder(nn.Module):

  def setup(self):
    self._encoder = nn.Sequential([
        nn.Dense(128), # 784->128
        nn.relu,
        nn.Dense(64), # 128->64
        nn.relu,
        nn.Dense(12), # 64->12
        nn.relu,
        nn.Dense(3) # 12->3
    ])

    self._decoder = nn.Sequential([
        nn.Dense(12), # 3->12
        nn.relu,
        nn.Dense(64), # 12->64
        nn.relu,
        nn.Dense(128), # 64->128
        nn.relu,
        nn.Dense(784), # 128->784
        nn.sigmoid
    ])

  def __call__(self, x):
    encoded = self._encoder(x)
    decoded = self._decoder(encoded)
    return decoded

In [45]:
model = Autoencoder()
params = model.init(jrand.PRNGKey(99), jnp.zeros(shape=(1, 784)))["params"]
opt = optax.adam(learning_rate=0.001)

In [46]:
train_dataset, test_dataset = create_dataset()
images, labels = get_batch(train_dataset)

In [47]:
recons = model.apply({"params": params}, images)

In [48]:
state = train_state.TrainState.create(apply_fn=model.apply,
                                      params=params,
                                      tx=opt)

In [49]:
def mse_loss(input, recon):
  return jnp.mean((input - recon)**2)

In [50]:
def train_step(params, state, batch):
  def _compute_loss(params):
    inputs, _ = batch
    recons = state.apply_fn({"params": params}, inputs)
    recon_losses = jax.vmap(mse_loss)(inputs, recons)
    loss = jnp.mean(recon_losses) # mean loss across batch
    return loss

  grad_fn = jax.value_and_grad(_compute_loss)
  loss, grads = grad_fn(params)
  state = state.apply_gradients(grads=grads)
  return state, loss

In [51]:
state, loss = train_step(params, state, (images, labels))

In [None]:
# Training loop
num_epochs = 20
steps_per_epoch = len(x_train) // BATCH_SIZE

for epoch in range(num_epochs):
    print("epoch: ", epoch)
    train_dataset, _ = create_dataset()

    for step in range(steps_per_epoch):
        batch = get_batch(train_dataset)

        params = state.params
        state, loss = train_step(params, state, batch)
        print("loss", loss) if step%100==0 else None


epoch:  0
loss 0.22952096
loss 0.064420894
loss 0.06163464
loss 0.052448533
loss 0.048550844
loss 0.045323335
loss 0.04582125
epoch:  1
loss 0.04300247
loss 0.04237263
loss 0.044522047
loss 0.04462269
loss 0.039602436
loss 0.042121753
loss 0.0379778
epoch:  2
loss 0.038252622
loss 0.042860053
