<a href="https://colab.research.google.com/github/mhrgroup/course_self_supervised_learning/blob/main/Section%2004%3A%20Self-Supervised%20Learning/ssl_section04_lecture10.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Lecture 10: SimCLR Experiment**

By the end of this lecture, you will be able to:

1. Address a labeling problem with SimCLR using a pretrained encoder.

# **10.1. Experiment**
---
* This experiment is similar to the previous ones too!
* We assume only 1000 training images are labeled in CIFAR-10.
* We develop a SimCLR pretext (prx) model using all training inputs and fine-tune it on the 1000 labeled images in the downstream (dwm) task.
* We then label the testing images using the fine-tuned model.
* We assume that, there is a trained encoder (regressor) on similar data distribution. Hence, our model has pretrained parameters.
* We compare this model with the result of a fairly similar fully supervised (fsp) model trained on the 1000 labeled data.
* Note: we have to develop three models here, model_fsp, model_prx, and model_dwm.

> **Abbreviations:**
* acc: accuracy
*	datain: input data
*	dataou: output data
*	dwm: downstream
*	fnt: fine-tuning
*	fsp: fully supervised learning
* ind:index
* lr: learning rate
*	prx: pretext
*	te: testing
*	tf: tensorflow
*	tr: training
*	trf: transfer learning

In [None]:
#@title Import necessary libraries
import tensorflow as tf
import copy
import warnings

from IPython.display import clear_output
# tqdm provides a simple and convenient way to add progress bars to loops.
from tqdm import tqdm

warnings.filterwarnings('ignore')

clear_output()

In [None]:
#@title Hyper-parameters
num_labeled  = 1000

# learning rates
lr_fsp_trf   = 0.01
lr_fsp_fnt   = 0.0001

lr_prx_trf   = 0.01
lr_prx_fnt   = 0.000001

lr_dwm_trf   = 0.01
lr_dwm_fnt   = 0.0001


# batch sizes
batch_fsp_trf  = 64
batch_fsp_fnt  = 64

batch_prx_trf  = 32
batch_prx_fnt  = 32

batch_dwm_trf  = 64
batch_dwm_fnt  = 64


# epochs
epoch_fsp_trf  = 15
epoch_fsp_fnt  = 10

epoch_prx_trf  = 15
epoch_prx_fnt  = 10

epoch_dwm_trf  = 15
epoch_dwm_fnt  = 10


In [None]:
#@title Load and process the CIFAR-10 data
(datain_tr, dataou_tr), (datain_te, dataou_te) = tf.keras.datasets.cifar10.load_data()

datain_tr = datain_tr/255 # trasnform unit-8 values between 0 and 1
datain_te = datain_te/255 # trasnform unit-8 values between 0 and 1

dataou_tr = tf.keras.utils.to_categorical(dataou_tr)
dataou_te = tf.keras.utils.to_categorical(dataou_te)

print('Shape of datain_tr: {}'.format(datain_tr.shape))
print('Shape of datain_te: {}'.format(datain_te.shape))
print('Shape of dataou_tr: {}'.format(dataou_tr.shape))
print('Shape of dataou_te: {}'.format(dataou_te.shape))


In [None]:
#@title Pick two augmentation functions

# Crop and Resize
fun_augment_a  = tf.keras.layers.RandomCrop(height = 20, width = 20)
fun_augment_b  = tf.keras.layers.Resizing(height = datain_tr.shape[1],
                                          width = datain_tr.shape[2])

fun_augment_01 = tf.keras.Sequential([fun_augment_a, fun_augment_b])

# Random rotation
fun_augment_02     = tf.keras.layers.RandomRotation(factor = 0.2)

In [None]:
#@title Limit the labeled training data

# Randomly select num_labeled of training data
index_tr = tf.experimental.numpy.random.randint(0,
                                                datain_tr.shape[0],
                                                num_labeled)

datain_tr_labeled = datain_tr[index_tr,:,:,:]
dataou_tr_labeled = dataou_tr[index_tr,:]

datain_tr_fsp = copy.deepcopy(datain_tr_labeled)
dataou_tr_fsp = copy.deepcopy(dataou_tr_labeled)

datain_tr_prx = copy.deepcopy(datain_tr)

datain_tr_dwm = copy.deepcopy(datain_tr_labeled)
dataou_tr_dwm = copy.deepcopy(dataou_tr_labeled)

# We have 50,000 training inputs; num_labeled of them are labeled


In [None]:
#@title Custom loss function
# Updated on 20240506 for TensorFlow 2.15.0
@tf.function
def fun_simclr_loss(z_real, z_estimate):
    # The z_real parameter is not used since SimCLR is an unsupervised model.
    # This dummy variable is discarded.
    del z_real

    # Temperature parameter is set to 0.1. This hyperparameter can be tuned
    # for specific problems to control the smoothness of the output distribution.
    toe = .1

    # Calculate the number of projections, which is twice the batch size (2N).
    num = z_estimate.shape[0]

    # Generate two sets of indices for all possible pair combinations within the batch.
    # ind0 is a temporary variable holding the repeated range for generating indices.
    ind0 = tf.repeat(tf.expand_dims(tf.range(0, num), axis=0), num, axis=0)
    # Flatten ind0 to create a single list of indices for the first element of pairs.
    ind1 = tf.reshape(ind0, (num**2, 1))[:, 0]
    # Flatten the transpose of ind0 to create a single list of indices for the second element of pairs.
    ind2 = tf.reshape(tf.transpose(ind0), (num**2, 1))[:, 0]

    # The temporary index tensor is no longer needed after use.
    del ind0

    # Select the projections based on the first set of indices to form the first elements of pairs.
    vector_1 = tf.gather(z_estimate, ind1, axis=0)
    # The first set of indices is no longer needed after use.
    del ind1

    # Select the projections based on the second set of indices to form the second elements of pairs.
    vector_2 = tf.gather(z_estimate, ind2, axis=0)
    # The second set of indices is no longer needed after use.
    del ind2

    # Calculate the cosine similarity between each pair of projections and negate it to prepare for loss calculation.
    s = -tf.reshape(tf.keras.losses.cosine_similarity(vector_1, vector_2, axis=1), (num, num))

    # The vector tensors are no longer needed after computing similarities.
    del vector_1
    del vector_2

    # Calculate the nominator of the contrastive loss function for each pair.
    nom = tf.exp(s / toe)

    # Calculate the denominator of the contrastive loss function by excluding the self-similarity term.
    x1 = tf.exp(s / toe)
    x2 = 1 - tf.eye(num, dtype=tf.float32)
    denom = tf.repeat(tf.expand_dims(tf.math.reduce_sum(x1 * x2, axis=1), axis=1), num, axis=1)

    # The intermediate tensors used for the denominator are no longer needed.
    del s
    del x1
    del x2

    # Calculate the loss `l(i, j)` for each pair of projections using the computed nominator and denominator.
    l = -tf.math.log(nom / denom)

    # Cleanup intermediate tensors to save memory.
    del nom
    del denom

    # Prepare indices to extract the diagonal elements from the loss matrix, which correspond to the actual pair losses.
    ind_2k0 = tf.range(0, num, 2, dtype=tf.int32)  # Even indices
    ind_2k1 = tf.range(1, num, 2, dtype=tf.int32)  # Odd indices

    # Extract and sum the losses for the actual pairs using the prepared indices.
    loss_mat1_1 = tf.gather(l, ind_2k0, axis=0)
    loss_mat1_2 = tf.gather(loss_mat1_1, ind_2k1, axis=1)
    loss_mat1 = tf.linalg.diag_part(loss_mat1_2)

    loss_mat2_1 = tf.gather(l, ind_2k1, axis=0)
    loss_mat2_2 = tf.gather(loss_mat2_1, ind_2k0, axis=1)
    loss_mat2 = tf.linalg.diag_part(loss_mat2_2)

    # After extracting the necessary elements, the large loss tensor can be discarded.
    del l

    # Combine the individual loss components into a single tensor.
    loss_mat = loss_mat1 + loss_mat2

    # Compute the final loss by taking the sum over all individual losses and normalizing by the number of pairs.
    L = tf.math.reduce_sum(loss_mat) / num


    # Return the final computed loss.
    return L

In [None]:
#@title SimCLR training function
# Updated on 20240506 for TensorFlow 2.15.0
# At the records of the videos, we had TensorFlow version 2.9.0.
# The latest version of TensorFlow is 2.15.0 as of 20240506. As such, we updated
# our code significantly to increase its efficiency and make it aligned with
# TesnorFlow version 2.15.0. The code is commented on so that students can
# study the updates.

def fun_train_simclr(model, images, fun_augment_01, fun_augment_02,
                     epochs=100, batch_size = 16, verbose=1, patience=3, learning_rate=0.001):

    # Determine the output size of the model's last layer
    z_size = model.layers[-1].weights[-1].shape[0]

    # Initialize a list to keep track of the loss values for each epoch
    loss = []

    # Initialize the optimizer with the specified learning rate
    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

    # Create a TensorFlow Dataset
    dataset = tf.data.Dataset.from_tensor_slices(images)
    dataset = dataset.shuffle(buffer_size=10000)  # Adjust buffer size to your dataset size/available memory
    dataset = dataset.batch(batch_size)
    dataset = dataset.cache()  # Cache the data to avoid re-reading from disk/memory each epoch
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)  # Prefetch batches in the background

    # Loop through each epoch
    for epoch in range(epochs):
        # Initialize a variable to keep track of the running loss for the current epoch
        loss_running = 0

        # Initialize a progress bar using tqdm to visualize training progress
        pbar = tqdm(dataset,
                    desc  = f"SimCLR Training: {epoch + 1:03d}/{epochs}", # Description shown in the progress bar
                    ncols = 125, # Width of the progress bar
                    leave = True) # Whether the progress bar should remain after completion

        # Iterate over each batch in the dataset
        for batch_num, batch in enumerate(pbar, 1):
            if isinstance(batch, tuple):
                batch_in = batch[0] # Get the input data from the batch
            else:
                batch_in = batch

            # Apply two different augmentations to the input data
            x_tilda_01 = fun_augment_01(batch_in)
            x_tilda_02 = fun_augment_02(batch_in)

            # Concatenate the augmented data along the batch axis and reshape it for the model input
            x_tilda = tf.reshape(tf.concat([x_tilda_01, x_tilda_02], axis=0),
                                (x_tilda_01.shape[0] * 2, # Corrected to ensure the correct shape for concatenation
                                 x_tilda_01.shape[1], x_tilda_01.shape[2],
                                 x_tilda_01.shape[3]))

            # Create a dummy output tensor of the same size as the model's output
            z_real = tf.random.uniform((x_tilda.shape[0], z_size))

            # Open a GradientTape to record the operations for automatic differentiation
            with tf.GradientTape() as tape:
                # Pass the concatenated and augmented inputs through the model
                z_estimate = model(x_tilda, training=True)

                # Calculate the batch loss using the custom SimCLR loss function
                loss_batch = fun_simclr_loss(z_real, z_estimate)

            # Calculate the gradients of the loss with respect to the model's trainable variables
            gradients = tape.gradient(loss_batch, model.trainable_variables)

            # Apply the calculated gradients to the model's variables to minimize the loss
            optimizer.apply_gradients(zip(gradients, model.trainable_variables))

            # Update the running loss for the epoch
            loss_running += loss_batch.numpy()  # Ensure loss is a scalar by calling .numpy()

            # If verbose, update the progress bar with the current average loss
            if verbose:
                pbar.set_postfix(loss = f"{loss_running/batch_num:.6f}")

        # Append the average loss for this epoch to the loss history
        loss.append(loss_running / batch_num)

        # Check for early stopping: if there's no improvement in loss for a specified number of epochs, stop training
        if epoch >= patience:
            if loss[-1] > min(loss[:-patience]):
                print("Early stopping due to no improvement in loss.")
                break  # Exit the training loop

    # Return the trained model and the history of loss values
    return model, loss

In [None]:
#@title Create model_fsp and model_dwm similar to DenseNet121

layerin = tf.keras.Input(shape=(datain_tr.shape[1],
                                datain_tr.shape[2],
                                datain_tr.shape[3]))

upscale = tf.keras.layers.Lambda(lambda x: tf.image.resize_with_pad(x,
                                                                    160,
                                                                    160,
                                                                    method=tf.image.ResizeMethod.BILINEAR))(layerin)

model_DenseNet121 = tf.keras.applications.DenseNet121(include_top  = False,
                                                      weights      = "imagenet",
                                                      input_shape  = (160,160,3),
                                                      input_tensor = upscale,
                                                      pooling      = 'max')

model_base_fsp =  tf.keras.models.clone_model(model_DenseNet121)
model_base_prx =  tf.keras.models.clone_model(model_DenseNet121) # encoder

model_base_fsp.set_weights(model_DenseNet121.get_weights())
model_base_prx.set_weights(model_DenseNet121.get_weights())


layer_batchnorm_fsp = tf.keras.layers.BatchNormalization()
layer_batchnorm_prx = tf.keras.layers.BatchNormalization()

'''
Now we create the SimCLR projector.
'''

layers_dense_prx = [tf.keras.Input(shape=(1024)),
                    tf.keras.layers.Dense(512, activation = 'relu'),
                    tf.keras.layers.Dense(128, activation = 'relu')]

model_projector = tf.keras.Sequential(layers_dense_prx)

'''
Now we create output layers of model_fsp.
'''

layerou_fsp = tf.keras.layers.Dense(dataou_tr_fsp.shape[-1], activation = 'softmax')
#layerou_prx = tf.keras.layers.Dense(dataou_tr_prx.shape[-1], activation = 'softmax')


'''
Now we create model_fsp and model_prx.
'''

model_fsp   = tf.keras.models.Sequential([model_base_fsp,
                                          layer_batchnorm_fsp,
                                          layerou_fsp])

model_prx   = tf.keras.models.Sequential([model_base_prx,
                                          layer_batchnorm_prx,
                                          model_projector])


In [None]:
#@title Train the fsp model using transfer learning and fine-tuning

# Transfer learning
model_base_fsp.trainable      = False
layer_batchnorm_fsp.trainable = False

model_fsp.compile(optimizer = tf.keras.optimizers.Adam(lr_fsp_trf),
                  loss      = 'categorical_crossentropy',
                  metrics   = ['accuracy'])

layerou_fsp_initial_parameters = copy.deepcopy(model_fsp.layers[2].weights)

model_fsp.summary()

history_fsp_trf = model_fsp.fit(datain_tr_fsp,
                                dataou_tr_fsp,
                                epochs           = epoch_fsp_trf,
                                batch_size       = batch_fsp_trf,
                                verbose          = 1,
                                shuffle          = True)

# Fine-tuning

model_base_fsp.trainable      = True
layer_batchnorm_fsp.trainable = True

model_fsp.compile(optimizer = tf.keras.optimizers.Adam(lr_fsp_fnt),
                  loss      = 'categorical_crossentropy',
                  metrics   = ['accuracy'])


model_fsp.summary()

history_fsp_fnt = model_fsp.fit(datain_tr_fsp,
                                dataou_tr_fsp,
                                epochs           = epoch_fsp_fnt,
                                batch_size       = batch_fsp_fnt,
                                verbose          = 1,
                                shuffle          = True)

In [None]:
#@title Train the prx model using transfer learning and fine-tuning
# Updated on 20240506
# Transfer learning
model_base_prx.trainable      = False
layer_batchnorm_prx.trainable = False

model_prx.compile(optimizer = tf.keras.optimizers.Adam(lr_prx_trf),
                  loss      = fun_simclr_loss,
                  metrics   = 'mean_squared_error')

model_prx.summary()

model_prx, _ = fun_train_simclr(model_prx,
                                datain_tr_prx,
                                fun_augment_01,
                                fun_augment_02,
                                epochs     = epoch_prx_trf,
                                batch_size = batch_prx_trf,
                                verbose    = 1,
                                patience   = 1,
                                learning_rate = lr_prx_trf)


# Fine-tuning

model_base_prx.trainable      = True
layer_batchnorm_prx.trainable = True

model_prx.compile(optimizer = tf.keras.optimizers.Adam(lr_prx_fnt),
                  loss      = fun_simclr_loss,
                  metrics   = 'mean_squared_error')

model_prx.summary()

model_prx, _ = fun_train_simclr(model_prx,
                                datain_tr_prx,
                                fun_augment_01,
                                fun_augment_02,
                                epochs     = epoch_prx_fnt,
                                batch_size = batch_prx_fnt,
                                verbose    = 1,
                                patience   = 1,
                                learning_rate = lr_prx_fnt)


In [None]:
#@title Create and train dwm model using transfer-learning and fine-tuning

layerou_dwm = tf.keras.layers.Dense(dataou_tr_dwm.shape[-1],
                                    activation = 'softmax')

model_base_prx.trainable      = False
layer_batchnorm_prx.trainable = False

model_dwm   = tf.keras.models.Sequential([model_base_prx,
                                          layer_batchnorm_prx,
                                          layerou_dwm])

model_dwm.layers[2].set_weights(layerou_fsp_initial_parameters)

# Transfer learning
model_dwm.compile(optimizer = tf.keras.optimizers.Adam(lr_dwm_trf),
                  loss      = 'categorical_crossentropy',
                  metrics   = ['accuracy'])

model_dwm.summary()

history_dwm = model_dwm.fit(datain_tr_dwm,
                            dataou_tr_dwm,
                            epochs           = epoch_dwm_trf,
                            batch_size       = batch_dwm_trf,
                            verbose          = 1,
                            shuffle          = True)

#Fine-tuning
model_base_prx.trainable      = True
layer_batchnorm_prx.trainable = True

# We can fine-tune after certain model_base_prx layer!
# fine_tune_after = 430
# for layer in model_base_prx.layers[:fine_tune_after]:
#   layer.trainable = False

model_dwm.compile(optimizer = tf.keras.optimizers.Adam(lr_dwm_fnt),
                  loss      = 'categorical_crossentropy',
                  metrics   = ['accuracy'])

model_dwm.summary()

history_dwm = model_dwm.fit(datain_tr_dwm,
                            dataou_tr_dwm,
                            epochs           = epoch_dwm_fnt,
                            batch_size       = batch_dwm_fnt,
                            verbose          = 1,
                            shuffle          = True)


In [None]:
#@title Compute model_fsp and model_dwm testing accuracies
_, acc_te_fsp = model_fsp.evaluate(datain_te,
                                   dataou_te,
                                   batch_size = 128)

_, acc_te_dwm = model_dwm.evaluate(datain_te,
                                   dataou_te,
                                   batch_size = 128)

print('Accuracy of fsp: {:05.2f}%'.format(acc_te_fsp*100))
print('Accuracy of dwm: {:05.2f}%'.format(acc_te_dwm*100))

In [None]:
#@title Clean up memory
%reset

# **Lecture 10: SimCLR, An Unsupervised Contrastive Pretext, Experiment**


In this lecture, you learned about:

1. A labeling problem with SimCLR using a pretrained encoder.


> ***Congratulations on completing this course!***

> * I hope this helps you learn SSL and how to apply it for labeling tasks.
* You can use the same idea in data domains other than the image domain, such as temporal records and natural language processing.
* $\color{red}{\text{Please rate this course and write a review.}}$

Stay safe and sound!

Mohammad H. Rafiei, Ph.D.

* https://ep.jhu.edu/faculty/mohammad-rafiei/
* https://scholar.google.com/citations?user=74pUQ3sAAAAJ&hl=en













