<a href="https://colab.research.google.com/github/mhrgroup/course_self_supervised_learning/blob/main/Section%2004%3A%20Self-Supervised%20Learning/ssl_section04_lecture06.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Lecture 06: Self-Supervised Learning**

By the end of this lecture, you will be able to:

1. Describe Self-Supervised Learning (SSL).
2. Describe SSL contrastive learning.
3. Describe SSL generative learning.

# **6.1. Self-Supervised Learning (SSL) in General**
---
Let’s focus again on the previous example: we have ten million images of cats and dogs, with only 10,000 (0.1%) manually labeled.

* ***The solutions mentioned in the previous lecture entirely ignore the unlabeled repository. Although unlabeled, there are 9,990,000 images there to learn from somehow.***


* ***The Self-Supervised Learning (SSL) (also known as representation learning) objective is (usually) to implement the repository’s labeled and unlabeled data together to develop a base model in a so-called pretext (prx) task and transfer and fine-tune that base model using the limited labeled data in a so-called downstream (dwm) task.***

* ***SSL ensures the base model, often referred to as the pretext model in SSL literature, is within the distribution of the current data repository.***

> But how to create a pretext model from both labeled and unlabeled data? Two primary SSL techniques exist to create a pretext model: **contrastive learning** and **generative learning**.

# **6.2. Contrastive Learning**
----

* **In contrastive learning**, the input data, labeled or unlabeled, are augmented automatically using an augmentation technique.
* A collection of popular augmentations in the image domain is shown in this figure:

> <img src=	"https://raw.githubusercontent.com/mhrgroup/course_self_supervised_learning/main/images/augmentation.png"	width="650"/>

> Image from [Chen et al. (2020)](http://proceedings.mlr.press/v119/chen20j/chen20j.pdf) (Original image cc-by: Von.grzanka).

* First, an augmentation is selected, say, cutout (also known as masking). Next, that augmentation is automatically applied to the input data, labeled or unlabeled.
* Hence, twenty million images are retrieved for our repository of cats and dogs: ten million original images and ten million augmented ones, disregard of being a dog or cat.
* In other words, there are ten million pairs of original and corresponding augmented images.

* A collection of popular augmentations in the signal (temporal records) domain is shown in this figure:

> <img src=	"https://raw.githubusercontent.com/mhrgroup/course_self_supervised_learning/main/images/signal.png"	width="300"/>

> ***Let's code some image augmentations using the CIFAR-10 data***

> **Abbreviations:**
*	datain: input data
*	dataou: output data
*	te: testing
*	tf: tensorflow
*	tr: training


In [None]:
#@title Install necessary libraries & restart the session

# Install the required libraries using the `pip` package manager.
!pip install tensorflow==2.15

# Import the time module to add a delay before restarting the session.
import time

# Import `clear_output` from IPython to clear the notebook output, ensuring a clean display for the user.
from IPython.display import clear_output

# Clear the output after the packages are installed to make the notebook cleaner.
clear_output()

# Print a message to let the user know that the libraries are installed & the session will restart.
print("Necessary Libraries are Installed. Restarting the session!")

# Add a short delay (1 second) before restarting to allow the message to be displayed to the user.
time.sleep(1)

# Import the `os` module to access low-level operating system functionality.
import os

# Use `os._exit(00)` to exit the current Python runtime environment forcefully.
# This effectively simulates a restart in notebook environments like Google Colab or Jupyter.
# After this command, the environment will be restarted & all the packages installed will be properly loaded.
os._exit(00)

In [None]:
#@title Import necessary libraries
import tensorflow as tf
import matplotlib as mt


In [None]:
#@title Load and process the CIFAR-10 data
(datain_tr, _), (_, _) = tf.keras.datasets.cifar10.load_data()

datain_tr = datain_tr/255 # trasnform unit-8 values between 0 and 1

print('Shape of datain_tr: {}'.format(datain_tr.shape))


In [None]:
#@title Plot function for image augmentation
def fun_plot(image_original, image_augmented, general_title):
  titles = ['original', 'augmented']
  images = [image_original, image_augmented]
  fig, ax = mt.pyplot.subplots(1,2, figsize=(4,3));
  fig.suptitle(general_title.title())

  fig.tight_layout()
  mt.pyplot.subplots_adjust(wspace=0.05, hspace=0.05)

  for i0 in range(2):
    ax[i0].imshow(images[i0])
    ax[i0].set_title(titles[i0].title())
    ax[i0].set_xticks([])
    ax[i0].set_yticks([])

  mt.pyplot.show()


In [None]:
#@title Image augmentation - random flip
'''
https://www.tensorflow.org/api_docs/python/tf/keras/layers/RandomFlip.
'''

fun_augment     = tf.keras.layers.RandomFlip(mode = 'horizontal_and_vertical')

image_original  = datain_tr[0,:,:,:]
image_augmented = fun_augment(image_original)

fun_plot(image_original, image_augmented, 'Random Flip')


In [None]:
#@title Image augmentation - random rotation
'''
https://www.tensorflow.org/api_docs/python/tf/keras/layers/RandomRotation.
'''

fun_augment     = tf.keras.layers.RandomRotation(factor = 0.2)

image_original  = datain_tr[0,:,:,:]
image_augmented = fun_augment(image_original)

fun_plot(image_original, image_augmented, 'Random Rotation')


In [None]:
#@title Image augmentation - random zoom
'''
https://www.tensorflow.org/api_docs/python/tf/keras/layers/RandomZoom.
'''

fun_augment     = tf.keras.layers.RandomZoom(height_factor = .75)

image_original  = datain_tr[0,:,:,:]
image_augmented = fun_augment(image_original)

fun_plot(image_original, image_augmented, 'Random Zoom')


In [None]:
#@title Image augmentation - random crop and resize
'''
https://www.tensorflow.org/api_docs/python/tf/keras/layers/RandomCrop.
'''

fun_augment_01  = tf.keras.layers.RandomCrop(height = 20, width = 20)
fun_augment_02  = tf.keras.layers.Resizing(height = datain_tr.shape[1],
                                           width = datain_tr.shape[2])

fun_augment     = tf.keras.Sequential([fun_augment_01, fun_augment_02])

image_original  = datain_tr[0,:,:,:]
image_augmented = fun_augment(image_original)

fun_plot(image_original, image_augmented, 'Crop & Resize')


In [None]:
#@title Image augmentation - random saturation (jitter)
'''
https://www.tensorflow.org/api_docs/python/tf/image/random_saturation.
'''

fun_augment     = lambda image: tf.cast(tf.image.random_saturation(image*255, lower = 1, upper = 5), dtype = tf.float32)/255
image_original  = datain_tr[0,:,:,:]

image_augmented = fun_augment(image_original)

fun_plot(image_original, image_augmented, 'Random Saturation (Jitter)')


In [None]:
#@title Create contrastive training inputs
'''
Let's pick an augmentation method, say, random rotation.
'''

fun_augment     = tf.keras.layers.RandomRotation(factor = 0.2)

datain_tr_augmented = fun_augment(datain_tr)

# Let's plot the first five augmented images
for i0 in range(5):
  fun_plot(datain_tr[i0,:,:,:], datain_tr_augmented[i0,:,:,:], 'Random Rotation')

# concatenate the original and augmented training data for pretext (prx)
datain_tr_prx = tf.concat([datain_tr,datain_tr_augmented], axis = 0)

print("Shape of original training inputs: {}".format(datain_tr.shape))
print("Shape of augmented training inputs: {}".format(datain_tr_augmented.shape))
print("Shape of pretext inputs: {}".format(datain_tr_prx.shape))


## 6.2.1. Supervised Contrastive Learning
* In **supervised contrastive learning**, an augmented image is pseudo-labeled positive (class 1), and the original image is pseudo-labeled negative (class 0).
* The pretext task’s goal is a pretext model to learn these pseudo-labels using all two-million images. Next, the pretext model is transferred and fine-tuned using the limited labeled data in the downstream task.

In [None]:
#@title Create pseudo labels
dataou_tr_prx_positive = tf.ones((datain_tr.shape[0],1))
dataou_tr_prx_negative = tf.zeros((datain_tr_augmented.shape[0],1))

#based on the order of data points in datain_tr_prx
dataou_tr_prx = tf.concat([dataou_tr_prx_negative, dataou_tr_prx_positive], axis = 0)

#binary categorical for SoftMax:
dataou_tr_prx = tf.keras.utils.to_categorical(dataou_tr_prx)

print("Shape of training pretext output: {}".format(dataou_tr_prx.shape))

## 6.2.1. Unsupervised Contrastive Learning
* In **unsupervised contrastive learning**, no pseudo-labeling is required.
* The pretext task’s goal is a pretext model to learn the contrast between feature representations of the twenty million images.
* Feature representations are usually retrieved from the pretext model’s last (or near-final) dense layers.
* For example, a pretext model’s loss function is to proximate feature representations of an original image and its augmentation in a training batch while marginalizing them from feature representations of other data points in the training batch.
* Here is the unsupervised contrastive SSL of [Chen et al. 2020](http://proceedings.mlr.press/v119/chen20j/chen20j.pdf), called simCLR:

> <img src=	"https://raw.githubusercontent.com/mhrgroup/course_self_supervised_learning/main/images/simclr1.png"	width="500"/>

# **6.3. Generative Learning**
* In **generative SSL**, a pretext model is retrieved from an unsupervised model, often with no augmentation.
* For example, the pretext task is to train a generative adversarial network (GAN) to generate synthetic input data (e.g., images of cats and dogs) near the distribution of the current repository.
* Next, the GAN’s discriminator is transferred and fine-tuned using the limited labeled data in the downstream task.
* Another is extracting an encoder from an autoencoder that precisely reconstructs the current input distribution to be transferred and fine-tuned using the limited labeled data in the downstream task.

> <img src=	"https://raw.githubusercontent.com/mhrgroup/course_self_supervised_learning/main/images/gan.png"	width="650"/>

> ***In this course, our focus is only on contrastive learning; by generalizing how to develop a contrastive learning pretext model, implementing generative SSL techniques becomes relatively easy.***

# **Lecture 06: Self-Supervised Learning**

In this lecture, you learned about:

1. Self-Supervised Learning (SSL).
2. Contrastive learning.
3. Generative learning.

***In the following lecture, we will see a labeling example using supervised contrastive learning.***