In [1]:
import tensorflow as tf
import re
import yaml
from tqdm import tqdm
from sklearn.linear_model import LogisticRegression
import numpy as np
from model import ResNetSimCLR
from augmentation.augment_helper import crop_and_resize_and_flip, color_distort, rotate, cutout, gaussian_noise, crop_and_resize
from tensorflow.keras.datasets import cifar100
from tensorflow.keras.layers import *
from tensorflow.keras.models import *

In [2]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

Num GPUs Available:  1


In [3]:
@tf.function
def augment_image(image, training=True):
    """
    Apply a series of augmentations for SimCLR suitable for CIFAR-100.

    Args:
    - image: An image tensor.
    - training: Flag to indicate if it's training phase.

    Returns:
    - Augmented image tensor.
    """
    if training:
        # Apply a series of random transformations for training
        # Resize slightly larger for augmentation and then resize back
        image = tf.image.resize(image, [40, 40])  # Resize to 40x40 for augmentation
        # print(image.shape)
        image= crop_and_resize(image, 32, 32)  # Resize back to 32x32 after augmentation
        print(image.shape)
        image = color_distort(image)
        print(image.shape)
        # image = rotate(image)
        # print(image.shape)
        # image = cutout(image, 10, 3)  # Adjusted patch size for CIFAR-100
        # print(image.shape)
        # image = gaussian_noise(image)
        # print(image.shape)
    else:
        # For testing, resize to the original CIFAR-100 size without further augmentation
        image = tf.image.resize(image, [32, 32])
    return image

In [4]:
@tf.function
def preprocess_for_simclr(image, label, training=True):
    """Preprocesses and applies augmentation for SimCLR."""
    image = tf.cast(image, tf.float32) / 255.0  # Normalize the image
    if training:
        # Create two augmented versions of the image for training
        augmented_image1 = augment_image(image, training=True)
        augmented_image2 = augment_image(image, training=True)
        return (augmented_image1, augmented_image2), label
    else:
        # For testing, resize to the original CIFAR-100 size without further augmentation
        augmented_image = tf.image.resize(image, [32, 32])
        return augmented_image, label
    
def load_dataset(images, labels, batch_size, training=True):
    """Creates a tf.data.Dataset object for CIFAR-100 with SimCLR augmentations."""
    dataset = tf.data.Dataset.from_tensor_slices((images, labels))
    dataset = dataset.map(lambda x, y: preprocess_for_simclr(x, y, training), 
                          num_parallel_calls=tf.data.experimental.AUTOTUNE)
    if training:
        dataset = dataset.shuffle(1024)
    dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)
    return dataset


In [5]:
def contrastive_loss(batch_size, temperature=0.1):
    def loss_fn(z_i, z_j):
        """
        Calculate the NT-Xent loss.

        Parameters:
        - z_i, z_j: The outputs from the two augmented views of the images, with shapes (batch_size, feature_dim).

        Returns:
        - Scalar loss value.
        """
        batch_size = tf.shape(z_i)[0]
        z = tf.concat([z_i, z_j], axis=0)
        z = tf.math.l2_normalize(z, axis=1)
        
        similarity_matrix = tf.matmul(z, z, transpose_b=True)
        similarity_matrix = similarity_matrix / temperature

        labels = tf.range(batch_size)
        labels = tf.concat([labels, labels], axis=0)

        # Create a mask to remove positive samples from the diagonals
        mask = tf.one_hot(labels, 2 * batch_size)
        logits_mask = tf.logical_not(tf.eye(2 * batch_size, dtype=tf.bool))
        masked_similarity_matrix = tf.boolean_mask(similarity_matrix, logits_mask)

        # Reshape the masked similarity matrix to the correct shape for cross-entropy
        masked_similarity_matrix = tf.reshape(masked_similarity_matrix, (2 * batch_size, -1))

        loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=masked_similarity_matrix, labels=labels)
        loss = tf.reduce_mean(loss)

        return loss
    return loss_fn


In [6]:
# Load configuration from YAML file
config = yaml.load(open('config.yaml', 'r'), Loader=yaml.FullLoader)
batch_size = config['batch_size']

# Prepare the training and  dataset
(train_images, train_labels), (test_images, test_labels) = cifar100.load_data()
(new_train_images, new_train_labels), (valid_images, valid_labels) = split_train_validation(train_images, train_labels)
train_dataset = load_dataset(new_train_images, new_train_labels, batch_size, training=True)
validation_dataset = load_dataset(valid_images, valid_labels, batch_size, training=False)
test_dataset = load_dataset(test_images, test_labels, batch_size, training=False)

# Initialize the SimCLR model with specified input and output dimensions
if config['input_size'] == 32:
    input_size = (32, 32, 3)
model = ResNetSimCLR(input_size, config['output_size'])


# Initialize the contrastive loss function with model and temperature
loss_fn = contrastive_loss(batch_size, temperature=config['temperature'])

# Training loop setup
epochs = config['epochs'] 

# Set optimizer for training
lr_decayed_fn = tf.keras.experimental.CosineDecay(
    initial_learning_rate=config['learning_rate'], decay_steps=1000)
optimizer = tf.keras.optimizers.SGD(lr_decayed_fn)
# optimizer = tf.keras.optimizers.Adam(learning_rate=config['learning_rate'])

In [7]:
for epoch in range(epochs):
    total_loss = 0
    num_batches = 0

    for images, _ in tqdm(train_dataset):
        with tf.GradientTape() as tape:
            # Forward pass through the model for both sets of augmented images
            _, proj1 = model(images[0], training=True)
            _, proj2 = model(images[1], training=True)

            # Calculate loss
            loss = loss_fn(proj1, proj2)

        # Compute and apply gradients
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

        # Accumulate loss for reporting
        total_loss += loss
        num_batches += 1

    # Calculate and display average loss for the epoch
    avg_loss = total_loss / num_batches
    print(f"Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}")

# Save the trained model
model.save('./saved_models')

100%|██████████| 87/87 [00:27<00:00,  3.16it/s]


Epoch 1/15, Loss: 5.9299


100%|██████████| 87/87 [00:23<00:00,  3.68it/s]


Epoch 2/15, Loss: 4.4038


100%|██████████| 87/87 [00:23<00:00,  3.70it/s]


Epoch 3/15, Loss: 3.7581


100%|██████████| 87/87 [00:23<00:00,  3.68it/s]


Epoch 4/15, Loss: 3.2438


100%|██████████| 87/87 [00:23<00:00,  3.70it/s]


Epoch 5/15, Loss: 2.8405


100%|██████████| 87/87 [00:23<00:00,  3.69it/s]


Epoch 6/15, Loss: 2.6259


100%|██████████| 87/87 [00:23<00:00,  3.70it/s]


Epoch 7/15, Loss: 2.4607


100%|██████████| 87/87 [00:23<00:00,  3.69it/s]


Epoch 8/15, Loss: 2.2994


100%|██████████| 87/87 [00:23<00:00,  3.70it/s]


Epoch 9/15, Loss: 2.2469


100%|██████████| 87/87 [00:23<00:00,  3.70it/s]


Epoch 10/15, Loss: 2.1707


100%|██████████| 87/87 [00:23<00:00,  3.71it/s]


Epoch 11/15, Loss: 2.1441


100%|██████████| 87/87 [00:23<00:00,  3.68it/s]


Epoch 12/15, Loss: 2.1647


100%|██████████| 87/87 [00:23<00:00,  3.72it/s]


Epoch 13/15, Loss: 2.1360


100%|██████████| 87/87 [00:23<00:00,  3.71it/s]


Epoch 14/15, Loss: 2.1307


100%|██████████| 87/87 [00:23<00:00,  3.71it/s]

Epoch 15/15, Loss: 2.1422





INFO:tensorflow:Assets written to: ./saved_models\assets


INFO:tensorflow:Assets written to: ./saved_models\assets


In [19]:
from tensorflow.keras.initializers import GlorotUniform
from tensorflow.keras.optimizers import Adam
# Extract representations
model = tf.keras.models.load_model('./saved_models')

train_x = []
train_y = []
for images, label in validation_dataset:
    rep, _ = model(images, training=False)  # Get the representation
    train_x.append(rep)
    train_y.append(label)

test_x = []
test_y = []
for x, y in test_dataset:
    rep, _ = model(x, training=False)
    test_x.append(rep)
    test_y.append(y)

train_x = tf.concat(train_x, axis=0)
train_y = tf.concat(train_y, axis=0)
test_x = tf.concat(test_x, axis=0)
test_y = tf.concat(test_y, axis=0)
train_x = train_x / np.max(train_x)
test_x = test_x / np.max(test_x)

# Make sure labels are in the correct shape
train_y = tf.squeeze(train_y)
test_y = tf.squeeze(test_y)

# Train a linear classifier on the representations
classifier = Sequential([Dense(5, input_shape=(2048,), activation="softmax", kernel_initializer=GlorotUniform())])
classifier.compile(loss="sparse_categorical_crossentropy", metrics=["accuracy"], optimizer=Adam(learning_rate=1e-4))
classifier.fit(train_x[:100], train_y[:100], epochs=15, batch_size=32)

# Predict on the test set
predicted_labels = classifier.predict(test_x[:500])

# Convert predictions to label indices
predicted_label_indices = np.argmax(predicted_labels, axis=1)

# Calculate accuracy
accuracy = np.mean(predicted_label_indices == test_y[:500].numpy())


Epoch 1/15
Epoch 2/15
Epoch 3/15
Epoch 4/15
Epoch 5/15
Epoch 6/15
Epoch 7/15
Epoch 8/15
Epoch 9/15
Epoch 10/15
Epoch 11/15
Epoch 12/15
Epoch 13/15
Epoch 14/15
Epoch 15/15
