# Model Soups

In this assignment, you will be implementing [Model Soups](https://arxiv.org/pdf/2203.05482.pdf)


In [1]:
!pip install jax-resnet

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
from google.colab import drive
drive.mount('/content/drive/')
%cd /content/drive/My Drive/cs182_final

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).
/content/drive/.shortcut-targets-by-id/1sukHccN68Nh8Ll76UD_M9g2ElGhyGzty/cs182_final


In [15]:
import numpy as np
import torch

import jax
import optax
import flax
import jax.numpy as jnp
from jax import jit
from jax import lax
from jax_resnet import pretrained_resnet, slice_variables, Sequential
from flax.jax_utils import replicate, unreplicate
from flax.training import train_state, checkpoints
from flax import linen as nn
from flax.core import FrozenDict,frozen_dict
from flax.training.common_utils import shard

from functools import partial
from tqdm.notebook import tqdm

from utils import get_model_and_variables, load_datasets
from utils import Config, TrainState, Model, Head
from utils import create_train_state, zero_grads, create_mask, accuracy
from utils import train_step, val_step, test_step

In [4]:
%mkdir models182

In [5]:
# Loading in pre-trained ResNet18 Model

model, variables = get_model_and_variables('resnet18', 0)
inputs = jnp.ones((1,Config['IMAGE_SIZE'], Config['IMAGE_SIZE'],3), jnp.float32)
key = jax.random.PRNGKey(0)
o = model.apply(variables, inputs, train=False, mutable=False)

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.6.0


In [6]:
train_dataset,test_dataset = load_datasets()

In [7]:
batch_size = Config["BATCH_SIZE"]
num_train_steps = len(train_dataset)

In [8]:
loss_fn = optax.softmax_cross_entropy
eval_fn = accuracy

In [9]:
parallel_train_step = jax.pmap(train_step, axis_name='batch', donate_argnums=(0,))
parallel_val_step = jax.pmap(val_step, axis_name='batch', donate_argnums=(0,))
parallel_test_step = jax.pmap(test_step, axis_name='batch', donate_argnums=(0,))

# control randomness on dropout and update inside train_step
rng = jax.random.PRNGKey(0)
dropout_rng = jax.random.split(rng, jax.local_device_count())  # for parallelism

In [10]:
# Training Loop

def train(state, epochs, save_path):
  
  rng = jax.random.PRNGKey(0)
  dropout_rng = jax.random.split(rng, jax.local_device_count()) 

  train_acc = []
  valid_acc = []

  #for epoch_i in tqdm(range(epochs), desc=f"{epochs} epochs", position=0, leave=True):
  for epoch_i in range(epochs):
      print(epoch_i)
      # Training set
      train_loss, train_accuracy = [], []
      iter_n = len(train_dataset)
      
      with tqdm(total=iter_n, desc=f"{iter_n} iterations", leave=False) as progress_bar:
          for _batch in train_dataset:
              batch=_batch[0]  # train_dataset is tuple containing (image,labels)
              labels=_batch[1]

              batch = jnp.array(batch, dtype=jnp.float32)
              labels = jnp.array(labels, dtype=jnp.float32)
              
              batch, labels = shard(batch), shard(labels)
            
              # backprop and update param & batch statsp
              
              state, train_metadata, dropout_rng = parallel_train_step(state, batch, labels, dropout_rng)
              train_metadata = unreplicate(train_metadata)
              
              # update train statistics
              _train_loss, _train_top1_acc = map(float, [train_metadata['loss'], *train_metadata['accuracy']])
              train_loss.append(_train_loss)
              train_accuracy.append(_train_top1_acc)
              progress_bar.update(1)
              del(batch)
              del(labels)
              
      avg_train_loss = sum(train_loss)/len(train_loss)
      avg_train_acc = sum(train_accuracy)/len(train_accuracy)
      train_acc.append(avg_train_acc)
      print(f"[{epoch_i+1}/{epochs}] Train Loss: {avg_train_loss:.03} | Train Accuracy: {avg_train_acc:.03}")

      # Saves the model's weights
      checkpoints.save_checkpoint(ckpt_dir=save_path, target=state, step=epoch_i, overwrite=True)
      
      # Validation set
      valid_accuracy = []
      iter_n = len(test_dataset)
      with tqdm(total=iter_n, desc=f"{iter_n} iterations", leave=False) as progress_bar:
          for _batch in test_dataset:
              batch = _batch[0]
              labels = _batch[1]

              batch = jnp.array(batch, dtype=jnp.float32)
              labels = jnp.array(labels, dtype=jnp.float32)

              batch, labels = shard(batch), shard(labels)
              metric = parallel_val_step(state, batch, labels)[0]
              valid_accuracy.append(metric)
              progress_bar.update(1)
              del(batch)
              del(labels)

      avg_valid_acc = sum(valid_accuracy)/len(valid_accuracy)
      avg_valid_acc = np.array(avg_valid_acc)[0]
      valid_acc.append(avg_valid_acc)
      print(f"[{epoch_i+1}/{Config['N_EPOCHS']}] Valid Accuracy: {avg_valid_acc:.03}")
  return train_acc, valid_acc

### Fine Tuning k Resnet models

In [11]:
def get_optimizer(optimizer_type, lr, weight_decay):
    optimizer = None
    if optimizer_type == 0:
      optimizer = optax.rmsprop(
          learning_rate=lr, 
          decay=weight_decay, 
          eps=1e-6
        )
    elif optimizer_type == 1:
      optimizer = optax.adam(
          learning_rate=lr,
          b1=0.9, b2=0.999, 
          eps=1e-6
      )
    elif optimizer_type == 2:
      optimizer = optax.adamw(
          learning_rate=lr,
          b1=0.9, b2=0.999, 
          eps=1e-6, weight_decay=weight_decay
      )
    
    return optimizer

In [12]:
learning_rates = [1e-4, 3e-5, 2e-5, 3e-6, 1e-6, 1e-7]
weight_decays = [1e-2, 1e-3, 2e-3, 1e-4, 1e-5]
optimizers = ['rmsprop', 'adam', 'adamw']

In [13]:
# Utility function for creating and fine tuning k models with different hyperparmeters
def create_k_models(k=16):
    models = []
    accs = {}
    for i in range(k):
        # Load the pre-trained model
        model, variables = get_model_and_variables('resnet18', 0)

        # Randomly choose the hyperparameters
        lr_index = np.random.choice(range(len(learning_rates)))
        lr = learning_rates[lr_index]                 
        
        weight_index = np.random.choice(range(len(weight_decays)))
        weight_decay = weight_decays[weight_index]
        
        #epochs = np.random.choice(range(10,21))
        epochs = 1

        optimizer_type = np.random.choice(range(3))
        optimizer = get_optimizer(optimizer_type, lr, weight_decay)

        optimizer = optax.multi_transform(
            {'optim': optimizer, 'zero': zero_grads()},
            create_mask(variables['params'], lambda s: s.startswith('backbone'))
        )

        # Create the TrainState
        state = create_train_state(model, variables, optimizer, loss_fn, accuracy)

        # required for parallelism
        state = replicate(state)

        print(f"Model_{i}, Optimizer: {optimizers[optimizer_type]}, LR: {lr}, Weight Decay: {weight_decay}")
        
        # Fine tune the model 
        accs[i] = train(state, epochs, 'models182/' + f'model_{i}')    

    return models

In [16]:
create_k_models()

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.6.0


Model_0, Optimizer: rmsprop, LR: 3e-06, Weight Decay: 1e-05
0


1563 iterations:   0%|          | 0/1563 [00:00<?, ?it/s]

KeyboardInterrupt: ignored

In [None]:
torch.cuda.empty_cache() 

In [None]:
optimizer = get_optimizer(0, 2e-5, 0.001)
model, variables = get_model_and_variables('resnet18', 0)
state = create_train_state(model, variables, optimizer)

# required for parallelism
state = replicate(state)
restored_state = checkpoints.restore_checkpoint(ckpt_dir='models182/model_0/checkpoint_0', target=None)

In [None]:
print(restored_state.keys())