<a href="https://colab.research.google.com/github/jsk245/Resnet_JAX/blob/'main'/Resnet_jax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install dm-haiku optax

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
import haiku as hk
import jax
import optax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds

tf.enable_v2_behavior()
tf.config.set_visible_devices([], device_type='GPU')

print("JAX version {}".format(jax.__version__))
print("Haiku version {}".format(hk.__version__))
print("TF version {}".format(tf.__version__))

JAX version 0.3.14
Haiku version 0.0.7
TF version 2.8.2


In [None]:
data_dir = '/tmp/tfds'

# Fetch full datasets for evaluation
# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1)
# You can convert them to NumPy arrays (or iterables of NumPy arrays) with tfds.dataset_as_numpy
cifar100_data = tfds.load(name="cifar100", data_dir=data_dir, split="train")

def make_dataset(batch_size, seed=1):
  def _preprocess(sample):
    # Convert to floats in [0, 1].
    image = tf.image.convert_image_dtype(sample["image"], tf.float32)
    # Scale the data to [-1, 1] to stabilize training.
    return 2.0 * image - 1.0
  def _label_identity(sample):
    label = sample['label']
    return label

  ds = cifar100_data
  #ds = cifar100_data["train"]
  ds = ds.map(map_func=_preprocess, 
              num_parallel_calls=tf.data.experimental.AUTOTUNE)
  ds = ds.cache()
  ds = ds.shuffle(10 * batch_size, seed=seed).repeat().batch(batch_size)

  labels = cifar100_data
  #labels = cifar100_data["train"]
  labels = labels.map(map_func=_label_identity, 
              num_parallel_calls=tf.data.experimental.AUTOTUNE)
  labels = labels.cache()
  labels = labels.shuffle(10 * batch_size, seed=seed).repeat().batch(batch_size)
  return (iter(tfds.as_numpy(ds)), iter(tfds.as_numpy(labels)))

In [None]:
class ConvAndBatchNormModule(hk.Module):
  def __init__(self, is_training, outchannels, kernel_size, stride, name=None):
    super(ConvAndBatchNormModule, self).__init__(name=name)
    self.conv = hk.Conv2D(outchannels, kernel_size, stride)
    self.bn = hk.BatchNorm(True, True, 0.9, cross_replica_axis="jax.vmap")
    self.is_training = is_training

  def __call__(self, x):
    x = self.conv(x)
    x = self.bn(x, self.is_training)
    return x

class ResModule(hk.Module):
  def __init__(self, is_training, inchannels, adjust_dimension, name=None):
    super(ResModule, self).__init__(name=name)
    outchannels = inchannels // 4

    self.dimensionHelper = None
    if adjust_dimension:
      self.dimensionHelper = ResDimensionHelper()
      outchannels = 2 * outchannels
      inchannels = 2 * inchannels

    self.conv1 = ConvAndBatchNormModule(is_training, outchannels, 1, 1)

    if adjust_dimension:
      self.conv2 = ConvAndBatchNormModule(is_training, outchannels, 3, 2)
    else:
      self.conv2 = ConvAndBatchNormModule(is_training, outchannels, 3, 1)

    self.conv3 = ConvAndBatchNormModule(is_training, inchannels, 1, 1)

  def __call__(self, x):
    x_res = x
    x = jax.nn.relu(self.conv1(x))
    x = jax.nn.relu(self.conv2(x))
    x = self.conv3(x)
    if self.dimensionHelper != None:
      x_res = self.dimensionHelper(x_res)
    x = x + x_res
    x = jax.nn.relu(x)
    return x

class ResDimensionHelper(hk.Module):
  def __init__(self, name=None):
    super(ResDimensionHelper, self).__init__(name=name)
    self.maxPool = hk.MaxPool(2, [2,2,1], "SAME")

  def __call__(self, x):
    added_channels = x.shape[2] // 2
    x = self.maxPool(x)
    x = jnp.pad(x, ((0,0),(0,0),(added_channels, added_channels)))
    return x

class DownSampleModule(hk.Module):
  def __init__(self, name=None):
    super(DownSampleModule, self).__init__(name=name)
    self.conv = hk.Conv2D(64, 7, 2, padding="VALID")
    self.maxPool = hk.MaxPool(3, [2,2,1], "SAME")

  def __call__(self, x):
    x = jnp.pad(x, ((3,3),(3,3),(0, 0)))
    x = self.conv(x)
    x = self.maxPool(x)
    return x

class GlobalPoolAndFCModule(hk.Module):
  def __init__(self, goal_num_classes, name=None):
    super(GlobalPoolAndFCModule, self).__init__(name=name)
    self.flatten = hk.Flatten(preserve_dims=-3)
    self.linear = hk.Linear(goal_num_classes)
  def __call__(self, x):
    x = hk.avg_pool(x, [x.shape[0], x.shape[1], 1], 1, 'VALID')
    x = self.flatten(x)
    x = self.linear(x)
    return x

def softmax_cross_entropy(logits, labels):
  one_hot = jax.nn.one_hot(labels, logits.shape[-1])
  return jnp.mean(-jnp.sum(jax.nn.log_softmax(logits) * one_hot, axis=-1))

def accuracy(logits, labels):
  logits = jnp.argmax(logits, axis=1)
  return jnp.mean(logits == labels)


In [None]:
def forward(data, labels, is_training):

  my_func = hk.Sequential([DownSampleModule(),
                           hk.Conv2D(256, 1, 1),
                           ResModule(is_training, 256, False),
                           ResModule(is_training, 256, False),
                           ResModule(is_training, 256, False),
                           ResModule(is_training, 256, True),
                           ResModule(is_training, 512, False),
                           ResModule(is_training, 512, False),
                           ResModule(is_training, 512, False),
                           ResModule(is_training, 512, True),
                           ResModule(is_training, 1024, False),
                           ResModule(is_training, 1024, False),
                           ResModule(is_training, 1024, False),
                           ResModule(is_training, 1024, False),
                           ResModule(is_training, 1024, False),
                           ResModule(is_training, 1024, True),
                           ResModule(is_training, 2048, False),
                           ResModule(is_training, 2048, False),
                           GlobalPoolAndFCModule(100)])
  logits = jax.vmap(my_func, axis_name="jax.vmap")(data)
  loss = softmax_cross_entropy(logits, labels)
  acc = accuracy(logits, labels)
  return {"loss": loss, "accuracy": acc}


In [None]:
learning_rate = 0.01
input_data, input_labels = make_dataset(1024, seed=1)
test_set = tfds.load(name="cifar100", data_dir=data_dir, split="test", batch_size=-1)
test_data, test_labels = tfds.as_numpy(test_set["image"]), tfds.as_numpy(test_set["label"])
forward = hk.transform_with_state(forward)
optimizer = optax.adam(learning_rate)

Instructions for updating:
Use `tf.data.Dataset.get_single_element()`.


Instructions for updating:
Use `tf.data.Dataset.get_single_element()`.


In [None]:
def preprocess(sample):
    # Convert to floats in [0, 1].
    image = tf.image.convert_image_dtype(sample, tf.float32)
    # Scale the data to [-1, 1] to stabilize training.
    return 2.0 * image - 1.0
test_set = tfds.load(name="cifar100", data_dir=data_dir, split="test", batch_size=-1)
test_data, test_labels = tfds.as_numpy(preprocess(test_set["image"])), tfds.as_numpy(test_set["label"])

In [None]:
@jax.jit
def train_step(params, state, opt_state, data, labels):
  def adapt_forward(params, state, data, labels):
    # Pack model output and state together.
    model_output, state = forward.apply(params, state, None, data, labels, True)
    loss = model_output["loss"]
    return loss, (model_output, state)

  grads, (model_output, state) = (jax.grad(adapt_forward, has_aux=True)(params, state, data, labels))

  updates, opt_state = optimizer.update(grads, opt_state)
  params = optax.apply_updates(params, updates)

  return params, state, opt_state, model_output

In [None]:
num_training_updates = 2000
train_losses = []
train_accuracies = []

rng = jax.random.PRNGKey(42)
#with jax.checking_leaks():
params, state = forward.init(rng, next(input_data), next(input_labels), True)
opt_state = optimizer.init(params)


for step in range(1, num_training_updates + 1):
  data = next(input_data)
  labels = next(input_labels)
  params, state, opt_state, train_results = (train_step(params, state, opt_state, data, labels))

  train_results = jax.device_get(train_results)
  train_losses.append(train_results["loss"])
  train_accuracies.append(train_results["accuracy"])


  if step % 100 == 0:
    model_output, _ = forward.apply(params, state, None, test_data, test_labels, False)
    train_step._clear_cache()
    #forward.init(rng, next(input_data), next(input_labels), True)

    print(f'[Step {step}/{num_training_updates}] ' + 
          ('train loss: %f ' % np.mean(train_losses[-100:])) + 
          ('train accuracy: %f ' % np.mean(train_accuracies[-100:])) + 
          ('test accuracy: %f ' % model_output["accuracy"]))

[Step 100/2000] train loss: 5.644377 train accuracy: 0.031875 test accuracy: 0.066300 
[Step 200/2000] train loss: 3.879315 train accuracy: 0.104004 test accuracy: 0.138800 
[Step 300/2000] train loss: 3.498480 train accuracy: 0.165078 test accuracy: 0.172900 
[Step 400/2000] train loss: 3.401290 train accuracy: 0.182148 test accuracy: 0.196900 
[Step 500/2000] train loss: 3.136335 train accuracy: 0.229219 test accuracy: 0.235100 
[Step 600/2000] train loss: 2.900311 train accuracy: 0.274678 test accuracy: 0.268500 
[Step 700/2000] train loss: 2.668881 train accuracy: 0.320693 test accuracy: 0.286900 
[Step 800/2000] train loss: 2.419318 train accuracy: 0.373398 test accuracy: 0.292200 
[Step 900/2000] train loss: 2.153525 train accuracy: 0.428926 test accuracy: 0.304300 
[Step 1000/2000] train loss: 1.855360 train accuracy: 0.492744 test accuracy: 0.312300 
[Step 1100/2000] train loss: 1.507508 train accuracy: 0.575742 test accuracy: 0.313300 
[Step 1200/2000] train loss: 1.151454 tra