# Model Soups

In this assignment, you will be implementing [Model Soups](https://arxiv.org/pdf/2203.05482.pdf)


In [1]:
!pip install jax-resnet

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting jax-resnet
  Downloading jax_resnet-0.0.4-py2.py3-none-any.whl (11 kB)
Collecting flax
  Downloading flax-0.6.2-py3-none-any.whl (189 kB)
[K     |████████████████████████████████| 189 kB 7.9 MB/s 
Collecting optax
  Downloading optax-0.1.4-py3-none-any.whl (154 kB)
[K     |████████████████████████████████| 154 kB 46.2 MB/s 
Collecting tensorstore
  Downloading tensorstore-0.1.28-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8.3 MB)
[K     |████████████████████████████████| 8.3 MB 35.2 MB/s 
[?25hCollecting rich>=11.1
  Downloading rich-12.6.0-py3-none-any.whl (237 kB)
[K     |████████████████████████████████| 237 kB 47.5 MB/s 
Collecting commonmark<0.10.0,>=0.9.0
  Downloading commonmark-0.9.1-py2.py3-none-any.whl (51 kB)
[K     |████████████████████████████████| 51 kB 3.1 MB/s 
Collecting chex>=0.1.5
  Downloading chex-0.1.5-py3-none-any.whl (85 kB)
[K     

In [8]:
import os
import cv2
import numpy as np
import pandas as pd
from PIL import Image
from typing import Callable
from tqdm.notebook import tqdm

from sklearn import preprocessing
from sklearn.model_selection import train_test_split
#import tensorflow.keras as keras
import tensorflow as tf
import tensorflow_datasets as tfds 

import jax
import optax
import flax
import jax.numpy as jnp
from jax import jit
from jax import lax
from jax_resnet import pretrained_resnet, slice_variables, Sequential
from flax.jax_utils import replicate, unreplicate
from flax.training import train_state
from flax import linen as nn
from flax.core import FrozenDict,frozen_dict
from flax.training.common_utils import shard

from functools import partial

In [3]:
Config = {
    'NUM_LABELS': 10,
    'N_SPLITS': 5,
    'BATCH_SIZE': 32,
    'N_EPOCHS': 10,
    'LR': 0.001,
    'WIDTH': 32,
    'HEIGHT': 32,
    'IMAGE_SIZE': 128,
    'WEIGHT_DECAY': 1e-5,
    'FREEZE_BACKBONE': True
}

In [4]:
# DATA PREPROCESSING

def transform_images(row, size):
    '''
    Resize image 
    INPUT row , size
    RETURNS resized image and its label
    '''
    x_train = tf.image.resize(row['image'], (size, size))
    return x_train, row['label']

def load_datasets():
    '''
    load and transform dataset from tfds
    RETURNS train and test dataset
    
    '''
    
    # Construct a tf.data.Dataset
    train_ds,test_ds = tfds.load('cifar10', split=['train','test'], shuffle_files=True)

    train_ds = train_ds.map(lambda row:transform_images(row,Config["IMAGE_SIZE"]))
    test_ds = test_ds.map(lambda row:transform_images(row,Config["IMAGE_SIZE"]))
    
    # Build your input pipeline
    train_dataset = train_ds.batch(Config["BATCH_SIZE"]).prefetch(tf.data.AUTOTUNE)
    test_dataset = test_ds.batch(Config["BATCH_SIZE"]).prefetch(tf.data.AUTOTUNE)
    
    return train_dataset,test_dataset

In [5]:
# DEFINING THE MODEL
"""
reference - https://www.kaggle.com/code/alexlwh/happywhale-flax-jax-tpu-gpu-resnet-baseline
"""
class MarginLayer(nn.Module):
    @nn.compact
    def __call__(self, inputs):
        raise NotImplementedError

class Head(nn.Module):
    '''head model'''
    batch_norm_cls: partial = partial(nn.BatchNorm, momentum=0.9)
    @nn.compact
    def __call__(self, inputs, train: bool):
        output_n = inputs.shape[-1]
        x = self.batch_norm_cls(use_running_average=not train)(inputs)
        x = nn.Dropout(rate=0.25)(x, deterministic=not train)
        x = nn.Dense(features=output_n)(x)
        x = nn.relu(x)
        x = self.batch_norm_cls(use_running_average=not train)(x)
        x = nn.Dropout(rate=0.5)(x, deterministic=not train)
        x = nn.Dense(features=Config["NUM_LABELS"])(x)
        return x

class Model(nn.Module):
    '''Combines backbone and head model'''
    backbone: Sequential
    head: Head
        
    def __call__(self, inputs, train: bool):
        x = self.backbone(inputs)
        # average pool layer
        x = jnp.mean(x, axis=(1, 2))
        x = self.head(x, train)
        return x

    
def _get_backbone_and_params(model_arch: str):
    '''
    Get backbone and params
    1. Loads pretrained model (resnet18)
    2. Get model and param structure except last 2 layers
    3. Extract the corresponding subset of the variables dict
    INPUT : model_arch
    RETURNS backbone , backbone_params
    '''
    if model_arch == 'resnet18':
        resnet_tmpl, params = pretrained_resnet(18)
        model = resnet_tmpl()
    else:
        raise NotImplementedError
        
    # get model & param structure for backbone
    start, end = 0, len(model.layers) - 2
    backbone = Sequential(model.layers[start:end])
    backbone_params = slice_variables(params, start, end)
    return backbone, backbone_params


def get_model_and_variables(model_arch: str, head_init_key: int):
    '''
    Get model and variables 
    1. Initialise inputs(shape=(1,image_size,image_size,3))
    2. Get backbone and params
    3. Apply backbone model and get outputs
    4. Initialise head
    5. Create final model using backbone and head
    6. Combine params from backbone and head
    
    INPUT model_arch, head_init_key
    RETURNS  model, variables 
    '''
    
    #backbone
    inputs = jnp.ones((1, Config['IMAGE_SIZE'],Config['IMAGE_SIZE'], 3), jnp.float32)
    backbone, backbone_params = _get_backbone_and_params(model_arch)
    key = jax.random.PRNGKey(head_init_key)
    backbone_output = backbone.apply(backbone_params, inputs, mutable=False)
    
    #head
    head_inputs = jnp.ones((1, backbone_output.shape[-1]), jnp.float32)
    head = Head()
    head_params = head.init(key, head_inputs, train=False)
    
    #final model
    model = Model(backbone, head)
    variables = FrozenDict({
        'params': {
            'backbone': backbone_params['params'],
            'head': head_params['params']
        },
        'batch_stats': {
            'backbone': backbone_params['batch_stats'],
            'head': head_params['batch_stats']
        }
    })
    return model, variables

In [6]:
# Loading in pre-trained ResNet18 Model

model, variables = get_model_and_variables('resnet18', 0)
inputs = jnp.ones((1,Config['IMAGE_SIZE'], Config['IMAGE_SIZE'],3), jnp.float32)
key = jax.random.PRNGKey(0)
o = model.apply(variables, inputs, train=False, mutable=False)

Downloading: "https://github.com/pytorch/vision/zipball/v0.6.0" to /root/.cache/torch/hub/v0.6.0.zip
Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

In [9]:
train_dataset,test_dataset = load_datasets()

Downloading and preparing dataset 162.17 MiB (download: 162.17 MiB, generated: 132.40 MiB, total: 294.58 MiB) to ~/tensorflow_datasets/cifar10/3.0.2...


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]

Generating splits...:   0%|          | 0/2 [00:00<?, ? splits/s]

Generating train examples...:   0%|          | 0/50000 [00:00<?, ? examples/s]

Shuffling ~/tensorflow_datasets/cifar10/3.0.2.incompleteV0J9RR/cifar10-train.tfrecord*...:   0%|          | 0/…

Generating test examples...:   0%|          | 0/10000 [00:00<?, ? examples/s]

Shuffling ~/tensorflow_datasets/cifar10/3.0.2.incompleteV0J9RR/cifar10-test.tfrecord*...:   0%|          | 0/1…

Dataset cifar10 downloaded and prepared to ~/tensorflow_datasets/cifar10/3.0.2. Subsequent calls will reuse this data.


In [14]:
total_batch_size = Config["BATCH_SIZE"]
num_train_steps = len(train_dataset)

In [15]:
"""
reference - https://github.com/deepmind/optax/issues/159#issuecomment-896459491
"""
def zero_grads():
    '''
    Zero out the previous gradient computation
    '''
    def init_fn(_): 
        return ()
    def update_fn(updates, state, params=None):
        return jax.tree_map(jnp.zeros_like, updates), ()
    return optax.GradientTransformation(init_fn, update_fn)

"""
reference - https://colab.research.google.com/drive/1g_pt2Rc3bv6H6qchvGHD-BpgF-Pt4vrC#scrollTo=TqDvTL_tIQCH&line=2&uniqifier=1
"""
def create_mask(params, label_fn):
    def _map(params, mask, label_fn):
        for k in params:
            if label_fn(k):
                mask[k] = 'zero'
            else:
                if isinstance(params[k], FrozenDict):
                    mask[k] = {}
                    _map(params[k], mask[k], label_fn)
                else:
                    mask[k] = 'adam'
    mask = {}
    _map(params, mask, label_fn)
    return frozen_dict.freeze(mask)

In [16]:
# Define optimizer

adamw = optax.adamw(
    learning_rate=Config['LR'],
    b1=0.9, b2=0.999, 
    eps=1e-6, weight_decay=1e-2
)

optimizer = optax.multi_transform(
    {'adam': adamw, 'zero': zero_grads()},
    create_mask(variables['params'], lambda s: s.startswith('backbone'))
)

In [None]:
def accuracy(logits,labels):
    '''
    calculates accuracy based on logits and labels
    INPUT logits , labels
    RETURNS accuracy
    '''
    return [jnp.mean(jnp.argmax(logits, -1) == labels)]

In [None]:
#loss function and evaluation function 
loss_fn = optax.softmax_cross_entropy
eval_fn = accuracy

In [None]:
# Instantiate a TrainState.
state = TrainState.create(
    apply_fn = model.apply,
    params = variables['params'],
    tx = optimizer,
    batch_stats = variables['batch_stats'],
    loss_fn = loss_fn,
    eval_fn = eval_fn
)

In [None]:
def train_step(state: TrainState, batch, labels, dropout_rng):
    dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
    
    # params as input because we differentiate wrt it 
    def loss_function(params):
        # if you set state.params, then params can't be backpropagated through!
        variables = {'params': params, 'batch_stats': state.batch_stats}
        
        # return mutated states if mutable is specified
        logits, new_batch_stats = state.apply_fn(
            variables, batch, train=True, 
            mutable=['batch_stats'],
            rngs={'dropout': dropout_rng}
        )
        # logits: (BS, OUTPUT_N), one_hot: (BS, OUTPUT_N)
        one_hot = jax.nn.one_hot(labels,Config["NUM_LABELS"])
        loss = state.loss_fn(logits, one_hot).mean()
        return loss, (logits, new_batch_stats)
    
    
    # backpropagation and update params & batch_stats 
    grad_fn = jax.value_and_grad(loss_function, has_aux=True) #differentiate the loss function
    (loss, aux), grads = grad_fn(state.params)
    logits, new_batch_stats = aux
    grads = lax.pmean(grads, axis_name='batch') #compute the mean gradient over all devices
    new_state = state.apply_gradients(
        grads=grads, batch_stats=new_batch_stats['batch_stats'] #applies the gradients to the weights.
    )
    
    # evaluation metrics
    accuracy = state.eval_fn(logits, labels)
    
    # store metadata
    metadata = jax.lax.pmean(
        {'loss': loss, 'accuracy': accuracy},
        axis_name='batch'
    )
    return new_state, metadata, new_dropout_rng


def val_step(state: TrainState, batch, labels):
    variables = {'params': state.params, 'batch_stats': state.batch_stats}
    logits = state.apply_fn(variables, batch, train=False) # stack the model's forward pass with the logits function
    return state.eval_fn(logits, labels)

def test_step(state: TrainState, batch):
    variables = {'params': state.params, 'batch_stats': state.batch_stats}
    logits = state.apply_fn(variables, batch, train=False) # stack the model's forward pass with the logits function
    return logits

parallel_train_step = jax.pmap(train_step, axis_name='batch', donate_argnums=(0,))
parallel_val_step = jax.pmap(val_step, axis_name='batch', donate_argnums=(0,))
parallel_test_step = jax.pmap(test_step, axis_name='batch', donate_argnums=(0,))

# required for parallelism
state = replicate(state)

# control randomness on dropout and update inside train_step
rng = jax.random.PRNGKey(0)
dropout_rng = jax.random.split(rng, jax.local_device_count())  # for parallelism

In [None]:
# Training Loop

epochs = Config['N_EPOCHS']

for epoch_i in tqdm(range(epochs), desc=f"{epochs} epochs", position=0, leave=True):
    # training set
    train_loss, train_accuracy = [], []
    iter_n = len(train_dataset)
    
    with tqdm(total=iter_n, desc=f"{iter_n} iterations", leave=False) as progress_bar:
        for _batch in train_dataset:
            batch=_batch[0]  # train_dataset is tuple containing (image,labels)
            labels=_batch[1]

            batch = jnp.array(batch, dtype=jnp.float32)
            labels = jnp.array(labels, dtype=jnp.float32)
            
            batch, labels = shard(batch), shard(labels)
           
            # backprop and update param & batch statsp
            
            state, train_metadata, dropout_rng = parallel_train_step(state, batch, labels, dropout_rng)
            train_metadata = unreplicate(train_metadata)
            
            # update train statistics
            _train_loss, _train_top1_acc = map(float, [train_metadata['loss'], *train_metadata['accuracy']])
            train_loss.append(_train_loss)
            train_accuracy.append(_train_top1_acc)
            progress_bar.update(1)
            
    avg_train_loss = sum(train_loss)/len(train_loss)
    avg_train_acc = sum(train_accuracy)/len(train_accuracy)
    print(f"[{epoch_i+1}/{Config['N_EPOCHS']}] Train Loss: {avg_train_loss:.03} | Train Accuracy: {avg_train_acc:.03}")
    
    # validation set
    
    valid_accuracy = []
    iter_n = len(test_dataset)
    with tqdm(total=iter_n, desc=f"{iter_n} iterations", leave=False) as progress_bar:
        for _batch in test_dataset:
            batch = _batch[0]
            labels = _batch[1]

            batch = jnp.array(batch, dtype=jnp.float32)
            labels = jnp.array(labels, dtype=jnp.float32)

            batch, labels = shard(batch), shard(labels)
            metric = parallel_val_step(state, batch, labels)[0]
            valid_accuracy.append(metric)
            progress_bar.update(1)


    avg_valid_acc = sum(valid_accuracy)/len(valid_accuracy)
    avg_valid_acc = np.array(avg_valid_acc)[0]
    print(f"[{epoch_i+1}/{Config['N_EPOCHS']}] Valid Accuracy: {avg_valid_acc:.03}")

# Background 
Pre-training <br>
Hyperparameter Search <br>
Ensemble Learning <br>
Greedy Algorithms <br> 
Vision Transformers <br>

### Part A. Fine Tuning a single ViT model 

### Part B. Fine-Tuning various models (and pick the best one)


### Part C. Ensemble Learning

### Part D. Uniform Soup

### Part E. Greedy Soup