<a href="https://colab.research.google.com/github/mhrgroup/course_self_supervised_learning/blob/main/Section%2004%3A%20Self-Supervised%20Learning/ssl_section04_lecture08.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Lecture 08: Supervised Contrastive Pretext, Experiment 2**

By the end of this lecture, you will be able to:

1. Address a labeling problem with supervised contrastive pretext with pretrained parameters.

# **8.1. Experiment 2**
---
* This experiment is similar to the previous one!
* We assume only 1000 training images are labeled in CIFAR-10.
* We develop a supervised contrastive pretext (prx) model using all training inputs and fine-tune it on the 1000 labeled images in the downstream (dwm) task.
* We then label the testing images using the fine-tuned model.
* We assume that there is a trained network on similar data distribution. Hence, our model has pretrained parameters.
* We compare this model with the result of a fairly similar fully supervised (fsp) model trained on the 1000 labeled data.
* Note: we have to develop three models here, model_fsp, model_prx, and model_dwm.

> **Abbreviations:**
* acc: accuracy
*	datain: input data
*	dataou: output data
*	dwm: downstream
*	fnt: fine-tuning
*	fsp: fully supervised learning
* lr: learning rate
*	prx: pretext
*	te: testing
*	tf: tensorflow
*	tr: training
*	trf: transfer learning

In [None]:
#@title Install necessary libraries & restart the session

# Install the required libraries using the `pip` package manager.
!pip install tensorflow==2.15

# Import the time module to add a delay before restarting the session.
import time

# Import `clear_output` from IPython to clear the notebook output, ensuring a clean display for the user.
from IPython.display import clear_output

# Clear the output after the packages are installed to make the notebook cleaner.
clear_output()

# Print a message to let the user know that the libraries are installed & the session will restart.
print("Necessary Libraries are Installed. Restarting the session!")

# Add a short delay (1 second) before restarting to allow the message to be displayed to the user.
time.sleep(1)

# Import the `os` module to access low-level operating system functionality.
import os

# Use `os._exit(00)` to exit the current Python runtime environment forcefully.
# This effectively simulates a restart in notebook environments like Google Colab or Jupyter.
# After this command, the environment will be restarted & all the packages installed will be properly loaded.
os._exit(00)

In [None]:
#@title Import necessary libraries
import tensorflow as tf
import copy

In [None]:
#@title Hyper-parameters
num_labeled  = 1000

# learning rates
lr_fsp_trf   = 0.01
lr_fsp_fnt   = 0.0001

lr_prx_trf   = 0.01
lr_prx_fnt   = 0.00001

lr_dwm_trf   = 0.01
lr_dwm_fnt   = 0.0001


# batch sizes
batch_fsp_trf  = 128
batch_fsp_fnt  = 128

batch_prx_trf  = 128
batch_prx_fnt  = 128

batch_dwm_trf  = 128
batch_dwm_fnt  = 128


# epochs
epoch_fsp_trf  = 15
epoch_fsp_fnt  = 10

epoch_prx_trf  = 15
epoch_prx_fnt  = 10

epoch_dwm_trf  = 15
epoch_dwm_fnt  = 10


In [None]:
#@title Load and process the CIFAR-10 data
(datain_tr, dataou_tr), (datain_te, dataou_te) = tf.keras.datasets.cifar10.load_data()

datain_tr = datain_tr/255 # trasnform unit-8 values between 0 and 1
datain_te = datain_te/255 # trasnform unit-8 values between 0 and 1

dataou_tr = tf.keras.utils.to_categorical(dataou_tr)
dataou_te = tf.keras.utils.to_categorical(dataou_te)

# print shapes of data

print('Shape of datain_tr: {}'.format(datain_tr.shape))
print('Shape of datain_te: {}'.format(datain_te.shape))
print('Shape of dataou_tr: {}'.format(dataou_tr.shape))
print('Shape of dataou_te: {}'.format(dataou_te.shape))


In [None]:
#@title Create contrastive training inputs and outputs
'''
Let's pick an augmentation method, say, random rotation.
'''

fun_augment     = tf.keras.layers.RandomRotation(factor = 0.2)

datain_tr_augmented = fun_augment(datain_tr)

#concatenate the original and augmented training data for pretext (prx)
datain_tr_prx = tf.concat([datain_tr,datain_tr_augmented], axis = 0)

#contrastive outputs
dataou_tr_prx_positive = tf.ones((datain_tr.shape[0],1))
dataou_tr_prx_negative = tf.zeros((datain_tr_augmented.shape[0],1))

#Output concatenation based on the order of data points in datain_tr_prx
dataou_tr_prx = tf.concat([dataou_tr_prx_negative, dataou_tr_prx_positive], axis = 0)

#binary categorical for SoftMax:
dataou_tr_prx = tf.keras.utils.to_categorical(dataou_tr_prx)


In [None]:
#@title Limit the labeled training data

# randomly select num_labeled of training data
index_tr = tf.experimental.numpy.random.randint(0, datain_tr.shape[0], num_labeled)

datain_tr_labeled = datain_tr[index_tr,:,:,:]
dataou_tr_labeled = dataou_tr[index_tr,:]

datain_tr_fsp = copy.deepcopy(datain_tr_labeled)
dataou_tr_fsp = copy.deepcopy(dataou_tr_labeled)

datain_tr_dwm = copy.deepcopy(datain_tr_labeled)
dataou_tr_dwm = copy.deepcopy(dataou_tr_labeled)

# we have 50,000 training inputs; num_labeled of them are labeled


In [None]:
#@title Create model_fsp and model_dwm similar to DenseNet121

layerin = tf.keras.Input(shape=(datain_tr.shape[1],
                                datain_tr.shape[2],
                                datain_tr.shape[3]))

upscale = tf.keras.layers.Lambda(lambda x: tf.image.resize_with_pad(x,
                                                                    160,
                                                                    160,
                                                                    method=tf.image.ResizeMethod.BILINEAR))(layerin)

# load with ImageNet Weights
model_DenseNet121 = tf.keras.applications.DenseNet121(include_top  = False,
                                                      weights      = "imagenet",
                                                      input_shape  = (160,160,3),
                                                      input_tensor = upscale,
                                                      pooling      = 'max')

'''
We clone model_DenseNet121 to create model_base of our fully supervised model
and pretext model.

P.S. We clone the downstream model after the pretext task using
the pretext model.
'''

model_base_fsp =  tf.keras.models.clone_model(model_DenseNet121)
model_base_prx =  tf.keras.models.clone_model(model_DenseNet121)

'''
We set the parameters of model_base_fsp and model_base_prx to be the same as the
randomly generated parameters of model_DenseNet121 to have both model learning
start point the same for a fair comparison.
'''

model_base_fsp.set_weights(model_DenseNet121.get_weights())
model_base_prx.set_weights(model_DenseNet121.get_weights())


'''
Now we create batch norm layers of model_fsp and model_prx.
'''
layer_batchnorm_fsp = tf.keras.layers.BatchNormalization()
layer_batchnorm_prx = tf.keras.layers.BatchNormalization()

'''
Now we create output layers of model_fsp and model_prx.
'''
layerou_fsp = tf.keras.layers.Dense(dataou_tr_fsp.shape[-1],
                                    activation = 'softmax')

layerou_prx = tf.keras.layers.Dense(dataou_tr_prx.shape[-1],
                                    activation = 'softmax')


'''
Now we create model_fsp and model_prx.
'''
model_fsp   = tf.keras.models.Sequential([model_base_fsp,
                                          layer_batchnorm_fsp,
                                          layerou_fsp])

model_prx   = tf.keras.models.Sequential([model_base_prx,
                                          layer_batchnorm_prx,
                                          layerou_prx])


In [None]:
#@title Train the fsp model using transfer learning and fine-tuning

'''
We don't include any validation split here since we can verify overfitting
existence later using testing data.
'''

# transfer learning
model_base_fsp.trainable      = False
layer_batchnorm_fsp.trainable = False

model_fsp.compile(optimizer = tf.keras.optimizers.Adam(lr_fsp_trf),
                  loss      = 'categorical_crossentropy',
                  metrics   = ['accuracy'])

'''
We reserve the output layer's parameters of fsp model for dwm model.
'''

layerou_fsp_initial_parameters = copy.deepcopy(model_fsp.layers[2].weights)

model_fsp.summary()

history_fsp_trf = model_fsp.fit(datain_tr_fsp,
                                dataou_tr_fsp,
                                epochs           = epoch_fsp_trf,
                                batch_size       = batch_fsp_trf,
                                verbose          = 1,
                                shuffle          = True)

# fine-tuning
model_base_fsp.trainable      = True
layer_batchnorm_fsp.trainable = True

model_fsp.compile(optimizer = tf.keras.optimizers.Adam(lr_fsp_fnt),
                  loss      = 'categorical_crossentropy',
                  metrics   = ['accuracy'])


model_fsp.summary()

history_fsp_fnt = model_fsp.fit(datain_tr_fsp,
                                dataou_tr_fsp,
                                epochs           = epoch_fsp_fnt,
                                batch_size       = batch_fsp_fnt,
                                verbose          = 1,
                                shuffle          = True)

In [None]:
#@title Train the prx model using transfer learning and fine-tuning
'''
Note that we do not have testing pretext data, so we track overfitting by
including validation split. Also, since prx takes time, we define some early
stopping criteria using "callbacks."
'''

callback = tf.keras.callbacks.EarlyStopping(monitor              = 'val_loss',
                                            patience             = 1,
                                            restore_best_weights = True)

# transfer learning
model_base_prx.trainable      = False
layer_batchnorm_prx.trainable = False

model_prx.compile(optimizer = tf.keras.optimizers.Adam(lr_prx_trf),
                  loss      = 'categorical_crossentropy',
                  metrics   = ['accuracy'])

model_prx.summary()

history_prx_trf = model_prx.fit(datain_tr_prx,
                                dataou_tr_prx,
                                epochs           = epoch_prx_trf,
                                batch_size       = batch_prx_trf,
                                verbose          = 1,
                                shuffle          = True,
                                validation_split = 0.05,
                                callbacks        = [callback])

# fine-tuning

model_base_prx.trainable      = True
layer_batchnorm_prx.trainable = True

model_prx.compile(optimizer = tf.keras.optimizers.Adam(lr_prx_fnt),
                  loss      = 'categorical_crossentropy',
                  metrics   = ['accuracy'])

model_prx.summary()

history_prx_fnt = model_prx.fit(datain_tr_prx,
                                dataou_tr_prx,
                                epochs           = epoch_prx_fnt,
                                batch_size       = batch_prx_fnt,
                                verbose          = 1,
                                shuffle          = True,
                                validation_split = 0.05,
                                callbacks        = [callback])

In [None]:
#@title Create and train dwm model using transfer-learning and fine-tuning
'''
Downstream output layer
'''
layerou_dwm = tf.keras.layers.Dense(dataou_tr_dwm.shape[-1],
                                    activation = 'softmax')

'''
Now we create model_dwm using model_base_prx (yes, it is trained!),
layer_batchnorm_prx (yes, it is trained too!), and layerou_dwm, but for a fair
fsp and dwm comparison, we set the dwm parameters to be similar to
layerou_fsp_initial_parameters.

P.S. Note that at the transfer learning level, we don't update model_base_prx
and layer_batchnorm_prx  parameters and focus on learning layerou_dwm only.
'''

# Transfer learning
model_base_prx.trainable      = False
layer_batchnorm_prx.trainable = False

model_dwm   = tf.keras.models.Sequential([model_base_prx,
                                          layer_batchnorm_prx,
                                          layerou_dwm])

model_dwm.layers[2].set_weights(layerou_fsp_initial_parameters)

model_dwm.compile(optimizer = tf.keras.optimizers.Adam(lr_dwm_trf),
                  loss      = 'categorical_crossentropy',
                  metrics   = ['accuracy'])

model_dwm.summary()

history_dwm = model_dwm.fit(datain_tr_dwm,
                            dataou_tr_dwm,
                            epochs           = epoch_dwm_trf,
                            batch_size       = batch_dwm_trf,
                            verbose          = 1,
                            shuffle          = True)

# Fine-tuning
model_base_prx.trainable      = True
layer_batchnorm_prx.trainable = True

# We can fine-tune after certain model_base_prx layer!
# fine_tune_after = 430
# for layer in model_base_prx.layers[:fine_tune_after]:
#   layer.trainable = False

model_dwm.compile(optimizer = tf.keras.optimizers.Adam(lr_dwm_fnt),
                  loss      = 'categorical_crossentropy',
                  metrics   = ['accuracy'])

model_dwm.summary()

history_dwm = model_dwm.fit(datain_tr_dwm,
                            dataou_tr_dwm,
                            epochs           = epoch_dwm_fnt,
                            batch_size       = batch_dwm_fnt,
                            verbose          = 1,
                            shuffle          = True)


In [None]:
#@title Compute model_fsp and model_dwm testing accuracies
_, acc_te_fsp = model_fsp.evaluate(datain_te,
                                   dataou_te,
                                   batch_size = 128)

_, acc_te_dwm = model_dwm.evaluate(datain_te,
                                   dataou_te,
                                   batch_size = 128)

print('Accuracy of fsp: {:05.2f}%'.format(acc_te_fsp*100))
print('Accuracy of dwm: {:05.2f}%'.format(acc_te_dwm*100))

In [None]:
#@title Clean up memory
%reset

# **Lecture 08: Supervised Contrastive Pretext, Experiment 2**

In this lecture, you learned about:

1. How to address a labeling problem with supervised contrastive pretext with pretrained parameters.

> ***In the following lecture, we will learn about SimCLR of [Chen et al. (2020)](http://proceedings.mlr.press/v119/chen20j/chen20j.pdf), a powerful unsupervised contrastive learning technique.***