<a href="https://colab.research.google.com/github/mhrgroup/course_self_supervised_learning/blob/main/Section%2004%3A%20Self-Supervised%20Learning/ssl_section04_lecture09.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Lecture 09: SimCLR, An UnSupervised Contrastive Pretext Model**

By the end of this lecture, you will be able to:

1. Describe how SimCLR of [Chen et al. (2020)](http://proceedings.mlr.press/v119/chen20j/chen20j.pdf) works.
2. Develop SimCLR custom loss and training functions.

# **9.1. Pretext SimCLR**
---
* [Chen et al. (2020)](http://proceedings.mlr.press/v119/chen20j/chen20j.pdf) proposed a machine learning framework model for contrastive learning of visual representation (SimCLR).
* This unsupervised contrastive model showed a notable accuracy boost on various standard visual data benchmarks such as CIFAR10, CIFAR100, and Caltech-101 after fine-tuning.
* SimCLR model is composed of a multi-layer encoder, also referred to as the regressor, and a usually shallow neural network (i.e., dense layers) projector.
* The encoder is later transferred for the downstream task.
* The SimCLR model enriches pretext models by maximizing the agreement between pairs of augmentations’ representations and marginalizing the representations from augmentations of other data points in a batch.
* The idea is not limited to visual data and can be applied to data types such as temporal records.
* The following Figure shows the [Chen et al. (2020)](http://proceedings.mlr.press/v119/chen20j/chen20j.pdf) SimCLR pretext model along with a downstream classification model.

> <img src=	"https://raw.githubusercontent.com/mhrgroup/course_self_supervised_learning/main/images/simclr1.png"	width="500"/>


* The SimCLR superiority is because of
 * Using two augmentations instead of one to retrieve enricher feature representations in the encoder (last encoder’s hidden layer in Figure), and
 * Using a simple projector to obtain a profound transformation of the encoder’s representation for pretext task.

* In a training batch, for a particular data point, the SimCLR’s loss function increases augmentation representations’ proximity while marginalizing these representations from other data points’ augmentation's representations.

> <img src=	"https://raw.githubusercontent.com/mhrgroup/course_self_supervised_learning/main/images/simclralgorithm0.png"	width="400"/>


# **9.2. SimCLR Loss Loss and Training Functions**

---

* We first create a custom loss function.
* Next, we create a SimCLR training function.


## **9.2.1. Loss Function**

* Custom loss functions in TensorFlow require two inputs, real and estimated output values.

* SimCLR is an unsupervised model, so we create some dummy real output values with the same shape as the predicted ones.  

* The estimated output values are the projector's representation, denoted as z_estimate.

* Remember if N is the batch size the size of matrix z_real is 2*N.

* Our goal is to avoid "for loops" in loss and benefit from multi-threading and memory for training speed.

* The custom loss functions should be written in the tensor format using the TensorFlow library.

> **Abbreviations:**
*	datain: input data
*	ind: index
* tf: tensorflow

In [None]:
#@title Import necessary libraries
import tensorflow as tf
# tqdm provides a simple and convenient way to add progress bars to loops.
from tqdm import tqdm

In [None]:
#@title Custom loss function
# Updated on 20240506 for TensorFlow 2.15.0
@tf.function
def fun_simclr_loss(z_real, z_estimate):
    # The z_real parameter is not used since SimCLR is an unsupervised model.
    # This dummy variable is discarded.
    del z_real

    # Temperature parameter is set to 0.1. This hyperparameter can be tuned
    # for specific problems to control the smoothness of the output distribution.
    toe = .1

    # Calculate the number of projections, which is twice the batch size (2N).
    num = z_estimate.shape[0]

    # Generate two sets of indices for all possible pair combinations within the batch.
    # ind0 is a temporary variable holding the repeated range for generating indices.
    ind0 = tf.repeat(tf.expand_dims(tf.range(0, num), axis=0), num, axis=0)
    # Flatten ind0 to create a single list of indices for the first element of pairs.
    ind1 = tf.reshape(ind0, (num**2, 1))[:, 0]
    # Flatten the transpose of ind0 to create a single list of indices for the second element of pairs.
    ind2 = tf.reshape(tf.transpose(ind0), (num**2, 1))[:, 0]

    # The temporary index tensor is no longer needed after use.
    del ind0

    # Select the projections based on the first set of indices to form the first elements of pairs.
    vector_1 = tf.gather(z_estimate, ind1, axis=0)
    # The first set of indices is no longer needed after use.
    del ind1

    # Select the projections based on the second set of indices to form the second elements of pairs.
    vector_2 = tf.gather(z_estimate, ind2, axis=0)
    # The second set of indices is no longer needed after use.
    del ind2

    # Calculate the cosine similarity between each pair of projections and negate it to prepare for loss calculation.
    s = -tf.reshape(tf.keras.losses.cosine_similarity(vector_1, vector_2, axis=1), (num, num))

    # The vector tensors are no longer needed after computing similarities.
    del vector_1
    del vector_2

    # Calculate the nominator of the contrastive loss function for each pair.
    nom = tf.exp(s / toe)

    # Calculate the denominator of the contrastive loss function by excluding the self-similarity term.
    x1 = tf.exp(s / toe)
    x2 = 1 - tf.eye(num, dtype=tf.float32)
    denom = tf.repeat(tf.expand_dims(tf.math.reduce_sum(x1 * x2, axis=1), axis=1), num, axis=1)

    # The intermediate tensors used for the denominator are no longer needed.
    del s
    del x1
    del x2

    # Calculate the loss `l(i, j)` for each pair of projections using the computed nominator and denominator.
    l = -tf.math.log(nom / denom)

    # Cleanup intermediate tensors to save memory.
    del nom
    del denom

    # Prepare indices to extract the diagonal elements from the loss matrix, which correspond to the actual pair losses.
    ind_2k0 = tf.range(0, num, 2, dtype=tf.int32)  # Even indices
    ind_2k1 = tf.range(1, num, 2, dtype=tf.int32)  # Odd indices

    # Extract and sum the losses for the actual pairs using the prepared indices.
    loss_mat1_1 = tf.gather(l, ind_2k0, axis=0)
    loss_mat1_2 = tf.gather(loss_mat1_1, ind_2k1, axis=1)
    loss_mat1 = tf.linalg.diag_part(loss_mat1_2)

    loss_mat2_1 = tf.gather(l, ind_2k1, axis=0)
    loss_mat2_2 = tf.gather(loss_mat2_1, ind_2k0, axis=1)
    loss_mat2 = tf.linalg.diag_part(loss_mat2_2)

    # After extracting the necessary elements, the large loss tensor can be discarded.
    del l

    # Combine the individual loss components into a single tensor.
    loss_mat = loss_mat1 + loss_mat2

    # Compute the final loss by taking the sum over all individual losses and normalizing by the number of pairs.
    L = tf.math.reduce_sum(loss_mat) / num


    # Return the final computed loss.
    return L

In [None]:
#@title SimCLR training function
# Updated on 20240506 for TensorFlow 2.15.0
# At the records of the videos, we had TensorFlow version 2.9.0.
# The latest version of TensorFlow is 2.15.0 as of 20240506. As such, we updated
# our code significantly to increase its efficiency and make it aligned with
# TesnorFlow version 2.15.0. The code is commented on so that students can
# study the updates.

def fun_train_simclr(model, images, fun_augment_01, fun_augment_02,
                     epochs=100, batch_size = 16, verbose=1, patience=3, learning_rate=0.001):

    # Determine the output size of the model's last layer
    z_size = model.layers[-1].weights[-1].shape[0]

    # Initialize a list to keep track of the loss values for each epoch
    loss = []

    # Initialize the optimizer with the specified learning rate
    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

    # Create a TensorFlow Dataset
    dataset = tf.data.Dataset.from_tensor_slices(images)
    dataset = dataset.shuffle(buffer_size=10000)  # Adjust buffer size to your dataset size/available memory
    dataset = dataset.batch(batch_size)
    dataset = dataset.cache()  # Cache the data to avoid re-reading from disk/memory each epoch
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)  # Prefetch batches in the background

    # Loop through each epoch
    for epoch in range(epochs):
        # Initialize a variable to keep track of the running loss for the current epoch
        loss_running = 0

        # Initialize a progress bar using tqdm to visualize training progress
        pbar = tqdm(dataset,
                    desc  = f"SimCLR Training: {epoch + 1:03d}/{epochs}", # Description shown in the progress bar
                    ncols = 125, # Width of the progress bar
                    leave = True) # Whether the progress bar should remain after completion

        # Iterate over each batch in the dataset
        for batch_num, batch in enumerate(pbar, 1):
            if isinstance(batch, tuple):
                batch_in = batch[0] # Get the input data from the batch
            else:
                batch_in = batch

            # Apply two different augmentations to the input data
            x_tilda_01 = fun_augment_01(batch_in)
            x_tilda_02 = fun_augment_02(batch_in)

            # Concatenate the augmented data along the batch axis and reshape it for the model input
            x_tilda = tf.reshape(tf.concat([x_tilda_01, x_tilda_02], axis=0),
                                (x_tilda_01.shape[0] * 2, # Corrected to ensure the correct shape for concatenation
                                 x_tilda_01.shape[1], x_tilda_01.shape[2],
                                 x_tilda_01.shape[3]))

            # Create a dummy output tensor of the same size as the model's output
            z_real = tf.random.uniform((x_tilda.shape[0], z_size))

            # Open a GradientTape to record the operations for automatic differentiation
            with tf.GradientTape() as tape:
                # Pass the concatenated and augmented inputs through the model
                z_estimate = model(x_tilda, training=True)

                # Calculate the batch loss using the custom SimCLR loss function
                loss_batch = fun_simclr_loss(z_real, z_estimate)

            # Calculate the gradients of the loss with respect to the model's trainable variables
            gradients = tape.gradient(loss_batch, model.trainable_variables)

            # Apply the calculated gradients to the model's variables to minimize the loss
            optimizer.apply_gradients(zip(gradients, model.trainable_variables))

            # Update the running loss for the epoch
            loss_running += loss_batch.numpy()  # Ensure loss is a scalar by calling .numpy()

            # If verbose, update the progress bar with the current average loss
            if verbose:
                pbar.set_postfix(loss = f"{loss_running/batch_num:.6f}")

        # Append the average loss for this epoch to the loss history
        loss.append(loss_running / batch_num)

        # Check for early stopping: if there's no improvement in loss for a specified number of epochs, stop training
        if epoch >= patience:
            if loss[-1] > min(loss[:-patience]):
                print("Early stopping due to no improvement in loss.")
                break  # Exit the training loop

    # Return the trained model and the history of loss values
    return model, loss

# **Lecture 09: SimCLR, An UnSupervised Contrastive Model**

In this lecture, you learned about:

1. How SimCLR of [Chen et al. (2020)](http://proceedings.mlr.press/v119/chen20j/chen20j.pdf) works.
2. SimCLR custom loss and training functions.

> ***In the following lecture, we will see a labeling example using SimCLR unsupervised contrastive learning.***