In [1]:
import flax
import jax
import numpy as np 
import jax.numpy as jnp
import optax
from flax import nnx, linen as lnn
from flax.training import train_state

import tensorflow as tf
from sklearn.preprocessing import LabelEncoder
from tqdm.auto import tqdm

In [None]:
image_size = 200
batch_size = 32
learn_rate = 0.01
epochs = 30
data_dir = '/kaggle/input/plantvillage-dataset/color'


In [None]:
# Data

train_data = tf.keras.utils.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=123,
    image_size=(200, 200),
    batch_size=batch_size,
)

val_data = tf.keras.utils.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=123,
    image_size=(200, 200),
    batch_size=batch_size,
)

In [None]:
class Convnet(lnn.module):

    @lnn.compact
    def __call__(self, img):
        x = lnn.Conv(features=32, kernel_size=(3, 3))(img)
        x = lnn.relu(x)
        x = lnn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = lnn.Conv(features=64, kernel_size=(3, 3))(x)
        x = lnn.relu(x)
        x = lnn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))
        x = lnn.Dense(features=256)(x)
        x = lnn.relu(x)
        x = lnn.Dense(features=10)(x)

        return x

In [None]:
def cross_entropy(*, logits, labels):
    encoded_labels = jax.nn.one_hot(labels, num_classes=38)
    ce_loss = optax.softmax_cross_entropy(logits=logits, labels=encoded_labels)

    return ce_loss.mean()


def compute_model_metrics(*, logits, labels):
    loss = cross_entropy(logits=logits, labels=labels)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)

    metrics = {"loss": loss, "accuracy": accuracy}

    return metrics    


def compute_loss(params, images, labels):
    logits = Convnet().apply({"params": params}, images)
    loss = cross_entropy(logits, labels)

    return loss


def init_train_state(rng, lr=learn_rate):
    convnet = Convnet()
    params = convnet.init(rng, jnp.ones([1, image_size, image_size, 3]))['params']
    tx = optax.adam(learning_rate=lr)
    
    train_state = train_state.TrainState.create(apply_fn=convnet.apply, params=params, tx=tx)
    
    return train_state

In [None]:
@jax.jit
def train_step(state, batch):
    images, labels = batch
    (_, logits), grads = jax.value_and_grad(compute_loss, has_aux=True)(state.params, images, labels)
    state = state.apply_gradients(grads=grads)
    metrics = compute_model_metrics(logits=logits, labels=labels)
    
    return state, metrics
    
@jax.jit
def eval_step(state, batch):
    images, labels = batch
    logits = Convnet().apply({'params': state.params}, images)
    
    return compute_model_metrics(logits=logits, labels=labels)

In [None]:
def evaluate_model(state, batch):
    test_imgs, test_labels = batch
    metrics = eval_step(state, test_imgs, test_labels)
    metrics = jax.device_get(metrics)
    metrics = jax.tree_map(lambda x: x.item(), metrics)
    
    return metrics

In [None]:
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
lr = 1e-5
seed = 0

state = init_train_state(init_rng, lr)

train_loss, test_loss = [], []
train_acc, test_acc = [], []


In [None]:
def train_clf_convnet(state, train_loader, test_loader, num_epochs=epochs):
    for epoch in tqdm(range(num_epochs)):
        train_batch_loss, train_batch_accuracy = [], []
        val_batch_loss, val_batch_accuracy = [], []
        
        for train_batch in train_loader:
            state, 