# Semi-supervised image classification using contrastive pretraining with SimCLR

## Setup

In [None]:
import datetime
import os

os.environ["KERAS_BACKEND"] = "tensorflow"


# Make sure we are able to handle large datasets
import resource

low, high = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (high, high))

import math

import keras

%matplotlib widget
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import tensorflow as tf
from keras import layers
from keras import optimizers as ops
from sklearn.metrics import classification_report, confusion_matrix

import tensorflow_datasets as tfds

# Load the TensorBoard notebook extension
%load_ext tensorboard

# Set seed for evaluation purpose (remove in production)
keras.utils.set_random_seed(5)
os.environ["PYTHONHASHSEED"] = str(5)
os.environ["TF_CUDNN_DETERMINISTIC"] = "1"
np.random.seed(5)
tf.random.set_seed(5)
tf.config.experimental.enable_op_determinism()

## Hyperparameter setup

In [None]:
# Dataset hyperparameters
unlabeled_dataset_size = sum(
    [len(files) for r, d, files in os.walk("../data_ssl/unlabeled/")]
)
labeled_dataset_size = sum(
    [len(files) for r, d, files in os.walk("../data_ssl/train/")]
)
img_height = 224
img_width = 224
width = 224

print("Unlabeled Images: " + str(unlabeled_dataset_size))
print("Labeled Images: " + str(labeled_dataset_size))

# Algorithm hyperparameters
num_epochs = 30
batch_size = 20
temperature = 0.1

# Stronger augmentations for contrastive, weaker ones for supervised training
contrastive_augmentation = {"min_area": 0.3, "brightness": 0.6, "jitter": 0.2}
classification_augmentation = {
    "min_area": 0.75,
    "brightness": 0.3,
    "jitter": 0.1,
}

## Dataset

In [None]:
def prepare_dataset():
    # Labeled and unlabeled samples are loaded synchronously
    # with batch sizes selected accordingly
    steps_per_epoch = (unlabeled_dataset_size + labeled_dataset_size) // batch_size
    unlabeled_batch_size = unlabeled_dataset_size // steps_per_epoch
    labeled_batch_size = labeled_dataset_size // steps_per_epoch
    print(
        f"batch size is {unlabeled_batch_size} (unlabeled) + {labeled_batch_size} (labeled)"
    )

    labeled_train_ds, val_ds = tf.keras.utils.image_dataset_from_directory(
        "../data_ssl/train/",
        validation_split=0.2,
        subset="both",
        seed=5,
        image_size=(img_height, img_width),
        batch_size=batch_size,
        shuffle=True,
        label_mode="int"
    )

    unlabeled_train_ds = tf.keras.utils.image_dataset_from_directory(
        "../data_ssl/unlabeled/",
        label_mode=None,
        seed=5,
        image_size=(img_height, img_width),
        batch_size=batch_size,
        shuffle=True
    )

    test_dataset = tf.keras.utils.image_dataset_from_directory(
        "../data_ssl/test_labeled/",
        seed=5,
        image_size=(img_height, img_width),
        batch_size=batch_size,
        shuffle=True,
        color_mode="rgb",
        label_mode="int"
    )

    # Labeled and unlabeled datasets are zipped together
    train_dataset = tf.data.Dataset.zip(
        (unlabeled_train_ds, labeled_train_ds)
    ).prefetch(buffer_size=tf.data.AUTOTUNE)

    return train_dataset, labeled_train_ds, test_dataset, val_ds


# Load dataset
train_dataset, labeled_train_dataset, test_dataset, validation_dataset = (
    prepare_dataset()
)

## Image augmentations

In [None]:
# Distorts the color distibutions of images
class RandomColorAffine(layers.Layer):
    def __init__(self, brightness=0, jitter=0, **kwargs):
        super().__init__(**kwargs)

        self.brightness = brightness
        self.jitter = jitter

    def get_config(self):
        config = super().get_config()
        config.update({"brightness": self.brightness, "jitter": self.jitter})
        return config

    def call(self, images, training=True):
        if training:
            batch_size = tf.shape(images)[0]

            # Same for all colors
            brightness_scales = 1 + tf.random.uniform(
                (batch_size, 1, 1, 1),
                minval=-self.brightness,
                maxval=self.brightness,
            )
            # Different for all colors
            jitter_matrices = tf.random.uniform(
                (batch_size, 1, 3, 3), minval=-self.jitter, maxval=self.jitter
            )

            color_transforms = (
                tf.eye(3, batch_shape=[batch_size, 1])
                * brightness_scales
                # + jitter_matrices // removed color change
            )
            images = tf.clip_by_value(tf.matmul(images, color_transforms), 0, 1)
        return images


# Image augmentation module
def get_augmenter(min_area, brightness, jitter):
    zoom_factor = 1.0 - math.sqrt(min_area)
    return keras.Sequential(
        [
            layers.Rescaling(1 / 255),
            layers.RandomFlip("horizontal"),
            layers.RandomTranslation(zoom_factor / 2, zoom_factor / 2),
            layers.RandomZoom((-zoom_factor, 0.0), (-zoom_factor, 0.0)),
            RandomColorAffine(brightness, jitter),
        ]
    )


def get_augmenter_without_pooling(min_area, brightness, jitter):
    zoom_factor = 1.0 - math.sqrt(min_area)
    return keras.Sequential(
        [
            layers.Rescaling(1 / 255),
            layers.RandomFlip("horizontal"),
            layers.RandomTranslation(zoom_factor / 2, zoom_factor / 2),
            layers.RandomZoom((-zoom_factor, 0.0), (-zoom_factor, 0.0)),
            RandomColorAffine(brightness, jitter),
        ]
    )


def visualize_augmentations(num_images):
    # Sample a batch from a dataset
    images = next(iter(train_dataset))[0][:num_images]

    # Apply augmentations
    augmented_images = zip(
        keras.Sequential([layers.Rescaling(1 / 255)])(images),
        get_augmenter_without_pooling(**classification_augmentation)(images),
        get_augmenter_without_pooling(**contrastive_augmentation)(images),
        get_augmenter_without_pooling(**contrastive_augmentation)(images),
    )
    row_titles = [
        "Original:",
        "Weakly augmented:",
        "Strongly augmented:",
        "Strongly augmented:",
    ]
    plt.figure(figsize=(num_images * 1.3, 4 * 1.3), dpi=100)
    for column, image_row in enumerate(augmented_images):
        for row, image in enumerate(image_row):
            plt.subplot(4, num_images, row * num_images + column + 1)
            plt.imshow(image)
            if column == 0:
                plt.title(row_titles[row], loc="left")
            plt.axis("off")
    plt.tight_layout()
    plt.show()


visualize_augmentations(num_images=8)

## Encoder architecture

In [None]:
# Load the pretrained models
base_finetuned_model = tf.keras.applications.efficientnet_v2.EfficientNetV2S(
    input_shape=(224, 224, 3), include_top=False, weights="imagenet", pooling="avg"
)
base_finetuned_model.trainable = True

# Define encoder
def get_encoder():
    # return resnet_simclr // this needs a lot of GPU/RAM capacity
    return keras.Sequential(
        [
            layers.Input(shape=(224, 224, 3)),
            layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
            layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
            layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
            layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
            layers.Flatten(),
            layers.Dense(width, activation="relu"),
        ],
        name="encoder",
    )

## Supervised baseline model

In [None]:
# Baseline supervised training with random initialization
pretrained_model = tf.keras.applications.efficientnet_v2.EfficientNetV2S(
    input_shape=(224, 224, 3), include_top=False, weights="imagenet", pooling="avg"
)
pretrained_model.trainable = False

inputs = pretrained_model.input
x = tf.keras.layers.Dense(128, activation="relu")(pretrained_model.output)
x = tf.keras.layers.Dense(128, activation="relu")(x)
x = tf.keras.layers.Dropout(0.4)(x)
x = tf.keras.layers.Dense(64, activation="relu")(x)
outputs = tf.keras.layers.Dense(7, activation="softmax")(x)
baseline_model = tf.keras.Model(
    inputs,
    outputs,
    name="baseline_model",
)

# Compile baseline_model
baseline_model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False),
    metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc")],
)

logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
callbacks = [
    tf.keras.callbacks.EarlyStopping(
        monitor="val_loss", patience=5, restore_best_weights=True
    ),
    tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1),
]

# Fit baseline_model
baseline_history = baseline_model.fit(
    labeled_train_dataset,
    epochs=num_epochs,
    validation_data=validation_dataset,
    callbacks=callbacks,
)

print(
    "Maximal validation accuracy: {:.2f}%".format(
        max(baseline_history.history["val_acc"]) * 100
    )
)

## Self-supervised model for contrastive pretraining

In [None]:
# Define the contrastive model (SimCLR) with model-subclassing
class ContrastiveModel(keras.Model):
    def __init__(self):
        super().__init__()

        self.temperature = temperature
        self.contrastive_augmenter = get_augmenter(**contrastive_augmentation)
        self.classification_augmenter = get_augmenter(**classification_augmentation)
        self.encoder = get_encoder()
        # Non-linear MLP as projection head
        self.projection_head = keras.Sequential(
            [
                keras.Input(shape=(width,)),
                layers.Dense(width, activation="relu"),
                layers.Dense(width),
            ],
            name="projection_head",
        )
        # Single dense layer for linear probing
        self.linear_probe = keras.Sequential(
            [layers.Input(shape=(width,)), layers.Dense(7)],
            name="linear_probe",
        )

        self.encoder.summary()
        self.projection_head.summary()
        self.linear_probe.summary()

    def compile(self, contrastive_optimizer, probe_optimizer, **kwargs):
        super().compile(**kwargs)

        self.contrastive_optimizer = contrastive_optimizer
        self.probe_optimizer = probe_optimizer

        # self.contrastive_loss will be defined as a method
        self.probe_loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

        self.contrastive_loss_tracker = keras.metrics.Mean(name="c_loss")
        self.contrastive_accuracy = keras.metrics.SparseCategoricalAccuracy(
            name="c_acc"
        )
        self.probe_loss_tracker = keras.metrics.Mean(name="p_loss")
        self.probe_accuracy = keras.metrics.SparseCategoricalAccuracy(name="p_acc")

    @property
    def metrics(self):
        return [
            self.contrastive_loss_tracker,
            self.contrastive_accuracy,
            self.probe_loss_tracker,
            self.probe_accuracy,
        ]

    def contrastive_loss(self, projections_1, projections_2):
        # InfoNCE loss (information noise-contrastive estimation)
        # NT-Xent loss (normalized temperature-scaled cross entropy)

        # Cosine similarity: the dot product of the l2-normalized feature vectors
        projections_1 = tf.math.l2_normalize(projections_1, axis=1)
        projections_2 = tf.math.l2_normalize(projections_2, axis=1)
        similarities = (
            tf.matmul(projections_1, projections_2, transpose_b=True) / self.temperature
        )

        # The similarity between the representations of two augmented views of the
        # same image should be higher than their similarity with other views
        batch_size = tf.shape(projections_1)[0]
        contrastive_labels = tf.range(batch_size)
        self.contrastive_accuracy.update_state(contrastive_labels, similarities)
        self.contrastive_accuracy.update_state(
            contrastive_labels, tf.transpose(similarities)
        )

        # The temperature-scaled similarities are used as logits for cross-entropy
        # a symmetrized version of the loss is used here
        loss_1_2 = keras.losses.sparse_categorical_crossentropy(
            contrastive_labels, similarities, from_logits=True
        )
        loss_2_1 = keras.losses.sparse_categorical_crossentropy(
            contrastive_labels, tf.transpose(similarities), from_logits=True
        )
        return (loss_1_2 + loss_2_1) / 2

    def train_step(self, data):
        (unlabeled_images), (labeled_images, labels) = data

        # Both labeled and unlabeled images are used, without labels
        images = tf.concat((unlabeled_images, labeled_images), axis=0)
        # Each image is augmented twice, differently
        augmented_images_1 = self.contrastive_augmenter(images, training=True)
        augmented_images_2 = self.contrastive_augmenter(images, training=True)
        with tf.GradientTape() as tape:
            features_1 = self.encoder(augmented_images_1, training=True)
            features_2 = self.encoder(augmented_images_2, training=True)
            # The representations are passed through a projection mlp
            projections_1 = self.projection_head(features_1, training=True)
            projections_2 = self.projection_head(features_2, training=True)
            contrastive_loss = self.contrastive_loss(projections_1, projections_2)
        gradients = tape.gradient(
            contrastive_loss,
            self.encoder.trainable_weights + self.projection_head.trainable_weights,
        )
        self.contrastive_optimizer.apply_gradients(
            zip(
                gradients,
                self.encoder.trainable_weights + self.projection_head.trainable_weights,
            )
        )
        self.contrastive_loss_tracker.update_state(contrastive_loss)

        # Labels are only used in evalutation for an on-the-fly logistic regression
        preprocessed_images = self.classification_augmenter(
            labeled_images, training=True
        )
        with tf.GradientTape() as tape:
            # the encoder is used in inference mode here to avoid regularization
            # and updating the batch normalization paramers if they are used
            features = self.encoder(preprocessed_images, training=False)
            class_logits = self.linear_probe(features, training=True)
            probe_loss = self.probe_loss(labels, class_logits)
        gradients = tape.gradient(probe_loss, self.linear_probe.trainable_weights)
        self.probe_optimizer.apply_gradients(
            zip(gradients, self.linear_probe.trainable_weights)
        )
        self.probe_loss_tracker.update_state(probe_loss)
        self.probe_accuracy.update_state(labels, class_logits)

        return {m.name: m.result() for m in self.metrics}

    def test_step(self, data):
        labeled_images, labels = data

        # For testing the components are used with a training=False flag
        preprocessed_images = self.classification_augmenter(
            labeled_images, training=False
        )
        features = self.encoder(preprocessed_images, training=False)
        class_logits = self.linear_probe(features, training=False)
        probe_loss = self.probe_loss(labels, class_logits)
        self.probe_loss_tracker.update_state(probe_loss)
        self.probe_accuracy.update_state(labels, class_logits)

        # Only the probe metrics are logged at test time
        return {m.name: m.result() for m in self.metrics[2:]}


# Contrastive pretraining
pretraining_model = ContrastiveModel()
pretraining_model.compile(
    contrastive_optimizer=keras.optimizers.Adam(),
    probe_optimizer=keras.optimizers.Adam(),
)

logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
callbacks = [
    tf.keras.callbacks.EarlyStopping(
        monitor="val_p_loss", patience=5, restore_best_weights=True
    ),
    tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1),
]

# Fit pretraining_model
pretraining_history = pretraining_model.fit(
    train_dataset,
    epochs=num_epochs,
    validation_data=validation_dataset,
    callbacks=callbacks,
)

print(
    "Maximal validation accuracy: {:.2f}%".format(
        max(pretraining_history.history["val_p_acc"]) * 100
    )
)

## Supervised finetuning of the pretrained encoder

In [None]:
# Supervised finetuning of the pretrained encoder
finetuning_model = keras.Sequential(
    [
        pretraining_model.encoder,
        layers.Dense(7),
    ],
    name="finetuning_model",
)

# Compile finetuning_model
finetuning_model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False),
    metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc")],
)

logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
callbacks = [
    tf.keras.callbacks.EarlyStopping(
        monitor="val_loss", patience=5, restore_best_weights=True
    ),
    tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1),
]

# Fit finetuning_model
finetuning_history = finetuning_model.fit(
    labeled_train_dataset,
    epochs=num_epochs,
    validation_data=validation_dataset,
    callbacks=callbacks,
)

# Save model for later use
finetuning_model.save("ssl_model_finetuned.keras")

print(
    "Maximal validation accuracy: {:.2f}%".format(
        max(finetuning_history.history["val_acc"]) * 100
    )
)

## Comparison against the baseline

In [None]:
# The classification accuracies of the baseline and the pretraining + finetuning process:
def plot_training_curves(pretraining_history, finetuning_history, baseline_history):
    for metric_key, metric_name in zip(["acc", "loss"], ["accuracy", "loss"]):
        plt.figure(figsize=(8, 5), dpi=100)
        plt.plot(
            baseline_history.history[f"val_{metric_key}"],
            label="supervised baseline",
        )
        plt.plot(
            pretraining_history.history[f"val_p_{metric_key}"],
            label="self-supervised pretraining",
        )
        plt.plot(
            finetuning_history.history[f"val_{metric_key}"],
            label="supervised finetuning",
        )
        plt.legend()
        plt.title(f"Classification {metric_name} during training")
        plt.xlabel("epochs")
        plt.ylabel(f"validation {metric_name}")
        plt.show()


plot_training_curves(pretraining_history, finetuning_history, baseline_history)

In [None]:
# Evaluate the base model
results_baseline = baseline_model.evaluate(test_dataset, verbose=0)
print(f"Baseline Test Accuracy: {np.round(results_baseline[1] * 100,2)}%")
# Evaluate the finetuned model
results_finetuned = finetuning_model.evaluate(test_dataset, verbose=0)
print(f"Finetuned Test Accuracy: {np.round(results_finetuned[1] * 100,2)}%")

In [None]:
folder = "../data_ssl/train/"
labels = sorted(
    [name for name in os.listdir(folder) if os.path.isdir(os.path.join(folder, name))]
)

y_pred_baseline = []  # store predicted labels
y_pred_finetuned = []  # store predicted labels
y_true = []  # store true labels

# iterate over the dataset
for image_batch, label_batch in test_dataset:  # use dataset.unbatch() with repeat
    # append true labels
    y_true.append(label_batch)
    # compute predictions
    preds = baseline_model.predict(image_batch, verbose=0)
    # append predicted labels
    y_pred_baseline.append(np.argmax(preds, axis=-1))

    # compute predictions
    preds = finetuning_model.predict(image_batch, verbose=0)
    # append predicted labels
    y_pred_finetuned.append(np.argmax(preds, axis=-1))

# convert the true and predicted labels into tensors
correct_labels = tf.concat([item for item in y_true], axis=0)
predicted_labels_baseline = tf.concat([item for item in y_pred_baseline], axis=0)
predicted_labels_finetuned = tf.concat([item for item in y_pred_finetuned], axis=0)

# Generate reports
matrix_baseline = confusion_matrix(correct_labels, predicted_labels_baseline)
matrix_finetuned = confusion_matrix(correct_labels, predicted_labels_finetuned)
report_baseline = classification_report(
    correct_labels,
    predicted_labels_baseline,
    target_names=labels,
    zero_division=0,
)
report_finetuned = classification_report(
    correct_labels,
    predicted_labels_finetuned,
    target_names=labels,
    zero_division=0,
)

In [None]:
# Plot confusion matrix
fig = plt.figure(figsize=(10, 10))
sns.heatmap(matrix_baseline, annot=True, cmap="viridis", fmt="g")
plt.xticks(ticks=np.arange(7) + 0.5, labels=labels, rotation=90)
plt.yticks(ticks=np.arange(7) + 0.5, labels=labels, rotation=0)
plt.title("Confusion Matrix (Baseline)")
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.show()

fig = plt.figure(figsize=(10, 10))
sns.heatmap(matrix_finetuned, annot=True, cmap="viridis", fmt="g")
plt.xticks(ticks=np.arange(7) + 0.5, labels=labels, rotation=90)
plt.yticks(ticks=np.arange(7) + 0.5, labels=labels, rotation=0)
plt.title("Confusion Matrix (Finetuned)")
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.show()

In [None]:
# Print classification report
print("Classification Report (Baseline):\n", report_baseline)
print("Classification Report (Finetuned):\n", report_finetuned)

In [None]:
# Load TensorBoard
%tensorboard --logdir logs