# Model Soups

In this assignment, you will be implementing [Model Soups](https://arxiv.org/pdf/2203.05482.pdf)


In [7]:
!pip install jax-resnet

Collecting jax-resnet
  Downloading jax_resnet-0.0.4-py2.py3-none-any.whl (11 kB)
Collecting flax
  Downloading flax-0.6.2-py3-none-any.whl (189 kB)
[K     |████████████████████████████████| 189 kB 9.3 MB/s eta 0:00:01
Collecting optax
  Downloading optax-0.1.4-py3-none-any.whl (154 kB)
[K     |████████████████████████████████| 154 kB 30.6 MB/s eta 0:00:01
[?25hCollecting tensorstore
  Downloading tensorstore-0.1.28-cp39-cp39-macosx_10_14_x86_64.whl (9.2 MB)
[K     |████████████████████████████████| 9.2 MB 20.6 MB/s eta 0:00:01
[?25hCollecting rich>=11.1
  Downloading rich-12.6.0-py3-none-any.whl (237 kB)
[K     |████████████████████████████████| 237 kB 25.8 MB/s eta 0:00:01
Collecting commonmark<0.10.0,>=0.9.0
  Downloading commonmark-0.9.1-py2.py3-none-any.whl (51 kB)
[K     |████████████████████████████████| 51 kB 20.5 MB/s eta 0:00:01
Collecting chex>=0.1.5
  Downloading chex-0.1.5-py3-none-any.whl (85 kB)
[K     |████████████████████████████████| 85 kB 7.0 MB/s  eta 0:00:0

In [6]:
import os
import cv2
import numpy as np
import pandas as pd
from PIL import Image
from typing import Callable
from tqdm.notebook import tqdm

from sklearn import preprocessing
from sklearn.model_selection import train_test_split
#import tensorflow.keras as keras
#import tensorflow as tf
#import tensorflow_datasets as tfds 

import jax
import optax
import flax
import jax.numpy as jnp
from jax import jit
from jax import lax
from jax_resnet import pretrained_resnet, slice_variables, Sequential
from flax.jax_utils import replicate, unreplicate
from flax.training import train_state
from flax import linen as nn
from flax.core import FrozenDict,frozen_dict
from flax.training.common_utils import shard

from functools import partial

In [3]:
Config = {
    'NUM_LABELS': 10,
    'N_SPLITS': 5,
    'BATCH_SIZE': 32,
    'N_EPOCHS': 10,
    'LR': 0.001,
    'WIDTH': 32,
    'HEIGHT': 32,
    'IMAGE_SIZE': 128,
    'WEIGHT_DECAY': 1e-5,
    'FREEZE_BACKBONE': True
}

In [4]:
# DATA PREPROCESSING

def transform_images(row, size):
    '''
    Resize image 
    INPUT row , size
    RETURNS resized image and its label
    '''
    x_train = tf.image.resize(row['image'], (size, size))
    return x_train, row['label']

def load_datasets():
    '''
    load and transform dataset from tfds
    RETURNS train and test dataset
    
    '''
    
    # Construct a tf.data.Dataset
    train_ds,test_ds = tfds.load('cifar10', split=['train','test'], shuffle_files=True)

    train_ds = train_ds.map(lambda row:transform_images(row,Config["IMAGE_SIZE"]))
    test_ds = test_ds.map(lambda row:transform_images(row,Config["IMAGE_SIZE"]))
    
    # Build your input pipeline
    train_dataset = train_ds.batch(Config["BATCH_SIZE"]).prefetch(tf.data.AUTOTUNE)
    test_dataset = test_ds.batch(Config["BATCH_SIZE"]).prefetch(tf.data.AUTOTUNE)
    
    return train_dataset,test_dataset

In [7]:
# DEFINING THE MODEL
"""
reference - https://www.kaggle.com/code/alexlwh/happywhale-flax-jax-tpu-gpu-resnet-baseline
"""
class MarginLayer(nn.Module):
    @nn.compact
    def __call__(self, inputs):
        raise NotImplementedError

class Head(nn.Module):
    '''head model'''
    batch_norm_cls: partial = partial(nn.BatchNorm, momentum=0.9)
    @nn.compact
    def __call__(self, inputs, train: bool):
        output_n = inputs.shape[-1]
        x = self.batch_norm_cls(use_running_average=not train)(inputs)
        x = nn.Dropout(rate=0.25)(x, deterministic=not train)
        x = nn.Dense(features=output_n)(x)
        x = nn.relu(x)
        x = self.batch_norm_cls(use_running_average=not train)(x)
        x = nn.Dropout(rate=0.5)(x, deterministic=not train)
        x = nn.Dense(features=Config["NUM_LABELS"])(x)
        return x

class Model(nn.Module):
    '''Combines backbone and head model'''
    backbone: Sequential
    head: Head
        
    def __call__(self, inputs, train: bool):
        x = self.backbone(inputs)
        # average pool layer
        x = jnp.mean(x, axis=(1, 2))
        x = self.head(x, train)
        return x

    
def _get_backbone_and_params(model_arch: str):
    '''
    Get backbone and params
    1. Loads pretrained model (resnet18)
    2. Get model and param structure except last 2 layers
    3. Extract the corresponding subset of the variables dict
    INPUT : model_arch
    RETURNS backbone , backbone_params
    '''
    if model_arch == 'resnet18':
        resnet_tmpl, params = pretrained_resnet(18)
        model = resnet_tmpl()
    else:
        raise NotImplementedError
        
    # get model & param structure for backbone
    start, end = 0, len(model.layers) - 2
    backbone = Sequential(model.layers[start:end])
    backbone_params = slice_variables(params, start, end)
    return backbone, backbone_params


def get_model_and_variables(model_arch: str, head_init_key: int):
    '''
    Get model and variables 
    1. Initialise inputs(shape=(1,image_size,image_size,3))
    2. Get backbone and params
    3. Apply backbone model and get outputs
    4. Initialise head
    5. Create final model using backbone and head
    6. Combine params from backbone and head
    
    INPUT model_arch, head_init_key
    RETURNS  model, variables 
    '''
    
    #backbone
    inputs = jnp.ones((1, Config['IMAGE_SIZE'],Config['IMAGE_SIZE'], 3), jnp.float32)
    backbone, backbone_params = _get_backbone_and_params(model_arch)
    key = jax.random.PRNGKey(head_init_key)
    backbone_output = backbone.apply(backbone_params, inputs, mutable=False)
    
    #head
    head_inputs = jnp.ones((1, backbone_output.shape[-1]), jnp.float32)
    head = Head()
    head_params = head.init(key, head_inputs, train=False)
    
    #final model
    model = Model(backbone, head)
    variables = FrozenDict({
        'params': {
            'backbone': backbone_params['params'],
            'head': head_params['params']
        },
        'batch_stats': {
            'backbone': backbone_params['batch_stats'],
            'head': head_params['batch_stats']
        }
    })
    return model, variables

In [10]:
# Loading in pre-trained ResNet18 Model

model, variables = get_model_and_variables('resnet18', 0)
inputs = jnp.ones((1,Config['IMAGE_SIZE'], Config['IMAGE_SIZE'],3), jnp.float32)
key = jax.random.PRNGKey(0)
o = model.apply(variables, inputs, train=False, mutable=False)

Downloading: "https://github.com/pytorch/vision/zipball/v0.6.0" to /Users/edanbash/.cache/torch/hub/v0.6.0.zip
Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /Users/edanbash/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

: 

: 

# Background 
Pre-training <br>
Hyperparameter Search <br>
Ensemble Learning <br>
Greedy Algorithms <br> 
Vision Transformers <br>

### Part A. Fine Tuning a single ViT model 

### Part B. Fine-Tuning various models (and pick the best one)


### Part C. Ensemble Learning

### Part D. Uniform Soup

### Part E. Greedy Soup