Dataset is obtained from [Kaggle](https://www.kaggle.com/c/dogs-vs-cats).

Click the image below to read the post online.

<a target="_blank" href="https://www.machinelearningnuggets.com/resnet-flax
"><img src="https://digitalpress.fra1.cdn.digitaloceanspaces.com/mhujhsj/2022/07/logho-1.png" alt="Open in ML Nuggets"></a>

In [2]:
pip install -U jax flax jaxlib

In [None]:
!git clone https://github.com/matthias-wright/flaxmodels.git

In [None]:
pip install -r flaxmodels/training/resnet/requirements.txt

In [None]:
pip install wget

In [None]:
import wget 
wget.download("https://ml.machinelearningnuggets.com/train.zip")

In [7]:
import zipfile
with zipfile.ZipFile('train.zip', 'r') as zip_ref:
  zip_ref.extractall('.')

## Perform standard imports

In [8]:
import torch
from torch.utils.data import DataLoader
import os
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
# ignore harmless warnings
import warnings
warnings.filterwarnings("ignore")

In [9]:
import jax
from jax import numpy as jnp
import flax
from flax import linen as nn  
from flax.training import train_state
import optax
import time
from tqdm.notebook import tqdm
import math
from flax import jax_utils

In [13]:
class CatsDogsDataset(Dataset):
    def __init__(self, root_dir, annotation_file, transform=None):
        self.root_dir = root_dir
        self.annotations = pd.read_csv(annotation_file)
        self.transform = transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, index):
        img_id = self.annotations.iloc[index, 0]
        img = Image.open(os.path.join(self.root_dir, img_id)).convert("RGB")
        y_label = torch.tensor(float(self.annotations.iloc[index, 1]))

        if self.transform is not None:
            img = self.transform(img)

        return (img, y_label)

In [14]:
train_df = pd.DataFrame(columns=["img_path","label"])
train_df["img_path"] = os.listdir("train/")
for idx, i in enumerate(os.listdir("train/")):
    if "cat" in i:
        train_df["label"][idx] = 0
    if "dog" in i:
        train_df["label"][idx] = 1

train_df.to_csv (r'train_csv.csv', index = False, header=True)

In [15]:
def custom_collate_fn(batch):
    transposed_data = list(zip(*batch))
    labels = np.array(transposed_data[1])
    imgs = np.stack(transposed_data[0])
    return imgs, labels

In [16]:
size_image = 224
batch_size = 32

transform = transforms.Compose([
    transforms.Resize((size_image,size_image)),
    np.array])
dataset = CatsDogsDataset("train","train_csv.csv",transform=transform)
train_set, validation_set = torch.utils.data.random_split(dataset,[20000,5000])
train_loader = DataLoader(dataset=train_set, collate_fn=custom_collate_fn,shuffle=True, batch_size=batch_size)
validation_loader = DataLoader(dataset=validation_set,collate_fn=custom_collate_fn, shuffle=False, batch_size=batch_size)

In [17]:
(image_batch, label_batch) = next(iter(train_loader))
print(image_batch.shape)
print(label_batch.shape)

(32, 224, 224, 3)
(32,)


In [18]:
image_batch.shape[0]

32

In [19]:
import jax.numpy as jnp
import flaxmodels as fm

num_classes = 2
dtype = jnp.float32
model = fm.ResNet50(output='log_softmax', pretrained=None, num_classes=num_classes, dtype=dtype)

In [20]:
def cross_entropy_loss(*, logits, labels):
  labels_onehot = jax.nn.one_hot(labels, num_classes=num_classes)
  return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()

In [21]:
def compute_metrics(*, logits, labels):
  loss = cross_entropy_loss(logits=logits, labels=labels)
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  metrics = {
      'loss': loss,
      'accuracy': accuracy,
  }
  return metrics

In [22]:

class TrainState(train_state.TrainState):
    batch_stats: flax.core.FrozenDict

In [23]:
key = jax.random.PRNGKey(0)
variables = model.init(key, jnp.ones([1, size_image, size_image, 3]), train=False)

In [24]:
import functools
@functools.partial(jax.pmap)
def create_train_state(rng):
  """Creates initial `TrainState`."""
  return TrainState.create(apply_fn = model.apply,params = variables['params'],tx = optax.adam(0.01,0.9),batch_stats = variables['batch_stats'])

In [25]:
@functools.partial(jax.pmap, axis_name='ensemble')
def apply_model(state, images, labels):
  def loss_fn(params,batch_stats):
    logits,batch_stats = model.apply({'params': params,'batch_stats': batch_stats},images, train=True,rngs={'dropout': jax.random.PRNGKey(0)}, mutable=['batch_stats'])
    one_hot = jax.nn.one_hot(labels, num_classes)
    loss = optax.softmax_cross_entropy(logits=logits, labels=one_hot).mean()
    return loss, (logits, batch_stats)

  (loss, (logits, batch_stats)), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params,state.batch_stats)
  probs = jax.lax.pmean(jax.nn.softmax(logits), axis_name='ensemble')

  accuracy = jnp.mean(jnp.argmax(probs, -1) == labels)
  return grads,loss, accuracy

@jax.pmap
def update_model(state, grads):
  return state.apply_gradients(grads=grads)

In [26]:
rm -rf ./flax_logs/     

In [27]:
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms.functional as F
logdir = "flax_logs"
writer = SummaryWriter(logdir)

In [28]:
(test_images, test_labels) = next(iter(validation_loader))
test_images = test_images / 255.0

In [29]:
test_images = np.array(jax_utils.replicate(test_images))
test_labels = np.array(jax_utils.replicate(test_labels))

In [30]:
epoch_loss = []
epoch_accuracy = []
testing_accuracy = []
testing_loss = []

In [31]:
def train_one_epoch(state, dataloader,num_epochs):
    """Train for 1 epoch on the training set."""
    
    for epoch in range(num_epochs):
        for cnt, (images, labels) in tqdm(enumerate(dataloader), total=(math.ceil(len(train_set)/batch_size))):
            images = images / 255.0
            images = jax_utils.replicate(images)
            labels = jax_utils.replicate(labels)
            grads, loss, accuracy = apply_model(state, images, labels)
            state = update_model(state, grads)
            
        epoch_loss.append(jax_utils.unreplicate(loss))
        epoch_accuracy.append(jax_utils.unreplicate(accuracy))
        train_loss = np.mean(epoch_loss)
        train_accuracy = np.mean(epoch_accuracy)
        
        _, test_loss, test_accuracy = jax_utils.unreplicate(apply_model(state, test_images, test_labels))
        testing_accuracy.append(test_accuracy)
        testing_loss.append(test_loss)
        
        writer.add_scalar('Loss/train', np.array(train_loss), epoch)
        writer.add_scalar('Loss/test', np.array(test_loss), epoch)
    
        writer.add_scalar('Accuracy/train', np.array(train_accuracy), epoch)
        writer.add_scalar('Accuracy/test', np.array(test_accuracy), epoch)
    
        print(f"Epoch: {epoch + 1}, train loss: {train_loss:.4f}, train accuracy: {train_accuracy * 100:.4f}, test loss: {test_loss:.4f}, test accuracy: {test_accuracy* 100:.4f}", flush=True)
    return state, epoch_loss, epoch_accuracy, testing_accuracy, testing_loss

In [32]:
seed = 0 
rng = jax.random.PRNGKey(seed)
rng, init_rng = jax.random.split(rng)
state = create_train_state(jax.random.split(init_rng, jax.device_count()))
del init_rng  # Must not be used anymore.

In [None]:
start = time.time()
num_epochs = 30
state, epoch_loss, epoch_accuracy, testing_accuracy, testing_loss = train_one_epoch(state, train_loader,num_epochs)
print("Total time: ", time.time() - start, "seconds")

  0%|          | 0/625 [00:00<?, ?it/s]

Epoch: 1, train loss: 0.7044, train accuracy: 50.0000, test loss: 0.6610, test accuracy: 56.2500


  0%|          | 0/625 [00:00<?, ?it/s]

Epoch: 2, train loss: 0.6756, train accuracy: 56.2500, test loss: 0.6459, test accuracy: 68.7500


  0%|          | 0/625 [00:00<?, ?it/s]

Epoch: 3, train loss: 0.6585, train accuracy: 56.2500, test loss: 0.6377, test accuracy: 65.6250


  0%|          | 0/625 [00:00<?, ?it/s]

Epoch: 4, train loss: 0.6137, train accuracy: 63.2812, test loss: 0.5754, test accuracy: 75.0000


  0%|          | 0/625 [00:00<?, ?it/s]

Epoch: 5, train loss: 0.5932, train accuracy: 66.8750, test loss: 0.5869, test accuracy: 62.5000


  0%|          | 0/625 [00:00<?, ?it/s]

Epoch: 6, train loss: 0.5743, train accuracy: 68.7500, test loss: 0.6419, test accuracy: 62.5000


  0%|          | 0/625 [00:00<?, ?it/s]

Epoch: 7, train loss: 0.5525, train accuracy: 70.0893, test loss: 0.6790, test accuracy: 65.6250


  0%|          | 0/625 [00:00<?, ?it/s]

Epoch: 8, train loss: 0.5485, train accuracy: 70.7031, test loss: 0.4403, test accuracy: 75.0000


  0%|          | 0/625 [00:00<?, ?it/s]

Epoch: 9, train loss: 0.5312, train accuracy: 71.5278, test loss: 0.5296, test accuracy: 75.0000


  0%|          | 0/625 [00:00<?, ?it/s]

Epoch: 10, train loss: 0.5117, train accuracy: 72.8125, test loss: 0.4238, test accuracy: 81.2500


  0%|          | 0/625 [00:00<?, ?it/s]

In [None]:
pip install tensorstore

In [None]:
from flax.training import checkpoints
ckpt_dir = 'model_checkpoint/'
checkpoints.save_checkpoint(ckpt_dir=ckpt_dir,  
                            target=state, 
                            step=100, 
                            prefix='flax_model',  
                            overwrite=True  
                           )

In [None]:
loaded_model = checkpoints.restore_checkpoint(
                                             ckpt_dir=ckpt_dir,   
                                             target=state,  
                                             prefix='flax_model'  
                                            )

In [None]:
loaded_model

In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir={logdir}

In [None]:
plt.plot(epoch_accuracy, label="Training")
plt.plot(testing_accuracy, label="Test")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.show()

In [None]:
plt.plot(epoch_loss, label="Training")
plt.plot(testing_loss, label="Test")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.show()

## Where to go from here
Follow us on [LinkedIn](https://www.linkedin.com/company/mlnuggets), [Twitter](https://twitter.com/ml_nuggets), [GitHub](https://github.com/mlnuggets) and subscribe to our [blog](https://www.machinelearningnuggets.com/#/portal) so that you don't miss a new issue.