<a href="https://colab.research.google.com/github/mhrgroup/course_self_supervised_learning/blob/main/Section%2004%3A%20Self-Supervised%20Learning/ssl_section04_lecture09.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Lecture 09: SimCLR, An UnSupervised Contrastive Pretext Model**

By the end of this lecture, you will be able to:

1. Describe how SimCLR of [Chen et al. (2020)](http://proceedings.mlr.press/v119/chen20j/chen20j.pdf) works. 
2. Develop SimCLR custom loss and training functions.

# **9.1. Pretext SimCLR**
---
* [Chen et al. (2020)](http://proceedings.mlr.press/v119/chen20j/chen20j.pdf) proposed a machine learning framework model for contrastive learning of visual representation (SimCLR). 
* This unsupervised contrastive model showed a notable accuracy boost on various standard visual data benchmarks such as CIFAR10, CIFAR100, and Caltech-101 after fine-tuning. 
* SimCLR model is composed of a multi-layer encoder, also referred to as the regressor, and a usually shallow neural network (i.e., dense layers) projector.
* The encoder is later transferred for the downstream task.
* The SimCLR model enriches pretext models by maximizing the agreement between pairs of augmentations’ representations and marginalizing the representations from augmentations of other data points in a batch. 
* The idea is not limited to visual data and can be applied to data types such as temporal records. 
* The following Figure shows the [Chen et al. (2020)](http://proceedings.mlr.press/v119/chen20j/chen20j.pdf) SimCLR pretext model along with a downstream classification model. 

> <img src=	"https://raw.githubusercontent.com/mhrgroup/course_self_supervised_learning/main/images/simclr1.png"	width="500"/>


* The SimCLR superiority is because of 
 * Using two augmentations instead of one to retrieve enricher feature representations in the encoder (last encoder’s hidden layer in Figure), and
 * Using a simple projector to obtain a profound transformation of the encoder’s representation for pretext task. 

* In a training batch, for a particular data point, the SimCLR’s loss function increases augmentation representations’ proximity while marginalizing these representations from other data points’ augmentation's representations. 

> <img src=	"https://raw.githubusercontent.com/mhrgroup/course_self_supervised_learning/main/images/simclralgorithm0.png"	width="400"/>


# **9.2. SimCLR Loss Loss and Training Functions**

---

* We first create a custom loss function.
* Next, we create a SimCLR training function.


## **9.2.1. Loss Function**

* Custom loss functions in TensorFlow require two inputs, real and estimated output values. 

* SimCLR is an unsupervised model, so we create some dummy real output values with the same shape as the predicted ones.  

* The estimated output values are the projector's representation, denoted as z_estimate.

* Remember if N is the batch size the size of matrix z_real is 2*N.

* Our goal is to avoid "for loops" in loss and benefit from multi-threading and memory for training speed. 

* The custom loss functions should be written in the tensor format using the TensorFlow library.

> **Abbreviations:**
*	datain: input data
*	ind: index
* tf: tensorflow

In [None]:
#@title Import necessary libraries
import tensorflow as tf

In [None]:
#@title Custom loss function
@tf.function
def fun_simclr_loss(z_real, z_estimate):
  # z_real is just some dummy variable here 
  del z_real

  # Temperature parameter, which is a hyper-parameter to be optimized for a 
  # particular problem 
  toe = .1 

  num = int(z_estimate.shape[0]) #num = 2N

  # Create i & j indices against each other
  ind0 = tf.repeat(tf.expand_dims(tf.range(0,num),axis = 0),num, axis = 0)
  ind1 = tf.reshape(ind0, (num**2,1))[:,0]
  ind2 = tf.reshape(tf.transpose(ind0), (num**2,1))[:,0]

  del ind0

  # Arrange the z_estimate values based on ind1
  vector_1   = tf.gather(z_estimate, ind1, axis = 0)
  del ind1

  # Arrange the z_estimate values based on ind2
  vector_2   = tf.gather(z_estimate, ind2, axis = 0)
  del ind2

  # Compute the cosine similarity of vector_1 and vector_2
  s      = - tf.reshape(tf.keras.losses.cosine_similarity(vector_1, vector_2, axis=1),(num,num))

  del vector_1 
  del vector_2

  # Compute the nominator of l(i,j)
  nom    = tf.exp(s/toe)

  # Compute the denominator of l(i,j)
  x1    = tf.exp(s/toe)

  del s

  x2    = 1-tf.eye(num, dtype = tf.float32)

  
  denom = tf.repeat(tf.expand_dims(tf.math.reduce_sum(x1 * x2, axis = 1), axis = 1), num, axis = 1)
  
  del x1
  del x2

  # Compute l(i,j) for all i and j
  l     = -tf.math.log(nom/denom)

  del nom
  del denom 

  # Compute L
  ind_2k0 = tf.range(0,num,2, dtype=tf.int32) 
  ind_2k1 = tf.range(1,num,2, dtype=tf.int32)

  loss_mat1_1 = tf.gather(l,           ind_2k0, axis = 0)
  loss_mat1_2 = tf.gather(loss_mat1_1, ind_2k1, axis = 1)
  loss_mat1   = tf.linalg.diag_part(loss_mat1_2)

  loss_mat2_1 = tf.gather(l,           ind_2k1, axis = 0)
  loss_mat2_2 = tf.gather(loss_mat2_1, ind_2k0, axis = 1)
  loss_mat2   = tf.linalg.diag_part(loss_mat2_2)

  del l

  loss_mat = loss_mat1 + loss_mat2

  L = tf.math.reduce_sum(loss_mat)/num

  return L

In [None]:
#@title SimCLR training function
def fun_train_simclr(model, datain, fun_augment_01, fun_augment_02, 
                     epochs = 100, batch_size = 32, verbose = 1, 
                     shuffle = True, patience = 3):
  
  num_data  = datain.shape[0]
  num_batch = int(num_data//batch_size) # Reminder: // is divide and floor
  z_size    = model.layers[-1].weights[-1].shape[0]
  loss      = []  

  for i0 in range(epochs):
    counter    = 0
    loss_batch = []

    if shuffle:
      ind_shuffle = tf.experimental.numpy.random.randint(0,datain.shape[0],datain.shape[0])
      datain = datain[ind_shuffle,:]


    for i1 in range(num_batch):
      if i1 == num_batch - 1:
        ind_case = range(counter, num_data)
      else:
        ind_case = range(counter, counter + batch_size)


      x_tilda_01  = fun_augment_01(datain[ind_case,:])
      x_tilda_02  = fun_augment_02(datain[ind_case,:]) 

      x_tilda     = tf.reshape(tf.concat([x_tilda_01,x_tilda_02], axis = 1), 
                               (x_tilda_01.shape[0] + x_tilda_02.shape[0],  
                                x_tilda_01.shape[1], x_tilda_01.shape[2], 
                                x_tilda_01.shape[3]))
      
      z_real      = tf.random.uniform((x_tilda.shape[0],z_size)) # dummy variable

      # Train on batch
      var = model.train_on_batch(x_tilda, z_real);
      loss_batch.append(var[0]) 
      # you may change this to loss_batch.append(var) and exclude any 'metrics'
      # when compiling the model


      counter  = counter + batch_size 

      if verbose:
        if i1 == num_batch - 1:
          print("\r SimCLR | Epoch {:04d}/{:04d} - Batch {:04d}/{:04d} - Loss {:8.5F}".format(i0+1, epochs, i1+1, num_batch, sum(loss_batch)/len(loss_batch)), flush=True)
        else:
          print("\r SimCLR | Epoch {:04d}/{:04d} - Batch {:04d}/{:04d} - Loss {:8.5F}".format(i0+1, epochs, i1+1, num_batch, sum(loss_batch)/len(loss_batch)), end="", flush=True)

      

    loss.append(sum(loss_batch)/len(loss_batch))

    if i0>patience:
      loss_hist_min = min(loss)

      if loss[-1] > loss_hist_min:
        break
  
  return model, loss

# **Lecture 09: SimCLR, An UnSupervised Contrastive Model**

In this lecture, you learned about:

1. How SimCLR of [Chen et al. (2020)](http://proceedings.mlr.press/v119/chen20j/chen20j.pdf) works. 
2. SimCLR custom loss and training functions.

> ***In the following lecture, we will see a labeling example using SimCLR unsupervised contrastive learning.*** 