In [None]:
#Importing useful libraries

import math
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import  models, optimizers, layers, activations
from tensorflow.keras.layers import Dense, Activation, Flatten, Dropout
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.metrics import categorical_crossentropy
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import InceptionV3, InceptionResNetV2, ResNet50, Xception

import wandb
from wandb.keras import WandbCallback


### Setting up a fixed seed value

In [None]:
seed = 100
tf.random.set_seed(seed)
np.random.seed(seed)

### Downloading and Unzipping data

In [None]:
!wget https://storage.googleapis.com/wandb_datasets/nature_12K.zip

In [None]:
!unzip "./nature_12K.zip"

### WandB login

In [None]:
wandb.login()

### Preparing data for training

In [None]:
def generate_data(augmentation=True, batch_size=64):
    dir_train = './inaturalist_12K/train'
    dir_test = './inaturalist_12K/val'

    if augmentation:   #data augmentation
        train_datagen = ImageDataGenerator(rescale=1./255,
                                          zoom_range=0.3,
                                          rotation_range=50,
                                          brightness_range=(0.2, 0.8),
                                          shear_range=0.2,
                                          width_shift_range=0.1,
                                          height_shift_range=0.2,
                                          horizontal_flip=True,
                                          vertical_flip=True,
                                          validation_split=0.1,)
        test_datagen = ImageDataGenerator(rescale=1./255)

    else:
        train_datagen = ImageDataGenerator(rescale=1./255, validation_split=0.1)
        test_datagen = ImageDataGenerator(rescale=1./255)

    train = train_datagen.flow_from_directory(dir_train, target_size=(128, 128), batch_size=batch_size, subset="training")
    val = train_datagen.flow_from_directory(dir_train, target_size=(128, 128), batch_size=batch_size, subset="validation")
    test = test_datagen.flow_from_directory(dir_test, target_size=(128, 128), batch_size=batch_size)
    
    return train, val, test;

### Function to set up a custom wandb sweep name

In [None]:
def setWandbName(model_name, dropout, batch_size, n_dense):
    
    name = "_".join(["model", str(model_name), "drop", str(dropout), "batch_size", str(batch_size), "n_dense", str(n_dense)])
    
    return name;

### Modifying the existing model

In [None]:
def modified_model(pre_trained_model, n_dense, dropout, freeze_before):
    
    input_size = (128, 128, 3)
    if pre_trained_model == 'IV3':
        temp_model = InceptionV3(input_shape = input_size, include_top = False, weights = 'imagenet') # leaving out the last layer since we have only 10 classes
    elif pre_trained_model == 'IRNV2':
        temp_model = InceptionResNetV2(input_shape = input_size, include_top = False, weights = 'imagenet')
    elif pre_trained_model == 'RN50':
        temp_model = ResNet50(input_shape = input_size, include_top = False, weights = 'imagenet')
    elif pre_trained_model == 'XCP':
        temp_model = Xception(input_shape = input_size, include_top = False, weights = 'imagenet')
        
    
    #freezing layers of pretrained model, except some last layers to make training faster
    freeze_point = len(temp_model.layers) - freeze_before
    for layer in temp_model.layers[:freeze_point]:
        layer.trainable = False
        
    new_model = Sequential()
    new_model.add(temp_model)
    new_model.add(Flatten())
    new_model.add(Dense(n_dense))
    new_model.add(Dropout(dropout))
    new_model.add(Activation("relu"))
    new_model.add(Dense(10))
    new_model.add(Activation("softmax"))
    
    #new_model.compile(loss='categorical_crossentropy', optimizer = 'adam', metrics=['categorical_accuracy'])
    return new_model
       

### Train Function to fine tune the network with a set of hyperparameters

In [None]:
def train(config=None):
    
    #Wandb settings
    wandb.init(project="Convolutional Neural Networks", entity="cs21s048-cs21s058")
    config = wandb.config
    wandb.run.name = setWandbName(model_name= config.pre_trained_model, dropout=config.dropout, batch_size=config.batch_size, n_dense=config.n_dense)

    
    train, val, test = generate_data(batch_size= config.batch_size)
    
    #new_model = modified_model(pre_trained_model)
    new_model = modified_model(pre_trained_model= config.pre_trained_model, n_dense= config.n_dense, dropout= config.dropout, freeze_before= config.freeze_before)
    new_model.compile(optimizer=keras.optimizers.Adam(config.learning_rate), loss='categorical_crossentropy', metrics='categorical_accuracy')
    new_model.fit(
        train,
        batch_size = config.batch_size,
        epochs = config.epochs,
        verbose = 1,
        validation_data= val,
        callbacks = [WandbCallback(),keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)]
    )

### WandB sweep config

In [None]:
sweep_config = {
    'name': 'PartB_final_sweep',
    'method': 'bayes', 
    'metric': {
      'name': 'val_categorical_accuracy',
      'goal': 'maximize'   
    },
    'parameters': {
       
       'pre_trained_model' :{
           'values' : ['IV3','IRNV2']#, 'RN50', 'XCP']
       },
        'freeze_before' : {
            'values': [50, 70,100]
        },
        'epochs' : {
            'values': [10]
        },
        'dropout': {
            'values': [0.2, 0.4, 0.6]
        },     
        'batch_size': {
            'values': [32, 64]
        },
        'n_dense':{
            'values': [64, 128, 256]
        },
        'learning_rate':{
            'values': [0.001, 0.0001]
        }
    }
}

In [None]:
#Creating the sweep:
sweep_id = wandb.sweep(sweep_config, project="Convolutional Neural Networks", entity="cs21s048-cs21s058")

### Fine tuning different models with wandb

In [None]:
wandb.agent('xicvicmo', function=train, count=100)