Dataset is obtained from [Kaggle](https://www.kaggle.com/c/dogs-vs-cats).

Click the image below to read the post online.

<a target="_blank" href="https://www.machinelearningnuggets.com/handling-state-in-jax-flax-batchnorm-and-dropout-layers"><img src="https://digitalpress.fra1.cdn.digitaloceanspaces.com/mhujhsj/2022/07/logho-1.png" alt="Open in ML Nuggets"></a>

In [None]:
pip install -U jax flax jaxlib

In [None]:
pip install wget

In [None]:
import wget 
wget.download("https://ml.machinelearningnuggets.com/train.zip")

In [None]:
import zipfile
with zipfile.ZipFile('train.zip', 'r') as zip_ref:
  zip_ref.extractall('.')

## Perform standard imports

In [1]:
import torch
from torch.utils.data import DataLoader
import os
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset
import numpy as np
import pandas as pd
from typing import Any
import matplotlib.pyplot as plt
%matplotlib inline
# ignore harmless warnings
import warnings
warnings.filterwarnings("ignore")

In [2]:
import jax
from jax import numpy as jnp
import flax
from flax import linen as nn  
from flax.training import train_state
import optax

In [3]:
class CatsDogsDataset(Dataset):
    def __init__(self, root_dir, annotation_file, transform=None):
        self.root_dir = root_dir
        self.annotations = pd.read_csv(annotation_file)
        self.transform = transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, index):
        img_id = self.annotations.iloc[index, 0]
        img = Image.open(os.path.join(self.root_dir, img_id)).convert("RGB")
        y_label = torch.tensor(float(self.annotations.iloc[index, 1]))

        if self.transform is not None:
            img = self.transform(img)

        return (img, y_label)

In [4]:
train_df = pd.DataFrame(columns=["img_path","label"])
train_df["img_path"] = os.listdir("train/")
for idx, i in enumerate(os.listdir("train/")):
    if "cat" in i:
        train_df["label"][idx] = 0
    if "dog" in i:
        train_df["label"][idx] = 1

train_df.to_csv (r'train_csv.csv', index = False, header=True)

In [5]:
def custom_collate_fn(batch):
    transposed_data = list(zip(*batch))
    labels = np.array(transposed_data[1])
    imgs = np.stack(transposed_data[0])
    return imgs, labels

In [6]:
size_image = 224
batch_size = 64

transform = transforms.Compose([
    transforms.Resize((size_image,size_image)),
    np.array])
dataset = CatsDogsDataset("train","train_csv.csv",transform=transform)
train_set, validation_set = torch.utils.data.random_split(dataset,[20000,5000])
train_loader = DataLoader(dataset=train_set, collate_fn=custom_collate_fn,shuffle=True, batch_size=batch_size)
validation_loader = DataLoader(dataset=validation_set,collate_fn=custom_collate_fn, shuffle=False, batch_size=batch_size)

In [None]:
(image_batch, label_batch) = next(iter(train_loader))
print(image_batch.shape)
print(label_batch.shape)

In [None]:
image_batch.shape[0]

In [9]:
class CNN(nn.Module):

  @nn.compact
  def __call__(self, x, training):
    x = nn.Conv(features=128, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  
    x = nn.Dense(features=256)(x)
    x = nn.Dense(features=128)(x)
    x = nn.BatchNorm(use_running_average=not training)(x)
    x = nn.Dropout(0.2, deterministic=not training)(x)
    x = nn.relu(x)
    x = nn.Dense(features=2)(x)
    x = nn.log_softmax(x)
    return x

In [10]:
def cross_entropy_loss(*, logits, labels):
  labels_onehot = jax.nn.one_hot(labels, num_classes=2)
  return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()

In [11]:
def compute_loss(params, batch_stats, images, labels):
    logits,batch_stats = CNN().apply({'params': params,'batch_stats': batch_stats},images, training=True,rngs={'dropout': jax.random.PRNGKey(0)}, mutable=['batch_stats'])
    loss = cross_entropy_loss(logits=logits, labels=labels)
    return loss, (logits, batch_stats)

In [12]:
def compute_metrics(*, logits, labels):
  loss = cross_entropy_loss(logits=logits, labels=labels)
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  metrics = {
      'loss': loss,
      'accuracy': accuracy,
  }
  return metrics

In [13]:
# initialize weights
model = CNN()
key = jax.random.PRNGKey(0)
variables = model.init(key, jnp.ones([1, size_image, size_image, 3]), training=False)

class TrainState(train_state.TrainState):
    batch_stats: flax.core.FrozenDict
    
state = TrainState.create(
    apply_fn = model.apply,
    params = variables['params'],
    tx = optax.sgd(0.01),
    batch_stats = variables['batch_stats'],
)

<!-- jax.value_and_grad computes the loss wrt the first argument. So in this code snippet you are computing the gradients wrt the inputs instead of the params. You can fix this by passing the model params as the first argument of the loss function -->

In [16]:
@jax.jit
def train_step(state,images, labels):
  """Train for a single step."""
  (batch_loss, (logits, batch_stats)), grads= jax.value_and_grad(compute_loss, has_aux=True)(state.params,state.batch_stats, images,labels)
  state = state.apply_gradients(grads=grads) 
  
  metrics = compute_metrics(logits=logits, labels=labels) 
  return state, metrics

In [17]:
def train_one_epoch(state, dataloader):
    """Train for 1 epoch on the training set."""
    batch_metrics = []
    for cnt, (images, labels) in enumerate(dataloader):
        images = images / 255.0
        state, metrics = train_step(state, images, labels)
        batch_metrics.append(metrics)

    batch_metrics_np = jax.device_get(batch_metrics)  
    epoch_metrics_np = {
        k: np.mean([metrics[k] for metrics in batch_metrics_np])
        for k in batch_metrics_np[0]
    }
    return state, epoch_metrics_np

In [18]:
@jax.jit
def eval_step(batch_stats, params, images, labels):
    logits = CNN().apply({'params': params,'batch_stats': batch_stats}, images, training=False,rngs={'dropout': jax.random.PRNGKey(0)})
    return compute_metrics(logits=logits, labels=labels)

In [19]:
def evaluate_model(state, test_imgs, test_lbls):
    """Evaluate on the validation set."""
    metrics = eval_step(state.batch_stats,state.params, test_imgs, test_lbls)
    metrics = jax.device_get(metrics) 
    metrics = jax.tree_map(lambda x: x.item(), metrics)  
    return metrics

In [23]:
(test_images, test_labels) = next(iter(validation_loader))
test_images = test_images / 255.0

In [24]:
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms.functional as F
logdir = "flax_logs"
writer = SummaryWriter(logdir)

In [25]:
training_loss = []
training_accuracy = []
testing_loss = []
testing_accuracy = []

In [26]:
def train_model(epochs):
    for epoch in range(1, epochs + 1):
        train_state, train_metrics = train_one_epoch(state, train_loader)
        training_loss.append(train_metrics['loss'])
        training_accuracy.append(train_metrics['accuracy'])

        test_metrics = evaluate_model(train_state, test_images, test_labels)
        testing_loss.append(test_metrics['loss'])
        testing_accuracy.append(test_metrics['accuracy'])
    
        writer.add_scalar('Loss/train', train_metrics['loss'], epoch)
        writer.add_scalar('Loss/test', test_metrics['loss'], epoch)
    
        writer.add_scalar('Accuracy/train', train_metrics['accuracy'], epoch)
        writer.add_scalar('Accuracy/test', test_metrics['accuracy'], epoch)
    
        print(f"Epoch: {epoch}, training loss: {train_metrics['loss']}, training accuracy: {train_metrics['accuracy'] * 100}, validation loss: {test_metrics['loss']}, validation accuracy: {test_metrics['accuracy'] * 100}")
    return train_state

In [None]:
trained_model_state = train_model(30)

In [None]:
evaluate_model(trained_model_state,test_images, test_labels)


In [None]:
pip install tensorstore

In [None]:
from flax.training import checkpoints
ckpt_dir = 'model_checkpoint/'
checkpoints.save_checkpoint(ckpt_dir=ckpt_dir,  
                            target=trained_model_state, 
                            step=100, 
                            prefix='flax_model',  
                            overwrite=True  
                           )

In [33]:
loaded_model = checkpoints.restore_checkpoint(
                                             ckpt_dir=ckpt_dir,   
                                             target=state,  
                                             prefix='flax_model'  
                                            )

In [28]:
%load_ext tensorboard 

In [None]:
%tensorboard --logdir={logdir}

In [None]:
plt.plot(training_accuracy, label="Training")
plt.plot(testing_accuracy, label="Test")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.show()

In [None]:
plt.plot(training_loss, label="Training")
plt.plot(testing_loss, label="Test")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.show()

## Where to go from here
Follow us on [LinkedIn](https://www.linkedin.com/company/mlnuggets), [Twitter](https://twitter.com/ml_nuggets), [GitHub](https://github.com/mlnuggets) and subscribe to our [blog](https://www.machinelearningnuggets.com/#/portal) so that you don't miss a new issue.