In [10]:
import tensorflow as tf
import tensorflow_datasets as tfds

def get_datasets(num_epochs: int, batch_size: int):
    """Load the MNIST dataset and prepare it for training."""
    train_ds = tfds.load('mnist', split='train')
    test_ds = tfds.load('mnist', split='test')

    train_ds = train_ds.map(lambda sample: 
            {'image': tf.cast(sample['image'], tf.float32) / 255.0,
            'label': sample['label']})
    test_ds = test_ds.map(lambda sample: 
            {'image': tf.cast(sample['image'], tf.float32) / 255.0,
            'label': sample['label']})
    
    train_ds = train_ds.repeat(num_epochs).shuffle(1024)
    train_ds = train_ds.batch(batch_size, drop_remainder=True)
    test_ds = test_ds.shuffle(1024).batch(batch_size, drop_remainder=True)

    return train_ds, test_ds
                            

In [None]:
from flax import linen as nn

class CNN(nn.Module):
    """A simple CNN model."""

    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)
        return x
        