<a href="https://colab.research.google.com/github/mohsenh17/jaxLearning/blob/main/flax/MNIST_CNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Train and test set (TensorFlow)

In [30]:
import tensorflow_datasets as tfds
import tensorflow as tf

from flax import nnx
from functools import partial

import jax.numpy as jnp


In [22]:
tf.random.set_seed(42)

train_steps = 1400
eval_every = 200
batch_size = 32

mnist_train: tf.data.Dataset = tfds.load('mnist', split='train')
mnist_test: tf.data.Dataset = tfds.load('mnist', split='test')

mnist_train = mnist_train.map(lambda sample: {'image': tf.cast(sample['image'], tf.float32) / 255., 'label': sample['label']})
mnist_test = mnist_test.map(lambda sample: {'image': tf.cast(sample['image'], tf.float32) / 255., 'label': sample['label']})

# the current setting won't cover all images -> (1875, 32)
mnist_train = mnist_train.repeat().shuffle(1024)
mnist_train = mnist_train.batch(batch_size, drop_remainder=True).take(train_steps).prefetch(1)

mnist_test = mnist_test.batch(batch_size, drop_remainder=True).prefetch(1)






In [28]:
class CNN(nnx.Module):
  def __init__(self, *, num_classes: int, rngs:nnx.Rngs):
    self.num_classes = num_classes
    self.conv1 = nnx.Conv(in_features=1, out_features=32, kernel_size=(3, 3), rngs=rngs)
    self.conv2 = nnx.Conv(in_features=32, out_features=64, kernel_size=(3, 3), rngs=rngs)
    self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))
    self.linear1 = nnx.Linear(3136, 256, rngs=rngs)
    self.linear2 = nnx.Linear(256, num_classes, rngs=rngs)
  def __call__(self, x):
    x = self.avg_pool(nnx.gelu(self.conv1(x)))
    x = self.avg_pool(nnx.gelu(self.conv2(x)))
    x = x.reshape(x.shape[0], -1)  # flatten
    x = nnx.gelu(self.linear1(x))
    x = self.linear2(x)
    return x

# Instantiate the model.
model = CNN(num_classes=10, rngs=nnx.Rngs(0))
# Visualize it.
nnx.display(model)

CNN(
  num_classes=10,
  conv1=Conv(
    kernel_shape=(3, 3, 1, 32),
    kernel=Param(
      value=Array(shape=(3, 3, 1, 32), dtype=float32)
    ),
    bias=Param(
      value=Array(shape=(32,), dtype=float32)
    ),
    in_features=1,
    out_features=32,
    kernel_size=(3, 3),
    strides=1,
    padding='SAME',
    input_dilation=1,
    kernel_dilation=1,
    feature_group_count=1,
    use_bias=True,
    mask=None,
    dtype=None,
    param_dtype=<class 'jax.numpy.float32'>,
    precision=None,
    kernel_init=<function variance_scaling.<locals>.init at 0x797f231a68c0>,
    bias_init=<function zeros at 0x797f3e206b00>,
    conv_general_dilated=<function conv_general_dilated at 0x797f3ea2b760>
  ),
  conv2=Conv(
    kernel_shape=(3, 3, 32, 64),
    kernel=Param(
      value=Array(shape=(3, 3, 32, 64), dtype=float32)
    ),
    bias=Param(
      value=Array(shape=(64,), dtype=float32)
    ),
    in_features=32,
    out_features=64,
    kernel_size=(3, 3),
    strides=1,
    padding='S

In [31]:
y = model(jnp.ones((1, 28, 28, 1)))
nnx.display(y)

[[-0.0380987  -0.1663384   0.04239376 -0.1821506   0.05388633  0.03874649
  -0.03934722  0.05054386  0.17043065 -0.06008959]]
