In [1]:
import flax
import jax
import numpy as np 
import jax.numpy as jnp
import optax
from flax import nnx, linen as lnn
from flax.training import train_state
from sklearn.preprocessing import LabelEncoder

In [None]:
class Convnet(lnn.module):

    @lnn.compact
    def __call__(self, img):
        x = lnn.Conv(features=32, kernel_size=(3, 3))(img)
        x = lnn.relu(x)
        x = lnn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = lnn.Conv(features=64, kernel_size=(3, 3))(x)
        x = lnn.relu(x)
        x = lnn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))
        x = lnn.Dense(features=256)(x)
        x = lnn.relu(x)
        x = lnn.Dense(features=10)(x)

        return x

In [None]:
image_size = 200
batch_size = 32
learn_rate = 0.01

In [None]:
def cross_entropy(*, logits, labels):
    encoded_labels = jax.nn.one_hot(labels, num_classes=38)
    ce_loss = optax.softmax_cross_entropy(logits=logits, labels=encoded_labels)

    return ce_loss.mean()


def compute_model_metrics(images, labels, params):
    logits = Convnet().apply({"params": params}, images)
    loss = cross_entropy(logits, labels)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)

    metrics = {"loss": loss, "accuracy": accuracy}

    return metrics    


def init_train_state(rng, lr=learn_rate):
    convnet = Convnet()
    params = convnet.init(rng, jnp.ones([1, image_size, image_size, 3]))['params']
    tx = optax.adam(learning_rate=lr)
    
    train_state = train_state.TrainState.create(apply_fn=convnet.apply, params=params, tx=tx)
    
    return train_state

In [None]:
@jax.jit
def train_step(state, batch):
    
    