# SimCLR on CIFAIR10 Image Classification (Tensorflow Backend)

Labeled datasets are much more expensive than their unlabeled counterparts. It is thus often the case that only a small fraction of total available data can be labeled. Therefore, self-supervised learning algorithms, which don't require labeled data during training, have become a huge topic in ML research recently. In 2020 [SimCLR](https://arxiv.org/pdf/2002.05709.pdf) was proposed and achieved 85.8% top-5 accuracy using only 1% of the available labels on the ImageNet dataset.

The idea of SimCLR is to separate visual tasks into two parts: an encoder and a classifier. The encoder projects images to a representation space which is then used by the classifier to make decisions. The encoder doesn't need to know the image class, but it does need to project an "image group" (a group of images generated from the same image with data augmentation) to a cluster. By increasing the similarity of encoded images from the same image groups while reducing similarity between different groups, the encoder can be trained without explicit labels. The process of training the encoder is called "pretraining". 
Later, users can attach any classifier after the pretrained encoder and finetune the whole model for specific visual tasks. According to the paper, this can achieve good results with only a small fraction of the available data being labeled.

In this tutorial we will demonstrate the implementation of SimCLR with the ciFAIR10 dataset. Some details of this implementation will be different from the original paper. This implementation draws upon the code provided [here](https://github.com/google-research/simclr).

In [1]:
import tempfile

import tensorflow as tf
from tensorflow.keras import layers

import fastestimator as fe
from fastestimator.dataset.data import cifair10
from fastestimator.op.numpyop.meta import Sometimes
from fastestimator.op.numpyop.multivariate import HorizontalFlip, PadIfNeeded, RandomCrop
from fastestimator.op.numpyop.univariate import ColorJitter, GaussianBlur, ToFloat, ToGray
from fastestimator.op.tensorop import LambdaOp, TensorOp
from fastestimator.op.tensorop.loss import CrossEntropy
from fastestimator.op.tensorop.model import ModelOp, UpdateOp
from fastestimator.trace.io import BestModelSaver, ModelSaver
from fastestimator.trace.metric import Accuracy

In [2]:
#training parameters
epochs_pretrain = 50
epochs_finetune = 10
batch_size = 512
train_steps_per_epoch = None
eval_steps_per_epoch = None
save_dir = tempfile.mkdtemp()

## Pre-Training Pipeline
In the SimCLR paper they emphasized the importance of data augmentation steps and how these can directly impact the quality of the pretrained model. The preprocessing steps include: random cropping, random color jitter, and random Gaussian blur. An image will go through the pipeline and generate two augmented images which constitute an image group (or pair to be more specific). The batch of augmented image pairs will later be used for model pretraining.

In [3]:
train_data, eval_data = cifair10.load_data()

In [4]:
pipeline_pretrain = fe.Pipeline(
    train_data=train_data,
    batch_size=batch_size,
    ops=[
        PadIfNeeded(min_height=40, min_width=40, image_in="x", image_out="x"),

        # augmentation 1
        RandomCrop(32, 32, image_in="x", image_out="x_aug"),
        Sometimes(HorizontalFlip(image_in="x_aug", image_out="x_aug"), prob=0.5),
        Sometimes(
                ColorJitter(inputs="x_aug", outputs="x_aug", brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2),
                prob=0.8),
        Sometimes(ToGray(inputs="x_aug", outputs="x_aug"), prob=0.2),
        Sometimes(GaussianBlur(inputs="x_aug", outputs="x_aug", blur_limit=(3, 3), sigma_limit=(0.1, 2.0)), prob=0.5),
        ToFloat(inputs="x_aug", outputs="x_aug"),
        
        # augmentation 2
        RandomCrop(32, 32, image_in="x", image_out="x_aug2"),
        Sometimes(HorizontalFlip(image_in="x_aug2", image_out="x_aug2"), prob=0.5),
        Sometimes(
                ColorJitter(inputs="x_aug2", outputs="x_aug2", brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2),
                prob=0.8),
        Sometimes(ToGray(inputs="x_aug2", outputs="x_aug2"), prob=0.2),
        Sometimes(GaussianBlur(inputs="x_aug2", outputs="x_aug2", blur_limit=(3, 3), sigma_limit=(0.1, 2.0)), prob=0.5),
        ToFloat(inputs="x_aug2", outputs="x_aug2")
    ])

## Model
During SimCLR contrastive learning, the training can be separated into two parts: pretraining and finetuning. In the pretraining step, the encoder is attached to a series of MLPs called the "projection head". During finetuning, the encoder is attached to a classifier called the "supervision head". The paper claimed that using the projection head can help make data more clustered in the representation space.    

Although in the original paper they used a ResNet50 model architecture, we will use ResNet9 for faster convergence.

In [5]:
def ResNet9(input_size=(32, 32, 3), dims=128, classes=10):
    """A small 9-layer ResNet Tensorflow model for cifar10 image classification.
    The model architecture is from https://github.com/davidcpage/cifar10-fast

    Args:
        input_size: The size of the input tensor (height, width, channels).
        classes: The number of outputs the model should generate.

    Raises:
        ValueError: Length of `input_size` is not 3.
        ValueError: `input_size`[0] or `input_size`[1] is not a multiple of 16.

    Returns:
        A TensorFlow ResNet9 model.
    """

    # prep layers
    inp = layers.Input(shape=input_size)
    x = layers.Conv2D(64, 3, padding='same')(inp)
    x = layers.BatchNormalization(momentum=0.8)(x)
    x = layers.LeakyReLU(alpha=0.1)(x)
    # layer1
    x = layers.Conv2D(128, 3, padding='same')(x)
    x = layers.MaxPool2D()(x)
    x = layers.BatchNormalization(momentum=0.8)(x)
    x = layers.LeakyReLU(alpha=0.1)(x)
    x = layers.Add()([x, residual(x, 128)])
    # layer2
    x = layers.Conv2D(256, 3, padding='same')(x)
    x = layers.MaxPool2D()(x)
    x = layers.BatchNormalization(momentum=0.8)(x)
    x = layers.LeakyReLU(alpha=0.1)(x)
    # layer3
    x = layers.Conv2D(512, 3, padding='same')(x)
    x = layers.MaxPool2D()(x)
    x = layers.BatchNormalization(momentum=0.8)(x)
    x = layers.LeakyReLU(alpha=0.1)(x)
    x = layers.Add()([x, residual(x, 512)])
    # layers4
    x = layers.GlobalMaxPool2D()(x)
    code = layers.Flatten()(x)

    p_head = layers.Dense(dims)(code)
    model_con = tf.keras.Model(inputs=inp, outputs=p_head)

    s_head = layers.Dense(classes)(code)
    s_head = layers.Activation('softmax', dtype='float32')(s_head)
    model_finetune = tf.keras.Model(inputs=inp, outputs=s_head)

    return model_con, model_finetune


def residual(x, num_channel: int):
    """A ResNet unit for ResNet9.

    Args:
        x: Input Keras tensor.
        num_channel: The number of layer channel.

    Return:
        Output Keras tensor.
    """
    x = layers.Conv2D(num_channel, 3, padding='same')(x)
    x = layers.BatchNormalization(momentum=0.8)(x)
    x = layers.LeakyReLU(alpha=0.1)(x)
    x = layers.Conv2D(num_channel, 3, padding='same')(x)
    x = layers.BatchNormalization(momentum=0.8)(x)
    x = layers.LeakyReLU(alpha=0.1)(x)
    return x


model_con, model_finetune = fe.build(model_fn=ResNet9, optimizer_fn=["adam", "adam"])

## Pre-Training Network
 SimCLR uses NT-Xent (the normalized temperature-scaled cross entropy loss) to train the encoder. By reducing the loss it will increase the similarity of positive augemented pairs and decrease the similarity of negative pairs as the following GIF demonstrates. For a detailed formula, please refer to the [orginal paper](https://arxiv.org/pdf/2002.05709.pdf).
 <img src="https://1.bp.blogspot.com/--vH4PKpE9Yo/Xo4a2BYervI/AAAAAAAAFpM/vaFDwPXOyAokAC8Xh852DzOgEs22NhbXwCLcBGAsYHQ/s1600/image4.gif" alt="drawing" width="400"/>
(source: https://ai.googleblog.com/2020/04/advancing-self-supervised-and-semi.html)
 


In [6]:
class NTXentOp(TensorOp):
    def __init__(self, arg1, arg2, outputs, temperature=1.0, mode=None):
        super().__init__(inputs=(arg1, arg2), outputs=outputs, mode=mode)
        self.temperature = temperature

    def forward(self, data, state):
        arg1, arg2 = data
        loss = NTXent(arg1, arg2, self.temperature)
        return loss


def NTXent(A, B, temperature):
    large_number = 1e9
    batch_size = tf.shape(A)[0]
    A = tf.math.l2_normalize(A, -1)
    B = tf.math.l2_normalize(B, -1)

    mask = tf.one_hot(tf.range(batch_size), batch_size)
    labels = tf.one_hot(tf.range(batch_size), 2 * batch_size)

    aa = tf.matmul(A, A, transpose_b=True) / temperature
    aa = aa - mask * large_number
    ab = tf.matmul(A, B, transpose_b=True) / temperature
    bb = tf.matmul(B, B, transpose_b=True) / temperature
    bb = bb - mask * large_number
    ba = tf.matmul(B, A, transpose_b=True) / temperature
    loss_a = tf.nn.softmax_cross_entropy_with_logits(labels, tf.concat([ab, aa], 1))
    loss_b = tf.nn.softmax_cross_entropy_with_logits(labels, tf.concat([ba, bb], 1))
    loss = tf.reduce_mean(loss_a + loss_b)

    return loss, ab, labels


network_pretrain = fe.Network(ops=[
    LambdaOp(lambda x,y: tf.concat([x, y], axis=0), inputs=["x_aug", "x_aug2"], outputs="x_com"),
    ModelOp(model=model_con, inputs="x_com", outputs="y_com"),
    LambdaOp(lambda x: tf.split(x, 2, axis=0), inputs="y_com", outputs=["y_pred", "y_pred2"]),
    NTXentOp(arg1="y_pred", arg2="y_pred2", outputs=["NTXent", "logit", "label"]),
    UpdateOp(model=model_con, loss_name="NTXent")
])

## Pre-Training Estimator
Next we are going to combine the pretraining pipeline and network together in the estimator class with an `Accuracy` trace to monitor the contrastive accuracy and a `ModelSaver` trace to save the pretrained model. We can then start the training.

In [7]:
traces = [
    Accuracy(true_key="label", pred_key="logit", mode="train", output_name="contrastive_accuracy"),
    ModelSaver(model=model_con, save_dir=save_dir)
]

estimator_pretrain = fe.Estimator(pipeline=pipeline_pretrain,
                                  network=network_pretrain,
                                  epochs=epochs_pretrain,
                                  traces=traces,
                                  train_steps_per_epoch=train_steps_per_epoch)
estimator_pretrain.fit()

    ______           __  ______     __  _                 __            
   / ____/___ ______/ /_/ ____/____/ /_(_)___ ___  ____ _/ /_____  _____
  / /_  / __ `/ ___/ __/ __/ / ___/ __/ / __ `__ \/ __ `/ __/ __ \/ ___/
 / __/ / /_/ (__  ) /_/ /___(__  ) /_/ / / / / / / /_/ / /_/ /_/ / /    
/_/    \__,_/____/\__/_____/____/\__/_/_/ /_/ /_/\__,_/\__/\____/_/     
                                                                        

FastEstimator-Start: step: 1; logging_interval: 100; num_device: 1;
FastEstimator-Train: step: 1; NTXent: 13.829769;
FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_1.h5
FastEstimator-Train: step: 98; epoch: 1; contrastive_accuracy: 0.191; epoch_time: 26.89 sec;
FastEstimator-Train: step: 100; NTXent: 12.382189; steps/sec: 4.52;
FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_2.h5
FastEstimator-Train: step: 196; epoch: 2; contrastive_accuracy: 0.5078; epoch_time: 20.97 sec;
FastEstimator-Train: step: 200; NTXent

FastEstimator-Train: step: 3200; NTXent: 11.982931; steps/sec: 4.91;
FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_33.h5
FastEstimator-Train: step: 3234; epoch: 33; contrastive_accuracy: 0.99192; epoch_time: 19.91 sec;
FastEstimator-Train: step: 3300; NTXent: 11.983704; steps/sec: 4.93;
FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_34.h5
FastEstimator-Train: step: 3332; epoch: 34; contrastive_accuracy: 0.99288; epoch_time: 19.94 sec;
FastEstimator-Train: step: 3400; NTXent: 11.982264; steps/sec: 4.9;
FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_35.h5
FastEstimator-Train: step: 3430; epoch: 35; contrastive_accuracy: 0.99274; epoch_time: 19.95 sec;
FastEstimator-Train: step: 3500; NTXent: 11.976917; steps/sec: 4.91;
FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_36.h5
FastEstimator-Train: step: 3528; epoch: 36; contrastive_accuracy: 0.99184; epoch_time: 19.97 sec;
FastEstimator-Train: step: 36

## Finetune the model on an image classification task

Once the model is pretrained, we can finetune the model on a specific task. In this case we are going to use this pretrained model on ciFAIR10 image classification. Remember in the previous section we built both `model_con` and `model_finetune`. Because those two models share the same encoder object, by (pre)training the `model_con`, the encoder of `model_fintune` is also trained. The finetuing of the model is literally just supervised training with the pretrained encoder. In order to demonstrate the benefit of SimCLR, we are going to fine-tune the network using only 10% of the labeled training data and compare with how well a model could do trained from scratch with the same data limitation.


In [8]:
split_train = train_data.split(0.1)

pipeline_finetune = fe.Pipeline(
    train_data=split_train,
    eval_data=eval_data,
    batch_size=batch_size,
    ops=[
        ToFloat(inputs="x", outputs="x")
    ])

network_finetune = fe.Network(ops=[
    ModelOp(model=model_finetune, inputs="x", outputs="y_pred"),
    CrossEntropy(inputs=["y_pred", "y"], outputs="ce"),
    UpdateOp(model=model_finetune, loss_name="ce")
])

traces = [
    Accuracy(true_key="y", pred_key="y_pred"),
    BestModelSaver(model=model_finetune, save_dir=save_dir, metric="accuracy", save_best_mode="max")
]

est_finetune = fe.Estimator(pipeline=pipeline_finetune,
                            network=network_finetune,
                            epochs=epochs_finetune,
                            traces=traces,
                            train_steps_per_epoch=train_steps_per_epoch)
est_finetune.fit()

    ______           __  ______     __  _                 __            
   / ____/___ ______/ /_/ ____/____/ /_(_)___ ___  ____ _/ /_____  _____
  / /_  / __ `/ ___/ __/ __/ / ___/ __/ / __ `__ \/ __ `/ __/ __ \/ ___/
 / __/ / /_/ (__  ) /_/ /___(__  ) /_/ / / / / / / /_/ / /_/ /_/ / /    
/_/    \__,_/____/\__/_____/____/\__/_/_/ /_/ /_/\__,_/\__/\____/_/     
                                                                        

FastEstimator-Start: step: 1; logging_interval: 100; num_device: 1;
FastEstimator-Train: step: 1; ce: 6.394948;
FastEstimator-Train: step: 9; epoch: 1; epoch_time: 3.6 sec;
FastEstimator-BestModelSaver: Saved model to /tmp/tmp33wb3rot/model1_best_accuracy.h5
FastEstimator-Eval: step: 9; epoch: 1; accuracy: 0.4504; ce: 1.8657482; max_accuracy: 0.4504; since_best_accuracy: 0;
FastEstimator-Train: step: 18; epoch: 2; epoch_time: 0.78 sec;
FastEstimator-BestModelSaver: Saved model to /tmp/tmp33wb3rot/model1_best_accuracy.h5
FastEstimator-Eval: step: 18; epoch

## Results
We can see that SimCLR achieved 70% accuracy using only 10% of the labeled data. With the same configuration, a vanilla ResNet9 can only achieve around 57%.