# Finetuning Tutorial

## Overview
In this tutorial we are going to cover finetuning using FastEstimator. This tutorial is structured as follows:

* [Setting Things Up](#ta15setup)
    * [Define Reusable Methods](#ta15resuable)
* [Tensorflow Workflow](#ta15tfworkflow)
    * [Train Base Model](#ta15tftrain)
    * [Extending Base Model for finetuning](#ta15tfmodify)
        * [Import Pretrained Model](#ta15tffreeze)
        * [Extending Base Model](#ta15tfunfreeze)
        * [Combine Base Model and Finetune Model](#ta15tfcombine)
    * [Start Finetuning](#ta15tffinetune)
* [Pytorch Workflow](#ta15pytorchworkflow)
    * [Train Base Model](#ta15pytorchtrain)
    * [Extending Base Model for finetuning](#ta15pytorchmodify)
        * [Import Pretrained Model](#ta15pytorchfreeze)
        * [Extending Base Model](#ta15pytorchunfreeze)
        * [Combine Base Model and Finetune Model](#ta15torchcombine)
    * [Start Finetuning](#ta15pytorchfinetune)

### Setting Things Up <a id='ta15setup'></a>

#### First let's get some imports out of the way:

In [57]:
import os
import tempfile

import fastestimator as fe
from fastestimator.trace.metric import Accuracy
from fastestimator.op.numpyop.univariate import ChannelTranspose, CoarseDropout, Normalize, Onehot
from fastestimator.op.numpyop.meta import Sometimes
from fastestimator.op.numpyop.multivariate import HorizontalFlip, PadIfNeeded, RandomCrop
from fastestimator.schedule.schedule import EpochScheduler
from fastestimator.op.tensorop.loss import CrossEntropy
from fastestimator.op.tensorop.model import ModelOp, UpdateOp
from fastestimator.dataset.data import cifair100, cifair10
from fastestimator.op.tensorop.model import ModelOp, UpdateOp

from fastestimator.architecture.tensorflow import LeNet as lenet_tf
from tensorflow.python.keras import Sequential, layers
from tensorflow.keras import Model
from fastestimator.architecture.pytorch import LeNet as lenet_torch
import torch.nn as nn
from torch import load, Tensor
import torch.nn.functional as fn

#### Define Reusable Methods <a id='ta15resuable'></a>

In [3]:
def get_pipeline(dataset, num_classes, batch_size, mode='tf', min_height=40, min_width=40):

    train_data, eval_data = dataset.load_data()

    mean_value = (0.4914, 0.4822, 0.4465)
    std_value = (0.2471, 0.2435, 0.2616)

    ops = [ Normalize(inputs="x", outputs="x", mean=mean_value, std=std_value),
            PadIfNeeded(min_height=min_height, min_width=min_width, image_in="x", image_out="x", mode="train"),
            RandomCrop(32, 32, image_in="x", image_out="x", mode="train"),
            Sometimes(HorizontalFlip(image_in="x", image_out="x", mode="train")),
            CoarseDropout(inputs="x", outputs="x", mode="train", max_holes=1),
            Onehot(inputs="y", outputs="y", mode="train", num_classes=num_classes, label_smoothing=0.2)]

    if mode == 'torch':
        ops.append(ChannelTranspose(inputs="x", outputs="x"))
                
    return fe.Pipeline(
                train_data=train_data,
                eval_data=eval_data,
                batch_size=batch_size,
                ops=ops)

def get_network(model):
    return  fe.Network(ops=[
        ModelOp(model=model, inputs="x", outputs="y_pred"),
        CrossEntropy(inputs=("y_pred", "y"), outputs="ce"),
        UpdateOp(model=model, loss_name="ce")])

def get_estimator(pipeline, network, epochs, train_steps_per_epoch=None, eval_steps_per_epoch=None):
    traces = [Accuracy(true_key="y", pred_key="y_pred")]

    return fe.Estimator(pipeline=pipeline,
                                network=network,
                                epochs=epochs,
                                traces=traces,
                                log_steps=0,
                                train_steps_per_epoch=train_steps_per_epoch, 
                                eval_steps_per_epoch=eval_steps_per_epoch)

#### Let's load some default training parameters as well

In [4]:
#training parameters
epochs_pretrain = 10

epochs_finetune = 5

batch_size = 64

base_num_classes = 100 

finetune_num_classes = 10

model_dir = tempfile.gettempdir()

### Tensorflow Workflow <a id='ta15tfworkflow'></a>

#### Train Base Model <a id='ta15tftrain'></a>

Now that boring stuff is done, let's train our first base model. We are using tensorflow LeNet to train on cifar100 with 100 classes. We are training for 10 epochs and saving the model at the end of the training job.

In [5]:
tf_input_shape = (32, 32, 3)

model_tf_pretrain = fe.build(model_fn=lambda: lenet_tf(input_shape=tf_input_shape, classes=base_num_classes), optimizer_fn="adam")

pipeline_tf_pretrain = get_pipeline(cifair100, base_num_classes, batch_size)

network_tf_pretrain = get_network(model_tf_pretrain)

estimator_tf_pretrain = get_estimator(pipeline_tf_pretrain, network_tf_pretrain, epochs_pretrain)

estimator_tf_pretrain.fit(warmup=False)

fe.backend.save_model(model_tf_pretrain, save_dir=model_dir, model_name= "lenet_tf")

FastEstimator-Warn: Pipeline multiprocessing is disabled. OS must support the 'fork' start method.
    ______           __  ______     __  _                 __            
   / ____/___ ______/ /_/ ____/____/ /_(_)___ ___  ____ _/ /_____  _____
  / /_  / __ `/ ___/ __/ __/ / ___/ __/ / __ `__ \/ __ `/ __/ __ \/ ___/
 / __/ / /_/ (__  ) /_/ /___(__  ) /_/ / / / / / / /_/ / /_/ /_/ / /    
/_/    \__,_/____/\__/_____/____/\__/_/_/ /_/ /_/\__,_/\__/\____/_/     
                                                                        

FastEstimator-Warn: No ModelSaver Trace detected. Models will not be saved.
FastEstimator-Start: step: 1; logging_interval: 0; num_device: 0;
FastEstimator-Eval: step: 782; epoch: 1; accuracy: 0.1447; ce: 3.6391912;
FastEstimator-Eval: step: 1564; epoch: 2; accuracy: 0.2173; ce: 3.272597;
FastEstimator-Eval: step: 2346; epoch: 3; accuracy: 0.2578; ce: 3.0625503;
FastEstimator-Eval: step: 3128; epoch: 4; accuracy: 0.2893; ce: 2.9065394;
FastEstimator-Eval: st

'C:\\windows\\TEMP\\lenet_tf.h5'

#### Load a new dataset for finetuning 

For finetuning, We use FastEstimator API to load the ciFAIR-10 dataset. You can use your own dataset by updating `get_pipeline` method.

In [6]:
pipeline_tf_finetune = get_pipeline(cifair10, finetune_num_classes, batch_size)

FastEstimator-Warn: Pipeline multiprocessing is disabled. OS must support the 'fork' start method.


#### Extending Base Model for Finetuning <a id='ta15tfmodify'></a>

##### Import Pretrained Model <a id='ta15tffreeze'></a>

Now we are ready to extend our base model with finetuning task.

Let's load our pretrained weights saved in previous setup. The weights files are saved with `h5` extension, since we have given `lenet_tf` as model_name to the `save_model`  function the model weights are saved as `lenet_tf.h5`. 


In [7]:
weights_path = os.path.join(model_dir, "lenet_tf.h5")

pretrained_lenet_tf = lenet_tf(input_shape=tf_input_shape, classes=base_num_classes)

pretrained_lenet_tf.load_weights(weights_path)

##### Extending Base Model <a id='ta15tfunfreeze'></a>

Let's remove the classification head of pretrained model and build a backbone. We will be using `fe.build` to build a new fe model.

In [8]:
def get_tf_backbone(pretrained_model):

    model = Model(inputs=pretrained_model.inputs, outputs=pretrained_model.layers[-3].output)

    return model

backbone_tf = fe.build(model_fn=lambda: get_tf_backbone(pretrained_lenet_tf), optimizer_fn="adam")

Next, we will define a classification head that can be used for the finetuning task. This is simply two `Dense` layers.

In [71]:
def get_class_head(finetune_num_classes):
    return Sequential([layers.Dense(64, activation='relu', input_shape=(1024,)), 
                       layers.Dense(finetune_num_classes, activation='softmax')])
    
cls_head_tf_finetune = fe.build(model_fn=lambda: get_class_head(finetune_num_classes), optimizer_fn="adam")

##### Combine Base Model and Finetune Model <a id='ta15tfcombine'></a>

If you want to save the finetune model, we can combine the `Backbone Model` and the `Class Head Model` and provide it to ModelSaver later. 

In [75]:
def combined_tf_model(backbone_model, cls_head_finetune):

    backbone_output = backbone_model.layers[-1].output
    x = cls_head_finetune.layers[0](backbone_output)
    x = cls_head_finetune.layers[1](x)
    model = Model(inputs=backbone_model.inputs, outputs=x)
    return model

final_model_tf = fe.build(model_fn=lambda: combined_tf_model(backbone_tf, cls_head_tf_finetune),  optimizer_fn="adam")

#### Start Finetuning <a id='ta15tffinetune'></a>

For Finetuning, we want to train different part of the network in the following manner:
- epoch 1-3: `freeze` backbone, `train` classification head only
- epoch 4-end: `train` backbone and classification head `together`

Let's use EpochScheduler to define when backbone and class head weights are updated. UpdateOp is responsible for weight updating. 

In [11]:
network_tf_finetune = fe.Network(ops=[
                                ModelOp(model=backbone_tf, inputs="x", outputs="feature"),
                                ModelOp(model=cls_head_tf_finetune, inputs="feature", outputs="y_pred"),
                                CrossEntropy(inputs=("y_pred", "y"), outputs="ce", from_logits=True),
                                EpochScheduler({1: None, 4: UpdateOp(model=backbone_tf, loss_name="ce")}),
                                EpochScheduler({1: UpdateOp(model=cls_head_tf_finetune, loss_name="ce")})])

estimator_tf_finetune = get_estimator(pipeline_tf_finetune, network_tf_finetune, epochs_finetune)

Let's train our finetune model using pretrained weights on our new dataset.

In [25]:
estimator_tf_finetune.fit(warmup=False)

    ______           __  ______     __  _                 __            
   / ____/___ ______/ /_/ ____/____/ /_(_)___ ___  ____ _/ /_____  _____
  / /_  / __ `/ ___/ __/ __/ / ___/ __/ / __ `__ \/ __ `/ __/ __ \/ ___/
 / __/ / /_/ (__  ) /_/ /___(__  ) /_/ / / / / / / /_/ / /_/ /_/ / /    
/_/    \__,_/____/\__/_____/____/\__/_/_/ /_/ /_/\__,_/\__/\____/_/     
                                                                        

FastEstimator-Warn: No ModelSaver Trace detected. Models will not be saved.
FastEstimator-Start: step: 1; logging_interval: 0; num_device: 0;
FastEstimator-Eval: step: 782; epoch: 1; accuracy: 0.6993; ce: 0.9487608;
FastEstimator-Eval: step: 1564; epoch: 2; accuracy: 0.7094; ce: 0.93741965;
FastEstimator-Eval: step: 2346; epoch: 3; accuracy: 0.7071; ce: 0.9292741;
FastEstimator-Eval: step: 3128; epoch: 4; accuracy: 0.7089; ce: 0.938162;
FastEstimator-Eval: step: 3910; epoch: 5; accuracy: 0.6997; ce: 0.94834113;
FastEstimator-Finish: step: 3910; model1_lr:

Finally, let's save our finetuned model.

In [55]:
fe.backend.save_model(final_model_tf, save_dir=model_dir, model_name= "final_tf_finetune")

'C:\\windows\\TEMP\\final_tf_finetune.h5'

### Pytorch Workflow <a id='ta15pytorchworkflow'></a>

#### Train Base Model <a id='ta15pytorchtrain'></a>

Now that boring stuff is done, let's train our first base model. We are using pytorch LeNet to train on cifar100 with 100 classes. We are training for 10 epochs and saving the model at the end of the training job.

In [13]:
torch_input_shape = (3, 32, 32)

model_torch_pretrain = fe.build(model_fn=lambda: lenet_torch(input_shape=torch_input_shape, classes=base_num_classes), optimizer_fn="adam")

pipeline_torch_pretrain = get_pipeline(cifair100, base_num_classes, batch_size, 'torch')

network_torch_pretrain = get_network(model_torch_pretrain)

estimator_torch_pretrain = get_estimator(pipeline_torch_pretrain, network_torch_pretrain, epochs_pretrain)

estimator_torch_pretrain.fit()

fe.backend.save_model(model_torch_pretrain, save_dir=model_dir, model_name= "lenet_torch")

FastEstimator-Warn: Pipeline multiprocessing is disabled. OS must support the 'fork' start method.
    ______           __  ______     __  _                 __            
   / ____/___ ______/ /_/ ____/____/ /_(_)___ ___  ____ _/ /_____  _____
  / /_  / __ `/ ___/ __/ __/ / ___/ __/ / __ `__ \/ __ `/ __/ __ \/ ___/
 / __/ / /_/ (__  ) /_/ /___(__  ) /_/ / / / / / / /_/ / /_/ /_/ / /    
/_/    \__,_/____/\__/_____/____/\__/_/_/ /_/ /_/\__,_/\__/\____/_/     
                                                                        

FastEstimator-Warn: No ModelSaver Trace detected. Models will not be saved.
FastEstimator-Start: step: 1; logging_interval: 0; num_device: 0;
FastEstimator-Eval: step: 782; epoch: 1; accuracy: 0.1446; ce: 3.6544788;
FastEstimator-Eval: step: 1564; epoch: 2; accuracy: 0.1972; ce: 3.340201;
FastEstimator-Eval: step: 2346; epoch: 3; accuracy: 0.2335; ce: 3.1667938;
FastEstimator-Eval: step: 3128; epoch: 4; accuracy: 0.2721; ce: 2.971765;
FastEstimator-Eval: ste

'C:\\windows\\TEMP\\lenet_torch.pt'

##### Load a new dataset for finetuning 

For finetuning, We use FastEstimator API to load the ciFAIR-10 dataset. You can use your own dataset by changing `get_pipeline` method.

In [14]:
pipeline_torch_finetune = get_pipeline(cifair10, finetune_num_classes, batch_size, 'torch')

FastEstimator-Warn: Pipeline multiprocessing is disabled. OS must support the 'fork' start method.


#### Extending Base Model for finetuning <a id='ta15pytorchmodify'></a>

##### Import Pretrained Model<a id='ta15pytorchfreeze'></a>

Now we are ready to extend our base model with finetuning task.

Let's load our pretrained weights saved in our previous setup. The weights files are saved with h5 extension, since we have given `lenet_torch` as model_name to the `save_model`  function the model weights are saved as `lenet_torch.pt`. Replace it if you used different model_name in `save_model` method.


In [19]:
model_torch_pretrained = lenet_torch(input_shape=torch_input_shape, classes=base_num_classes)

model_torch_pretrained.load_state_dict(load(os.path.join(model_dir, 'lenet_torch.pt')))

<All keys matched successfully>

##### Extending Base Model <a id='ta15pytorchunfreeze'></a>

Let's remove the last layer of pretrained model and build a new backbone. We will be using fe.build to build a new fe model.

In [65]:
class BackboneTorch(nn.Module):
    def __init__(self, model_torch_pretrained) -> None:
        super().__init__()
        self.pool_kernel = 2
        self.backbone_layers = nn.Sequential(*(list(model_torch_pretrained.children())[:-2]))


    def forward(self, x: Tensor) -> Tensor:
        x = fn.relu(self.backbone_layers[0](x))
        x = fn.max_pool2d(x, self.pool_kernel)
        x = fn.relu(self.backbone_layers[1](x))
        x = fn.max_pool2d(x, self.pool_kernel)
        x = fn.relu(self.backbone_layers[2](x))
        return x

backbone_torch = fe.build(model_fn=lambda: BackboneTorch(model_torch_pretrained), optimizer_fn="adam") 

Next, we will define a classification head that can be used for the finetuning task. This is simply two `nn.Linear` layers.

In [66]:
class ClassifierHead(nn.Module):
    def __init__(self, classes=10):
        super().__init__()
        self.fc1 = nn.Linear(1024, 64)
        self.fc2 = nn.Linear(64, classes)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = fn.relu(self.fc1(x))
        x = fn.softmax(self.fc2(x), dim=-1)
        return x

# dimensions of last layer of backbone
cls_head_torch_finetune = fe.build(model_fn=lambda: ClassifierHead(classes=finetune_num_classes), optimizer_fn="adam")

##### Combine Base Model and Finetune Model <a id='ta15torchcombine'></a>

If you want to save the finetune model, we can combine the `Backbone Model` and the `Class Head Model` and provide it to ModelSaver later. 

In [67]:
class CombinedTorchModel(nn.Module):
    def __init__(self, backbone, cls_head):
        super().__init__()
        self.backbone = backbone
        self.cls_head = cls_head

    def forward(self, x):
        x = self.backbone(x)
        x = self.cls_head(x)
        return x

final_torch_model = fe.build(model_fn=lambda: CombinedTorchModel(backbone_torch, cls_head_torch_finetune), optimizer_fn=None) 

#### Start Finetuning <a id='ta15pytorchfinetune'></a>

For Finetuning, we want to train different part of the network in the following manner:
- epoch 1-3: `freeze` backbone, `train` classification head only
- epoch 4-end: `train` backbone and classification head `together`

Let's use EpochScheduler to define when backbone and class head weights are updated. UpdateOp is responsible for weight updating. 

In [23]:
network_torch_finetune = fe.Network(ops=[
                                ModelOp(model=backbone_torch, inputs="x", outputs="feature"),
                                ModelOp(model=cls_head_torch_finetune, inputs="feature", outputs="y_pred"),
                                CrossEntropy(inputs=("y_pred", "y"), outputs="ce", from_logits=True),
                                EpochScheduler({1: None, 1: UpdateOp(model=backbone_torch, loss_name="ce")}),
                                EpochScheduler({1: UpdateOp(model=cls_head_torch_finetune, loss_name="ce")})])

estimator_torch_finetune = get_estimator(pipeline_torch_finetune, network_torch_finetune, epochs_finetune)

Let's train our finetune model using pretrained weights on our new dataset.

In [24]:
estimator_torch_finetune.fit()

    ______           __  ______     __  _                 __            
   / ____/___ ______/ /_/ ____/____/ /_(_)___ ___  ____ _/ /_____  _____
  / /_  / __ `/ ___/ __/ __/ / ___/ __/ / __ `__ \/ __ `/ __/ __ \/ ___/
 / __/ / /_/ (__  ) /_/ /___(__  ) /_/ / / / / / / /_/ / /_/ /_/ / /    
/_/    \__,_/____/\__/_____/____/\__/_/_/ /_/ /_/\__,_/\__/\____/_/     
                                                                        

FastEstimator-Warn: No ModelSaver Trace detected. Models will not be saved.
FastEstimator-Start: step: 1; logging_interval: 0; num_device: 0;
FastEstimator-Eval: step: 782; epoch: 1; accuracy: 0.575; ce: 1.8876605;
FastEstimator-Eval: step: 1564; epoch: 2; accuracy: 0.608; ce: 1.8521833;
FastEstimator-Eval: step: 2346; epoch: 3; accuracy: 0.6366; ce: 1.8237936;
FastEstimator-Eval: step: 3128; epoch: 4; accuracy: 0.6447; ce: 1.8164377;
FastEstimator-Eval: step: 3910; epoch: 5; accuracy: 0.648; ce: 1.8117273;
FastEstimator-Finish: step: 3910; model7_lr: 0.0

Finally, let's save our finetuned model.

In [56]:
fe.backend.save_model(final_torch_model, save_dir=model_dir, model_name= "final_torch_finetune")

'C:\\windows\\TEMP\\final_torch_finetune.pt'

<a id='ta15finetune'></a>