In [None]:
from google.colab import drive
drive.mount('/content/drive')

# !pip install  tensorflow==2.8
!git clone https://github.com/beresandras/contrastive-classification-keras.git

In [None]:
import tensorflow as tf
import keras
import sys
sys.path.insert(0,'/content/contrastive-classification-keras')
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential, Model
from algorithms import SimCLR, NNCLR, DCCLR, BarlowTwins, HSICTwins, TWIST, MoCo, DINO
import tensorflow_datasets as tfds
import os
from sklearn.model_selection import train_test_split
import numpy as np

In [None]:
# hyperparameters
epochs = 1
steps_per_epoch = 200
PROJECT_DIM = 128

input_shape = (96,96,3)
# hyperparameters corresponding to each algorithm
hyperparams = {
    "SimCLR": {"temperature": 0.1},
    "NNCLR": {"temperature": 0.1, "queue_size": 10000},
    "DCCLR": {"temperature": 0.1},
    "BarlowTwins": {"redundancy_reduction_weight": 10.0},
    "HSICTwins": {"redundancy_reduction_weight": 3.0},
    "TWIST": {},
    "MoCo": {"momentum_coeff": 0.99, "temperature": 0.1, "queue_size": 10000},
    "DINO": {"momentum_coeff": 0.9, "temperature": 0.1, "sharpening": 0.5},
}

ssl = {
    "SimCLR" : SimCLR,
    "NNCLR" : NNCLR,
    "DCCLR" : DCCLR,
    "BarlowTwins" : BarlowTwins,
    "HSICTwins" : HSICTwins,
    "TWIST" : TWIST ,
    "MoCo" : MoCo,
    "DINO" : DINO

}


temperature = 0.1
queue_size = 10000



classification_augmenter = {
    "brightness": 0.2,
    "name": "classification_augmenter",
    "scale": (0.5, 1.0),
}

contrastive_augmenter = {
    "brightness": 0.5,
    "name": "contrastive_augmenter",
    "scale": (0.2, 1.0),
}



AUTOTUNE = tf.data.AUTOTUNE
shuffle_buffer = 5000
# The below two values are taken from https://www.tensorflow.org/datasets/catalog/stl10
labelled_train_images = 5000
unlabelled_images = 100000

PROJECT_DIM = 2048
WEIGHT_DECAY = 0.0005
LATENT_DIM = 512
CROP_TO = 96
batch_size = 32
l = 0.001
dataset_name = "stl10"

In [None]:


def prepare_dataset_stl(size = 32):

    unlabeled_batch_size = size
    labeled_batch_size = size
    batch_size = size

    # unlabeled_batch_size = 500
    # labeled_batch_size = 500
    # batch_size = 500


    unlabeled_train_dataset = (
        tfds.load(
            dataset_name, split="unlabelled", as_supervised=True, shuffle_files=True
        )
        .shuffle(buffer_size=shuffle_buffer)
        .batch(unlabeled_batch_size, drop_remainder=True)
    )
    labeled_train_dataset = (
        tfds.load(dataset_name, split="train", as_supervised=True, shuffle_files=True)
        .shuffle(buffer_size=shuffle_buffer)
        .batch(labeled_batch_size, drop_remainder=True)
    )
    test_dataset = (
        tfds.load(dataset_name, split="test", as_supervised=True)
        .batch(batch_size)
        .prefetch(buffer_size=AUTOTUNE)
    )
    train_dataset = tf.data.Dataset.zip(
        (unlabeled_train_dataset, labeled_train_dataset)
    ).prefetch(buffer_size=AUTOTUNE)

    return (train_dataset, labeled_train_dataset, test_dataset)



In [None]:

train_dataset, labeled_train_dataset, test_dataset = prepare_dataset_stl(32)

In [None]:
class RandomResizedCrop(layers.Layer):
    def __init__(self, scale, ratio):
        super(RandomResizedCrop, self).__init__()
        self.scale = scale
        self.log_ratio = (tf.math.log(ratio[0]), tf.math.log(ratio[1]))

    def call(self, images):
        batch_size = tf.shape(images)[0]
        height = tf.shape(images)[1]
        PROJECT_DIM = tf.shape(images)[2]

        random_scales = tf.random.uniform((batch_size,), self.scale[0], self.scale[1])
        random_ratios = tf.exp(
            tf.random.uniform((batch_size,), self.log_ratio[0], self.log_ratio[1])
        )

        new_heights = tf.clip_by_value(tf.sqrt(random_scales / random_ratios), 0, 1)
        new_PROJECT_DIMs = tf.clip_by_value(tf.sqrt(random_scales * random_ratios), 0, 1)
        height_offsets = tf.random.uniform((batch_size,), 0, 1 - new_heights)
        PROJECT_DIM_offsets = tf.random.uniform((batch_size,), 0, 1 - new_PROJECT_DIMs)

        bounding_boxes = tf.stack(
            [
                height_offsets,
                PROJECT_DIM_offsets,
                height_offsets + new_heights,
                PROJECT_DIM_offsets + new_PROJECT_DIMs,
            ],
            axis=1,
        )
        images = tf.image.crop_and_resize(
            images, bounding_boxes, tf.range(batch_size), (height, PROJECT_DIM)
        )
        return images



class RandomBrightness(layers.Layer):
    def __init__(self, brightness):
        super(RandomBrightness, self).__init__()
        self.brightness = brightness

    def blend(self, images_1, images_2, ratios):
        return tf.clip_by_value(ratios * images_1 + (1.0 - ratios) * images_2, 0, 1)

    def random_brightness(self, images):
        # random interpolation/extrapolation between the image and darkness
        return self.blend(
            images,
            0,
            tf.random.uniform(
                (tf.shape(images)[0], 1, 1, 1), 1 - self.brightness, 1 + self.brightness
            ),
        )

    def call(self, images):
        images = self.random_brightness(images)
        return images

def augmenter(brightness, name, scale):
    return keras.Sequential(
        [
            layers.Input(shape=input_shape),
            layers.Rescaling(1 / 255),
            layers.RandomFlip("horizontal"),
            RandomResizedCrop(scale=scale, ratio=(3 / 4, 4 / 3)),
            RandomBrightness(brightness=brightness),
        ],
        name=name,
    )


def augmenter2( brightness, name, scale):
    return keras.Sequential(
        [
            layers.Input(shape=input_shape),
            layers.Rescaling(1 / 255),
            layers.RandomFlip("horizontal"),
            layers.RandomZoom(0.2),
            layers.RandomRotation(0.3),
            layers.RandomContrast(0.2),
            layers.RandomHeight(0.2),

            RandomResizedCrop(scale=scale, ratio=(3 / 4, 4 / 3)),
            RandomBrightness(brightness=brightness),
        ],
        name=name,
    )


def get_encoder2():
    base_model = tf.keras.applications.ResNet50(include_top=False,
        weights="imagenet", input_shape=input_shape)
    base_model.trainable = True

    inputs = layers.Input(input_shape)
    x = base_model(inputs, training=True)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(2048, activation='relu', use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    z = layers.Dense(2048)(x)

    f = tf.keras.Model(inputs, z)

    return f



In [None]:
class PNNCLR(keras.Model):
    def __init__(
        self, temperature, queue_size,t = 0.99, frac = 0.05 , points = 21, initial = 0.05 , noise_val = 0.1, noise = False, nn = True
    ):
        super(PNNCLR, self).__init__()
        self.probe_accuracy = keras.metrics.SparseCategoricalAccuracy()
        self.correlation_accuracy = keras.metrics.SparseCategoricalAccuracy()
        self.contrastive_accuracy = keras.metrics.SparseCategoricalAccuracy()
        self.probe_loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        self.contrastive_augmenter = augmenter2(**contrastive_augmenter)
        self.classification_augmenter = augmenter2(**classification_augmenter)


        self.f_online  = get_encoder2()
        self.f_target = get_encoder2()
        self.g_online = keras.Sequential(
        [
            layers.Input(shape=(PROJECT_DIM,)),
            layers.Dense(PROJECT_DIM, activation="relu"),
            layers.Dense(PROJECT_DIM),
        ],
            name="projection_head",

        )


        self.q_online = keras.Sequential(
         [
            layers.Input(shape=(PROJECT_DIM,)),
            layers.Dense(PROJECT_DIM, activation="relu"),
            layers.Dense(PROJECT_DIM),
        ],
            name="projection_head1",

        )

        self.g_target = keras.Sequential(
        [
            layers.Input(shape=(PROJECT_DIM,)),
            layers.Dense(PROJECT_DIM, activation="relu"),
            layers.Dense(PROJECT_DIM),
        ],
            name="projection_head2",


        )
        self.linear_probe = keras.Sequential(
            [layers.Input(shape=(PROJECT_DIM,)), layers.Dense(10)], name="linear_probe"
        )
        self.temperature = temperature

        feature_dimensions = self.f_online.output_shape[1]
        self.feature_queue = tf.Variable(
            tf.math.l2_normalize(
                tf.random.normal(shape=(queue_size, feature_dimensions)), axis=1
            ),
            trainable=False,
        )

        self.t = t
        self.noise = noise
        self.nn = nn
        self.frac = frac
        self.points = points
        self.initial = initial
        self.noise_val = noise_val
        self.index = int(points - frac/initial - 1)


        print()
        print()
        print("____config________")
        print("frac " , self.frac)
        print("points ", self.points)
        print("initial ", self.initial)
        print("nn ", self.nn)
        print("noise ", self.noise)
        print("noise val" , self.noise_val)
        print("t ",self.t)
        print("index ", self.index)
        print()
        print()


    def gaussian_noise_layer(input_layer, std = 0.2):
        noise = tf.random.normal(shape=tf.shape(input_layer), mean=0.0, stddev=std, dtype=tf.float32)
        return input_layer + noise

    def compile(self, contrastive_optimizer, probe_optimizer, **kwargs):
        super(PNNCLR, self).compile(**kwargs)
        self.contrastive_optimizer = contrastive_optimizer
        self.probe_optimizer = probe_optimizer


    def nearest_neighbour(self, projection):

        """
                    projection
                    Tensor("l2_normalize:0", shape=(64, 2048), dtype=float32)

                    self.feature_queue
                    <tf.Variable 'Variable:0' shape=(10000, 2048) dtype=float32>

                    support_similarities
                    Tensor("MatMul:0", shape=(64, 10000), dtype=float32)

                    tf.argmax(support_similarities, axis=1)
                    Tensor("ArgMax:0", shape=(64,), dtype=int64)
                    nn_projections

                    Tensor("Identity:0", shape=(64, 2048), dtype=float32)

                    projection + tf.stop_gradient(nn_projections - projections)
                    Tensor("add:0", shape=(64, 2048), dtype=float32)


        """


        support_similarities = tf.matmul(
            projection, self.feature_queue, transpose_b=True
        )



        nn_projection = tf.gather(
            self.feature_queue, tf.argmax(support_similarities, axis=1), axis=0
        )



        if(self.nn):


            # nn_projections = tf.keras.layers.GaussianNoise(0.1)(nn_projections)

            # new_projections =  (self.a) * nn_projections + (self.b) * projections

            new_projections = tf.linspace(projection, nn_projection , num=self.points)


            new_projection = new_projections[self.index]



            if(self.noise):

                new_projection = tf.keras.layers.GaussianNoise(self.noise_val)(new_projection)

            return new_projection

        else:

            return nn_projection


    def update_contrastive_accuracy(self, features_1, features_2):
        features_1 = tf.math.l2_normalize(features_1, axis=1)
        features_2 = tf.math.l2_normalize(features_2, axis=1)
        similarities = tf.matmul(features_1, features_2, transpose_b=True)

        batch_size = tf.shape(features_1)[0]
        contrastive_labels = tf.range(batch_size)
        self.contrastive_accuracy.update_state(
            tf.concat([contrastive_labels, contrastive_labels], axis=0),
            tf.concat([similarities, tf.transpose(similarities)], axis=0),
        )

    def update_correlation_accuracy(self, features_1, features_2):
        features_1 = (
            features_1 - tf.reduce_mean(features_1, axis=0)
        ) / tf.math.reduce_std(features_1, axis=0)
        features_2 = (
            features_2 - tf.reduce_mean(features_2, axis=0)
        ) / tf.math.reduce_std(features_2, axis=0)

        batch_size = tf.shape(features_1, out_type=tf.float32)[0]
        cross_correlation = (
            tf.matmul(features_1, features_2, transpose_a=True) / batch_size
        )

        feature_dim = tf.shape(features_1)[1]
        correlation_labels = tf.range(feature_dim)
        self.correlation_accuracy.update_state(
            tf.concat([correlation_labels, correlation_labels], axis=0),
            tf.concat([cross_correlation, tf.transpose(cross_correlation)], axis=0),
        )



    def contrastive_loss(self, projections_1, projections_2):

        projections_1 = tf.math.l2_normalize(projections_1, axis=1)
        projections_2 = tf.math.l2_normalize(projections_2, axis=1)

        similarities_1_2_1 = (
            tf.matmul(
                self.nearest_neighbour(projections_1), projections_2, transpose_b=True
            )
            / self.temperature
        )
        similarities_1_2_2 = (
            tf.matmul(
                projections_2, self.nearest_neighbour(projections_1), transpose_b=True
            )
            / self.temperature
        )

        similarities_2_1_1 = (
            tf.matmul(
                self.nearest_neighbour(projections_2), projections_1, transpose_b=True
            )
            / self.temperature
        )
        similarities_2_1_2 = (
            tf.matmul(
                projections_1, self.nearest_neighbour(projections_2), transpose_b=True
            )
            / self.temperature
        )

        batch_size = tf.shape(projections_1)[0]
        contrastive_labels = tf.range(batch_size)
        loss = keras.losses.sparse_categorical_crossentropy(
            tf.concat(
                [
                    contrastive_labels,
                    contrastive_labels,
                    contrastive_labels,
                    contrastive_labels,
                ],
                axis=0,
            ),
            tf.concat(
                [
                    similarities_1_2_1,
                    similarities_1_2_2,
                    similarities_2_1_1,
                    similarities_2_1_2,
                ],
                axis=0,
            ),
            from_logits=True,
        )

        self.feature_queue.assign(
            tf.concat([projections_1, self.feature_queue[:-batch_size]], axis=0)
        )
        return loss






    def train_step(self, data):

        (unlabeled_images, _), (labeled_images, labels) = data
        images = tf.concat((unlabeled_images, labeled_images), axis=0)

        x1 = self.contrastive_augmenter(images)
        x2 = self.contrastive_augmenter(images)

        h_target_1 = self.f_target(x1)
        z_target_1 = self.g_target(h_target_1)

        h_target_2 = self.f_target(x2)
        z_target_2 = self.g_target(h_target_2)

        with tf.GradientTape(persistent=True) as tape:
            h_online_1 = self.f_online(x1)
            z_online_1  = self.g_online(h_online_1)
            p_online_1 = self.q_online(z_online_1)

            h_online_2 = self.f_online(x2)
            z_online_2  = self.g_online(h_online_2)
            p_online_2 = self.q_online(z_online_2)


            loss = self.contrastive_loss(p_online_1, z_target_2)/2 + self.contrastive_loss(p_online_2, z_target_1)/2

        # Backward pass (update online networks)
        grads = tape.gradient(loss, self.f_online.trainable_weights )
        self.contrastive_optimizer.apply_gradients(zip(grads, self.f_online.trainable_weights ))
        grads = tape.gradient(loss, self.g_online.trainable_weights )
        self.contrastive_optimizer.apply_gradients(zip(grads, self.g_online.trainable_weights ))
        grads = tape.gradient(loss, self.q_online.trainable_weights )
        self.contrastive_optimizer.apply_gradients(zip(grads, self.q_online.trainable_weights ))

        del tape



        self.update_contrastive_accuracy(h_online_1, h_online_2)
        self.update_correlation_accuracy(h_online_1, h_online_2)
        preprocessed_images = self.classification_augmenter(labeled_images)

        with tf.GradientTape() as tape:
            features = self.f_online(preprocessed_images)
            class_logits = self.linear_probe(features)
            probe_loss = self.probe_loss(labels, class_logits)
        gradients = tape.gradient(probe_loss, self.linear_probe.trainable_weights)
        self.probe_optimizer.apply_gradients(
            zip(gradients, self.linear_probe.trainable_weights)
        )
        self.probe_accuracy.update_state(labels, class_logits)




       # the momentum networks are updated by exponential moving average
        for weight, m_weight in zip(self.f_online.weights, self.f_target.weights):
            m_weight.assign(
               self.t * m_weight + (1 - self.t) * weight
            )
        for weight, m_weight in zip(
            self.g_online.weights, self.g_target.weights
        ):
            m_weight.assign(
                self.t * m_weight + (1 - self.t) * weight
            )


        return {
            "c_loss": loss,
            "c_acc": self.contrastive_accuracy.result(),
            "r_acc": self.correlation_accuracy.result(),
            "p_loss": probe_loss,
            "p_acc": self.probe_accuracy.result(),
        }








    def test_step(self, data):
        labeled_images, labels = data

        preprocessed_images = self.classification_augmenter(
            labeled_images, training=False
        )
        features = self.f_online(preprocessed_images, training=False)
        class_logits = self.linear_probe(features, training=False)
        probe_loss = self.probe_loss(labels, class_logits)

        self.probe_accuracy.update_state(labels, class_logits)
        return {"p_loss": probe_loss, "p_acc": self.probe_accuracy.result()}




In [None]:

def ssl_method(name = "SimCLR"):


    Algorithm = ssl[name]

    # architecture
    model_ = Algorithm(
    contrastive_augmenter= augmenter2(**contrastive_augmenter),
    classification_augmenter= augmenter2(**classification_augmenter),
    encoder= get_encoder2() ,
    projection_head=keras.Sequential(
        [
            layers.Input(shape=(PROJECT_DIM,)),
            layers.Dense(PROJECT_DIM, activation="relu"),
            layers.Dense(PROJECT_DIM),
        ],
        name="projection_head",
    ),
    linear_probe=keras.Sequential(
        [
            layers.Input(shape=(PROJECT_DIM,)),
            layers.Dense(10),
        ],
        name="linear_probe",
    ),
    **hyperparams[name],
    )


    return model_





In [None]:
def train_ssl(model,train_ssl,val_data):


    steps = epochs * (5000// batch_size)
    lr_decayed_fn = tf.keras.optimizers.schedules.CosineDecay(
    initial_learning_rate= l , decay_steps=steps)

    model.compile(
        contrastive_optimizer= tf.keras.optimizers.legacy.SGD(lr_decayed_fn, momentum=0.9),
        probe_optimizer= tf.keras.optimizers.legacy.SGD(lr_decayed_fn, momentum=0.9),
    )

    # run training
    history = model.fit(train_ssl, epochs=epochs, validation_data=val_data )

    return history, model

In [None]:


model_SimCLR =  ssl_method(name = "SimCLR")
history ,model_SimCLR = train_ssl(model_SimCLR,train_dataset,test_dataset)
# model_SimCLR.encoder.save_weights('required path')



In [None]:
model_NNCLR=  ssl_method(name = "NNCLR")
history ,model_NNCLR = train_ssl(model_NNCLR,train_dataset,test_dataset)
# model_NNCLR.encoder.save_weights('required path')


In [None]:
model_PNNCLR = PNNCLR(temperature=temperature, queue_size=queue_size)


history ,model_PNNCLR = train_ssl(model_PNNCLR,train_dataset,test_dataset)
# model_PNNCLR.f_online.save_weights('required path')


