In [1]:
import tensorflow as tf
import re
import yaml
from tqdm import tqdm
from model import ResNetSimCLR
from augmentation.augment_helper import crop_and_resize_and_flip, color_distort, rotate, cutout, gaussian_noise, crop_and_resize
from tensorflow.keras.datasets import cifar100

In [2]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

Num GPUs Available:  1


In [3]:
def augment_image(image, training=True):
    """
    Apply a series of augmentations for SimCLR suitable for CIFAR-100.

    Args:
    - image: An image tensor.
    - training: Flag to indicate if it's training phase.

    Returns:
    - Augmented image tensor.
    """
    if training:
        # Apply a series of random transformations for training
        # Resize slightly larger for augmentation and then resize back
        image = tf.image.resize(image, [40, 40])  # Resize to 40x40 for augmentation
        # print(image.shape)
        image = crop_and_resize_and_flip(image, 32, 32)  # Resize back to 32x32 after augmentation
        # print(image.shape)
        image = color_distort(image)
        # print(image.shape)
        image = rotate(image)
        image = cutout(image, 10, 3)  # Adjusted patch size for CIFAR-100
        image = gaussian_noise(image)
    else:
        # For testing, resize to the original CIFAR-100 size without further augmentation
        image = tf.image.resize(image, [32, 32])
    return image


In [4]:
def preprocess_for_simclr(image, label, training=True):
    """Preprocesses and applies augmentation for SimCLR."""
    image = tf.cast(image, tf.float32) / 255.0  # Normalize the image
    if training:
        # Create two augmented versions of the image for training
        augmented_image1 = augment_image(image, training=True)
        augmented_image2 = augment_image(image, training=True)
        return (augmented_image1, augmented_image2), label
    else:
        # For testing, resize to the original CIFAR-100 size without further augmentation
        augmented_image = tf.image.resize(image, [32, 32])
        return augmented_image, label
    
def load_dataset(images, labels, batch_size, training=True):
    """Creates a tf.data.Dataset object for CIFAR-100 with SimCLR augmentations."""
    dataset = tf.data.Dataset.from_tensor_slices((images, labels))
    dataset = dataset.map(lambda x, y: preprocess_for_simclr(x, y, training), 
                          num_parallel_calls=tf.data.experimental.AUTOTUNE)
    if training:
        dataset = dataset.shuffle(10000)
    dataset = dataset.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
    return dataset


In [5]:
def contrastive_loss(temperature=0.1):
    def loss_fn(z_i, z_j):
        """
        Calculate the NT-Xent loss.

        Parameters:
        - z_i, z_j: The outputs from the two augmented views of the images, with shapes (batch_size, feature_dim).

        Returns:
        - Scalar loss value.
        """
        batch_size = tf.shape(z_i)[0]
        z = tf.concat([z_i, z_j], axis=0)
        z = tf.math.l2_normalize(z, axis=1)
        
        similarity_matrix = tf.matmul(z, z, transpose_b=True)
        similarity_matrix = similarity_matrix / temperature

        labels = tf.range(batch_size)
        labels = tf.concat([labels, labels], axis=0)

        # Create a mask to remove positive samples from the diagonals
        mask = tf.one_hot(labels, 2 * batch_size)
        logits_mask = tf.logical_not(tf.eye(2 * batch_size, dtype=tf.bool))
        masked_similarity_matrix = tf.boolean_mask(similarity_matrix, logits_mask)

        # Reshape the masked similarity matrix to the correct shape for cross-entropy
        masked_similarity_matrix = tf.reshape(masked_similarity_matrix, (2 * batch_size, -1))

        loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=masked_similarity_matrix, labels=labels)
        loss = tf.reduce_mean(loss)

        return loss
    return loss_fn


In [6]:
class LARSOptimizer(tf.keras.optimizers.Optimizer):
    def __init__(self, learning_rate, momentum=0.9, use_nesterov=False,
                 weight_decay=0.0, exclude_from_weight_decay=None,
                 exclude_from_layer_adaptation=None, classic_momentum=True,
                 eeta=0.001, name="LARSOptimizer", **kwargs):
        """Constructs a LARSOptimizer."""
        super(LARSOptimizer, self).__init__(name, **kwargs)

        self.learning_rate = learning_rate
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.use_nesterov = use_nesterov
        self.classic_momentum = classic_momentum
        self.eeta = eeta
        self.exclude_from_weight_decay = exclude_from_weight_decay
        if exclude_from_layer_adaptation:
            self.exclude_from_layer_adaptation = exclude_from_layer_adaptation
        else:
            self.exclude_from_layer_adaptation = exclude_from_weight_decay

    def _create_slots(self, var_list):
        for var in var_list:
            self.add_slot(var, "momentum")

    def _resource_apply_dense(self, grad, var):
        var_dtype = var.dtype.base_dtype
        lr_t = self.learning_rate
        momentum = self.get_slot(var, "momentum")

        if self._use_weight_decay(var.name):
            grad += self.weight_decay * var

        # Update logic
        # ...

        return None  # Return value is not used

    def _use_weight_decay(self, param_name):
        """Whether to use L2 weight decay for `param_name`."""
        if not self.weight_decay:
            return False
        if self.exclude_from_weight_decay:
            for r in self.exclude_from_weight_decay:
                if re.search(r, param_name) is not None:
                    return False
        return True

    def _do_layer_adaptation(self, param_name):
        """Whether to do layer-wise learning rate adaptation for `param_name`."""
        if self.exclude_from_layer_adaptation:
            for r in self.exclude_from_layer_adaptation:
                if re.search(r, param_name) is not None:
                    return False
        return True

    def get_config(self):
        config = super(LARSOptimizer, self).get_config()
        config.update({
            "learning_rate": self.learning_rate,
            "momentum": self.momentum,
            "weight_decay": self.weight_decay,
            "use_nesterov": self.use_nesterov,
            "classic_momentum": self.classic_momentum,
            "eeta": self.eeta,
            "exclude_from_weight_decay": self.exclude_from_weight_decay,
            "exclude_from_layer_adaptation": self.exclude_from_layer_adaptation
        })
        return config

In [7]:
# Load configuration from YAML file
config = yaml.load(open('config.yaml', 'r'), Loader=yaml.FullLoader)
batch_size = config['batch_size']

# Prepare the training and  dataset
(train_images, train_labels), (test_images, test_labels) = cifar100.load_data()
train_dataset = load_dataset(train_images, train_labels, batch_size, training=True)
test_dataset = load_dataset(test_images, test_labels, batch_size, training=False)

# Initialize the SimCLR model with specified input and output dimensions
if config['input_size'] == 32:
    input_size = (32, 32, 3)
model = ResNetSimCLR(input_size, config['output_size'])

# Initialize the contrastive loss function with model and temperature
loss_fn = contrastive_loss(temperature=config['temperature'])

# Training loop setup
epochs = config['epochs'] 

# Set optimizer for training
learning_rate_t = tf.keras.optimizers.schedules.CosineDecay(
    initial_learning_rate=config['learning_rate'],
    decay_steps=epochs,
    alpha=0.0
)
optimizer = LARSOptimizer(
    config['learning_rate'],
    momentum=0.9,
    weight_decay=config['weight_decay'])
# optimizer = tf.keras.optimizers.SGD(learning_rate=config['learning_rate'])

In [8]:
for epoch in range(epochs):
    total_loss = 0
    num_batches = 0

    for images, _ in tqdm(train_dataset):
        with tf.GradientTape() as tape:
            # Forward pass through the model for both sets of augmented images
            _, proj1 = model(images[0], training=True)
            _, proj2 = model(images[1], training=True)

            # Calculate loss
            loss = loss_fn(proj1, proj2)

        # Compute and apply gradients
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

        # Accumulate loss for reporting
        total_loss += loss
        num_batches += 1

    # Calculate and display average loss for the epoch
    avg_loss = total_loss / num_batches
    print(f"Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}")

# Save the trained model
model.save('./saved_models')

total_test_loss = 0
num_test_batches = 0

for test_images, _ in tqdm(test_dataset):
    # Forward pass through the model for test images
    _, test_proj = model(test_images, training=False)

    # Assuming a similar loss function for testing, or you can use a different metric
    test_loss = loss_fn(test_proj)

    total_test_loss += test_loss
    num_test_batches += 1

# Calculate and display average test loss
avg_test_loss = total_test_loss / num_test_batches
print(f"Average Test Loss: {avg_test_loss:.4f}")

100%|██████████| 391/391 [02:05<00:00,  3.13it/s]


Epoch 1/100, Loss: 5.5518


100%|██████████| 391/391 [02:02<00:00,  3.18it/s]


Epoch 2/100, Loss: 5.5504


100%|██████████| 391/391 [02:00<00:00,  3.24it/s]


Epoch 3/100, Loss: 5.5507


100%|██████████| 391/391 [01:57<00:00,  3.31it/s]


Epoch 4/100, Loss: 5.5508


100%|██████████| 391/391 [01:57<00:00,  3.32it/s]


Epoch 5/100, Loss: 5.5504


100%|██████████| 391/391 [01:58<00:00,  3.31it/s]


Epoch 6/100, Loss: 5.5506


100%|██████████| 391/391 [01:57<00:00,  3.32it/s]


Epoch 7/100, Loss: 5.5508


100%|██████████| 391/391 [01:57<00:00,  3.32it/s]


Epoch 8/100, Loss: 5.5504


100%|██████████| 391/391 [01:57<00:00,  3.31it/s]


Epoch 9/100, Loss: 5.5512


100%|██████████| 391/391 [01:58<00:00,  3.31it/s]


Epoch 10/100, Loss: 5.5497


100%|██████████| 391/391 [01:57<00:00,  3.32it/s]


Epoch 11/100, Loss: 5.5500


100%|██████████| 391/391 [01:57<00:00,  3.32it/s]


Epoch 12/100, Loss: 5.5509


100%|██████████| 391/391 [01:58<00:00,  3.31it/s]


Epoch 13/100, Loss: 5.5499


100%|██████████| 391/391 [01:57<00:00,  3.32it/s]


Epoch 14/100, Loss: 5.5499


100%|██████████| 391/391 [01:58<00:00,  3.31it/s]


Epoch 15/100, Loss: 5.5509


100%|██████████| 391/391 [01:57<00:00,  3.32it/s]


Epoch 16/100, Loss: 5.5505


100%|██████████| 391/391 [01:58<00:00,  3.31it/s]


Epoch 17/100, Loss: 5.5511


100%|██████████| 391/391 [01:57<00:00,  3.32it/s]


Epoch 18/100, Loss: 5.5501


100%|██████████| 391/391 [01:58<00:00,  3.31it/s]


Epoch 19/100, Loss: 5.5503


100%|██████████| 391/391 [01:57<00:00,  3.32it/s]


Epoch 20/100, Loss: 5.5504


100%|██████████| 391/391 [01:57<00:00,  3.32it/s]


Epoch 21/100, Loss: 5.5506


100%|██████████| 391/391 [01:57<00:00,  3.32it/s]


Epoch 22/100, Loss: 5.5509


100%|██████████| 391/391 [01:58<00:00,  3.31it/s]


Epoch 23/100, Loss: 5.5500


100%|██████████| 391/391 [01:57<00:00,  3.32it/s]


Epoch 24/100, Loss: 5.5505


100%|██████████| 391/391 [01:57<00:00,  3.32it/s]


Epoch 25/100, Loss: 5.5506


100%|██████████| 391/391 [01:57<00:00,  3.32it/s]


Epoch 26/100, Loss: 5.5513


100%|██████████| 391/391 [01:57<00:00,  3.32it/s]


Epoch 27/100, Loss: 5.5507


100%|██████████| 391/391 [01:57<00:00,  3.32it/s]


Epoch 28/100, Loss: 5.5500


100%|██████████| 391/391 [01:57<00:00,  3.32it/s]


Epoch 29/100, Loss: 5.5504


100%|██████████| 391/391 [01:57<00:00,  3.32it/s]


Epoch 30/100, Loss: 5.5504


100%|██████████| 391/391 [01:57<00:00,  3.31it/s]


Epoch 31/100, Loss: 5.5512


100%|██████████| 391/391 [01:57<00:00,  3.32it/s]


Epoch 32/100, Loss: 5.5497


100%|██████████| 391/391 [01:57<00:00,  3.32it/s]


Epoch 33/100, Loss: 5.5500


100%|██████████| 391/391 [01:58<00:00,  3.31it/s]


Epoch 34/100, Loss: 5.5498


100%|██████████| 391/391 [01:57<00:00,  3.32it/s]


Epoch 35/100, Loss: 5.5507


100%|██████████| 391/391 [01:58<00:00,  3.31it/s]


Epoch 36/100, Loss: 5.5512


100%|██████████| 391/391 [01:57<00:00,  3.32it/s]


Epoch 37/100, Loss: 5.5517


100%|██████████| 391/391 [01:57<00:00,  3.32it/s]


Epoch 38/100, Loss: 5.5509


100%|██████████| 391/391 [02:02<00:00,  3.19it/s]


Epoch 39/100, Loss: 5.5508


100%|██████████| 391/391 [01:58<00:00,  3.31it/s]


Epoch 40/100, Loss: 5.5516


100%|██████████| 391/391 [01:57<00:00,  3.31it/s]


Epoch 41/100, Loss: 5.5509


100%|██████████| 391/391 [01:57<00:00,  3.32it/s]


Epoch 42/100, Loss: 5.5506


100%|██████████| 391/391 [02:01<00:00,  3.22it/s]


Epoch 43/100, Loss: 5.5499


100%|██████████| 391/391 [02:01<00:00,  3.22it/s]


Epoch 44/100, Loss: 5.5505


100%|██████████| 391/391 [02:00<00:00,  3.25it/s]


Epoch 45/100, Loss: 5.5501


100%|██████████| 391/391 [01:58<00:00,  3.31it/s]


Epoch 46/100, Loss: 5.5499


100%|██████████| 391/391 [01:57<00:00,  3.32it/s]


Epoch 47/100, Loss: 5.5511


100%|██████████| 391/391 [01:57<00:00,  3.32it/s]


Epoch 48/100, Loss: 5.5512


100%|██████████| 391/391 [01:57<00:00,  3.32it/s]


Epoch 49/100, Loss: 5.5513


100%|██████████| 391/391 [01:58<00:00,  3.31it/s]


Epoch 50/100, Loss: 5.5502


100%|██████████| 391/391 [01:57<00:00,  3.32it/s]


Epoch 51/100, Loss: 5.5514


100%|██████████| 391/391 [01:57<00:00,  3.33it/s]


Epoch 52/100, Loss: 5.5506


100%|██████████| 391/391 [01:58<00:00,  3.30it/s]


Epoch 53/100, Loss: 5.5502


100%|██████████| 391/391 [01:57<00:00,  3.32it/s]


Epoch 54/100, Loss: 5.5509


100%|██████████| 391/391 [01:58<00:00,  3.31it/s]


Epoch 55/100, Loss: 5.5509


100%|██████████| 391/391 [01:57<00:00,  3.32it/s]


Epoch 56/100, Loss: 5.5507


 92%|█████████▏| 360/391 [01:49<00:09,  3.29it/s]