# Transfer learning with JAX and Flax

Click the image below to read the post online.

<a target="_blank" href="https://www.machinelearningnuggets.com/transfer-learning-with-jax-flax
"><img src="https://www.machinelearningnuggets.com/ezoimgfmt/digitalpress.fra1.cdn.digitaloceanspaces.com/mhujhsj/2022/07/logho-1.png?ezimgfmt=ng:webp/ngcb1" alt="Open in ML Nuggets"></a>

# Imports

In [None]:
# install jax-resnet library
!pip install jax-resnet

In [None]:
pip install wget

In [None]:
pip install flax

In [None]:
import numpy as np
import pandas as pd
from PIL import Image


import jax
import optax
import flax
import jax.numpy as jnp
from jax_resnet import pretrained_resnet, slice_variables, Sequential
from flax.training import train_state
from flax import linen as nn
from flax.core import FrozenDict,frozen_dict
from functools import partial
import os
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
%matplotlib inline
# ignore harmless warnings
import warnings
warnings.filterwarnings("ignore")

In [None]:
import wget 
wget.download("https://ml.machinelearningnuggets.com/train.zip")

In [None]:
import zipfile
with zipfile.ZipFile('train.zip', 'r') as zip_ref:
  zip_ref.extractall('.')

In [None]:
class CatsDogsDataset(Dataset):
    def __init__(self, root_dir, annotation_file, transform=None):
        self.root_dir = root_dir
        self.annotations = pd.read_csv(annotation_file)
        self.transform = transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, index):
        img_id = self.annotations.iloc[index, 0]
        img = Image.open(os.path.join(self.root_dir, img_id)).convert("RGB")
        y_label = torch.tensor(float(self.annotations.iloc[index, 1]))

        if self.transform is not None:
            img = self.transform(img)

        return (img, y_label)

In [None]:
train_df = pd.DataFrame(columns=["img_path","label"])
train_df["img_path"] = os.listdir("train/")
for idx, i in enumerate(os.listdir("train/")):
    if "cat" in i:
        train_df["label"][idx] = 0
    if "dog" in i:
        train_df["label"][idx] = 1

train_df.to_csv (r'train_csv.csv', index = False, header=True)

In [None]:
df = pd.read_csv("train_csv.csv")
df.head()

In [None]:
def custom_collate_fn(batch):
    transposed_data = list(zip(*batch))
    labels = np.array(transposed_data[1])
    imgs = np.stack(transposed_data[0])
    return imgs, labels

In [None]:

config = {
    'NUM_LABELS': 2,
    'BATCH_SIZE': 32,
    'N_EPOCHS': 5,
    'LR': 0.001,
    'IMAGE_SIZE': 224,
    'WEIGHT_DECAY': 1e-5,
    'FREEZE_BACKBONE': True,
}

In [None]:
transform = transforms.Compose([
    transforms.Resize((config["IMAGE_SIZE"],config["IMAGE_SIZE"])),
    np.array])
dataset = CatsDogsDataset("train","train_csv.csv",transform=transform)
train_set, validation_set = torch.utils.data.random_split(dataset,[20000,5000])
train_loader = DataLoader(dataset=train_set, collate_fn=custom_collate_fn,shuffle=True, batch_size=config["BATCH_SIZE"])
validation_loader = DataLoader(dataset=validation_set,collate_fn=custom_collate_fn, shuffle=False, batch_size=config["BATCH_SIZE"])

In [None]:
(image_batch, label_batch) = next(iter(train_loader))
print(image_batch.shape)
print(label_batch.shape)

#  Model

In [None]:
"""
reference - https://www.kaggle.com/code/alexlwh/happywhale-flax-jax-tpu-gpu-resnet-baseline
"""
class Head(nn.Module):
    '''head model'''
    batch_norm_cls: partial = partial(nn.BatchNorm, momentum=0.9)
    @nn.compact
    def __call__(self, inputs, train: bool):
        output_n = inputs.shape[-1]
        x = self.batch_norm_cls(use_running_average=not train)(inputs)
        x = nn.Dropout(rate=0.25)(x, deterministic=not train)
        x = nn.Dense(features=output_n)(x)
        x = nn.relu(x)
        x = self.batch_norm_cls(use_running_average=not train)(x)
        x = nn.Dropout(rate=0.5)(x, deterministic=not train)
        x = nn.Dense(features=config["NUM_LABELS"])(x)
        return x

class Model(nn.Module):
    '''Combines backbone and head model'''
    backbone: Sequential
    head: Head
        
    def __call__(self, inputs, train: bool):
        x = self.backbone(inputs)
        # average pool layer
        x = jnp.mean(x, axis=(1, 2))
        x = self.head(x, train)
        return x

    
def get_backbone_and_params(model_arch: str):
    '''
    Get backbone and params
    1. Loads pretrained model (resnet50)
    2. Get model and param structure except last 2 layers
    3. Extract the corresponding subset of the variables dict
    INPUT : model_arch
    RETURNS backbone , backbone_params
    '''
    if model_arch == 'resnet50':
        resnet_tmpl, params = pretrained_resnet(50)
        model = resnet_tmpl()
    else:
        raise NotImplementedError
        
    # get model & param structure for backbone
    start, end = 0, len(model.layers) - 2
    backbone = Sequential(model.layers[start:end])
    backbone_params = slice_variables(params, start, end)
    return backbone, backbone_params


def get_model_and_variables(model_arch: str, head_init_key: int):
    '''
    Get model and variables 
    1. Initialise inputs(shape=(1,image_size,image_size,3))
    2. Get backbone and params
    3. Apply backbone model and get outputs
    4. Initialise head
    5. Create final model using backbone and head
    6. Combine params from backbone and head
    
    INPUT model_arch, head_init_key
    RETURNS  model, variables 
    '''
    
    #backbone
    inputs = jnp.ones((1, config['IMAGE_SIZE'],config['IMAGE_SIZE'], 3), jnp.float32)
    backbone, backbone_params = get_backbone_and_params(model_arch)
    key = jax.random.PRNGKey(head_init_key)
    backbone_output = backbone.apply(backbone_params, inputs, mutable=False)
    
    #head
    head_inputs = jnp.ones((1, backbone_output.shape[-1]), jnp.float32)
    head = Head()
    head_params = head.init(key, head_inputs, train=False)
    
    #final model
    model = Model(backbone, head)
    variables = FrozenDict({
        'params': {
            'backbone': backbone_params['params'],
            'head': head_params['params']
        },
        'batch_stats': {
            'backbone': backbone_params['batch_stats'],
            'head': head_params['batch_stats']
        }
    })
    return model, variables


In [None]:
model, variables = get_model_and_variables('resnet50', 0)
inputs = jnp.ones((1,config['IMAGE_SIZE'], config['IMAGE_SIZE'],3), jnp.float32)
key = jax.random.PRNGKey(0)
model_apply = model.apply(variables, inputs, train=False, mutable=False)

In [None]:
model_apply

In [None]:
"""
reference - https://github.com/deepmind/optax/issues/159#issuecomment-896459491
"""
def zero_grads():
    '''
    Zero out the previous gradient computation
    '''
    def init_fn(_): 
        return ()
    def update_fn(updates, state, params=None):
        return jax.tree_map(jnp.zeros_like, updates), ()
    return optax.GradientTransformation(init_fn, update_fn)

In [None]:
"""
reference - https://colab.research.google.com/drive/1g_pt2Rc3bv6H6qchvGHD-BpgF-Pt4vrC#scrollTo=TqDvTL_tIQCH&line=2&uniqifier=1
"""
def create_mask(params, label_fn):
    def _map(params, mask, label_fn):
        for k in params:
            if label_fn(k):
                mask[k] = 'zero'
            else:
                if isinstance(params[k], FrozenDict):
                    mask[k] = {}
                    _map(params[k], mask[k], label_fn)
                else:
                    mask[k] = 'adam'
    mask = {}
    _map(params, mask, label_fn)
    return frozen_dict.freeze(mask)

In [None]:
adamw = optax.adamw(
    learning_rate=config['LR'],
    b1=0.9, b2=0.999, 
    eps=1e-6, weight_decay=1e-2
)
optimizer = optax.multi_transform(
    {'adam': adamw, 'zero': zero_grads()},
    create_mask(variables['params'], lambda s: s.startswith('backbone'))
)

In [None]:
def cross_entropy_loss(*, logits, labels):
  labels_onehot = jax.nn.one_hot(labels, num_classes=config["NUM_LABELS"])
  return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()

In [None]:
def compute_loss(params, batch_stats, images, labels):
    logits,batch_stats = model.apply({'params': params,'batch_stats': batch_stats},images, train=True,rngs={'dropout': jax.random.PRNGKey(0)}, mutable=['batch_stats'])
    loss = cross_entropy_loss(logits=logits, labels=labels)
    return loss, (logits, batch_stats)

In [None]:
def compute_metrics(*, logits, labels):
  loss = cross_entropy_loss(logits=logits, labels=labels)
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  metrics = {
      'loss': loss,
      'accuracy': accuracy,
  }
  return metrics

In [None]:
class TrainState(train_state.TrainState):
    batch_stats: FrozenDict


In [None]:
state = TrainState.create(
    apply_fn = model.apply,
    params = variables['params'],
    tx = optimizer,
    batch_stats = variables['batch_stats'],
)


In [None]:
@jax.jit
def train_step(state: TrainState,images, labels):
  """Train for a single step."""
  (batch_loss, (logits, batch_stats)), grads= jax.value_and_grad(compute_loss, has_aux=True)(state.params,state.batch_stats, images,labels)
  state = state.apply_gradients(grads=grads) 
  metrics = compute_metrics(logits=logits, labels=labels) 
  return state, metrics

In [None]:
def train_one_epoch(state, dataloader):
    """Train for 1 epoch on the training set."""
    batch_metrics = []
    for cnt, (images, labels) in enumerate(dataloader):
        images = images / 255.0
        state, metrics = train_step(state, images, labels)
        batch_metrics.append(metrics)

    batch_metrics_np = jax.device_get(batch_metrics)  
    epoch_metrics_np = {
        k: np.mean([metrics[k] for metrics in batch_metrics_np])
        for k in batch_metrics_np[0]
    }
    return state, epoch_metrics_np

In [None]:
@jax.jit
def eval_step(batch_stats, params, images, labels):
    logits = model.apply({'params': params,'batch_stats': batch_stats}, images, train=False,rngs={'dropout': jax.random.PRNGKey(0)})
    return compute_metrics(logits=logits, labels=labels)

In [None]:
def evaluate_model(state, test_imgs, test_lbls):
    """Evaluate on the validation set."""
    metrics = eval_step(state.batch_stats,state.params, test_imgs, test_lbls)
    metrics = jax.device_get(metrics) 
    metrics = jax.tree_map(lambda x: x.item(), metrics)  
    return metrics

In [None]:
(test_images, test_labels) = next(iter(validation_loader))
test_images = test_images / 255.0

In [None]:
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms.functional as F
logdir = "flax_logs"
writer = SummaryWriter(logdir)

In [None]:
training_loss = []
training_accuracy = []
testing_loss = []
testing_accuracy = []

In [None]:
def train_model(epochs):
    for epoch in range(1, epochs + 1):
        train_state, train_metrics = train_one_epoch(state, train_loader)
        training_loss.append(train_metrics['loss'])
        training_accuracy.append(train_metrics['accuracy'])

        test_metrics = evaluate_model(train_state, test_images, test_labels)
        testing_loss.append(test_metrics['loss'])
        testing_accuracy.append(test_metrics['accuracy'])
    
        writer.add_scalar('Loss/train', train_metrics['loss'], epoch)
        writer.add_scalar('Loss/test', test_metrics['loss'], epoch)
    
        writer.add_scalar('Accuracy/train', train_metrics['accuracy'], epoch)
        writer.add_scalar('Accuracy/test', test_metrics['accuracy'], epoch)
    
        print(f"Epoch: {epoch}, training loss: {train_metrics['loss']}, training accuracy: {train_metrics['accuracy'] * 100}, validation loss: {test_metrics['loss']}, validation accuracy: {test_metrics['accuracy'] * 100}")
    return train_state

In [None]:
trained_model_state = train_model(config["N_EPOCHS"])

In [None]:
pip install tensorstore

In [None]:
from flax.training import checkpoints
ckpt_dir = 'model_checkpoint/'
checkpoints.save_checkpoint(ckpt_dir=ckpt_dir,  
                            target=trained_model_state, 
                            step=100, 
                            prefix='resnet_model',  
                            overwrite=True  
                           )

In [None]:
loaded_model = checkpoints.restore_checkpoint(
                                             ckpt_dir=ckpt_dir,   
                                             target=state,  
                                             prefix='resnet_model'  
                                            )
                                            
loaded_model

In [None]:
evaluate_model(loaded_model,test_images, test_labels)

In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir={logdir}

In [None]:
plt.plot(training_accuracy, label="Training")
plt.plot(testing_accuracy, label="Test")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.show()

In [None]:
plt.plot(training_loss, label="Training")
plt.plot(testing_loss, label="Test")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.show()

## Where to go from here
Follow us on [LinkedIn](https://www.linkedin.com/company/mlnuggets), [Twitter](https://twitter.com/ml_nuggets), [GitHub](https://github.com/mlnuggets) and subscribe to our [blog](https://www.machinelearningnuggets.com/#/portal) so that you don't miss a new issue.