In [1]:
import tensorflow as tf
import yaml
from model import ResNetSimCLR
from augmentation.augment_helper import crop_and_resize_and_flip, color_distort, rotate, cutout, gaussian_noise, crop_and_resize
from tensorflow.keras.datasets import cifar100

In [2]:
def augment_image(image, training=True):
    """
    Apply a series of augmentations for SimCLR suitable for CIFAR-100.

    Args:
    - image: An image tensor.
    - training: Flag to indicate if it's training phase.

    Returns:
    - Augmented image tensor.
    """
    if training:
        # Apply a series of random transformations for training
        # Resize slightly larger for augmentation and then resize back
        image = tf.image.resize(image, [40, 40])  # Resize to 40x40 for augmentation
        image = crop_and_resize(image, 32, 32)  # Resize back to 32x32 after augmentation
        image = color_distort(image)
        image = rotate(image)
        image = cutout(image, 10, 3)  # Adjusted patch size for CIFAR-100
        image = gaussian_noise(image)
    else:
        # For testing, resize to the original CIFAR-100 size without further augmentation
        image = tf.image.resize(image, [32, 32])
    return image


In [3]:
def preprocess_for_simclr(image, label, training=True):
    """Preprocesses and applies augmentation for SimCLR."""
    image = tf.cast(image, tf.float32) / 255.0  # Normalize the image
    if training:
        # Create two augmented versions of the image for training
        augmented_image1 = augment_image(image, training=True)
        augmented_image2 = augment_image(image, training=True)
        return (augmented_image1, augmented_image2), label
    else:
        # For testing, resize to the original CIFAR-100 size without further augmentation
        augmented_image = tf.image.resize(image, [32, 32])
        return augmented_image, label
    
def load_dataset(images, labels, batch_size, training=True):
    """Creates a tf.data.Dataset object for CIFAR-100 with SimCLR augmentations."""
    dataset = tf.data.Dataset.from_tensor_slices((images, labels))
    dataset = dataset.map(lambda x, y: preprocess_for_simclr(x, y, training), 
                          num_parallel_calls=tf.data.experimental.AUTOTUNE)
    if training:
        dataset = dataset.shuffle(10000)
    dataset = dataset.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
    return dataset


In [4]:
def contrastive_loss(batch_size, temperature=0.1):
    # Define contrastive loss function (NT-Xent loss)
    def loss_fn(z_i, z_j):
        """
        Calculate the NT-Xent loss.

        Parameters:
        - z_i, z_j: Outputs from the two augmented views of the images,
                    with shapes (batch_size, feature_dim).

        Returns:
        - Scalar loss value.
        """

        # Concatenate the projections for positive and negative pairs
        z = tf.concat([z_i, z_j], axis=0)

        # Normalize the projections to unit vectors
        z = tf.math.l2_normalize(z, axis=1)

        # Compute cosine similarity as dot product of normalized vectors
        similarity_matrix = tf.matmul(z, z, transpose_b=True)

        # Scale similarity by temperature
        similarity_matrix = similarity_matrix / temperature

        # Create labels for positive pairs (matching augmented images)
        labels = tf.range(batch_size)
        labels = tf.concat([labels, labels], axis=0)

        # Create a mask to exclude self-comparisons (diagonal elements)
        mask = tf.one_hot(labels, 2 * batch_size)
        logits_mask = tf.logical_not(tf.cast(mask, dtype=tf.bool))
        masked_similarity_matrix = tf.boolean_mask(similarity_matrix, logits_mask)

        # Reshape logits for cross-entropy calculation
        masked_similarity_matrix = tf.reshape(masked_similarity_matrix, (2 * batch_size, -1))
        labels = tf.repeat(labels, batch_size * 2 - 1)

        # Compute cross-entropy loss between similarities and labels
        loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=masked_similarity_matrix, labels=labels)

        # Average the loss across the batch
        loss = tf.reduce_mean(loss)

        return loss
    return loss_fn

In [None]:
# Load configuration from YAML file
config = yaml.load(open('config.yaml', 'r'), Loader=yaml.FullLoader)
batch_size = config['batch_size']

# Prepare the training and  dataset
(train_images, train_labels), (test_images, test_labels) = cifar100.load_data()
train_dataset = load_dataset(train_images, train_labels, batch_size, training=True)
test_dataset = load_dataset(test_images, test_labels, batch_size, training=False)

# Initialize the SimCLR model with specified input and output dimensions
if config['input_size'] == 32:
    input_size = (32, 32, 3)
model = ResNetSimCLR(input_size, config['output_size'])

# Initialize the contrastive loss function with model and temperature
loss_fn = contrastive_loss(batch_size, temperature=config['temperature'])

# Cosine decay with linear warmup
lr_schedule = tf.keras.optimizers.schedules.CosineDecay(
    initial_learning_rate=0,
    decay_steps=(config['epochs'] - 10) * len(train_dataset),
    warmup_target=config['learning_rate'],
    warmup_steps=10 * len(train_dataset))

# Set optimizer for training
optimizer = tf.optimization.lars.LARS(learning_rate=lr_schedule, weight_decay_rate=config['weight_decay'])

# Training loop setup
epochs = config['epochs'] 

In [None]:
for epoch in range(epochs):
    total_loss = 0
    num_batches = 0

    for images, _ in train_dataset:
        with tf.GradientTape() as tape:
            # Forward pass through the model for both sets of augmented images
            _, proj1 = model(images[0], training=True)
            _, proj2 = model(images[1], training=True)

            # Calculate loss
            loss = loss_fn(proj1, proj2)

        # Compute and apply gradients
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

        # Accumulate loss for reporting
        total_loss += loss
        num_batches += 1

    # Calculate and display average loss for the epoch
    avg_loss = total_loss / num_batches
    print(f"Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}")

# Save the trained model
model.save('./saved_models')

total_test_loss = 0
num_test_batches = 0

for test_images, _ in test_dataset:
    # Forward pass through the model for test images
    _, test_proj = model(test_images, training=False)

    # Assuming a similar loss function for testing, or you can use a different metric
    test_loss = loss_fn(test_proj)

    total_test_loss += test_loss
    num_test_batches += 1

# Calculate and display average test loss
avg_test_loss = total_test_loss / num_test_batches
print(f"Average Test Loss: {avg_test_loss:.4f}")