### SimCLR attempt using Vision Transformer

In [None]:
import os
os.environ["LD_LIBRARY_PATH"]='/opt/conda/lib'
os.environ["XLA_FLAGS"] = "--xla_gpu_cuda_data_dir=/opt/conda/pkgs/cuda-toolkit"

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import tensorflow as tf
import tensorflow.keras as keras

In [None]:
(x_train, y_train), (x_val, y_val) = keras.datasets.mnist.load_data()

x_train = x_train / 255.0
x_val = x_val / 255.0

In [None]:
x_train.shape

In [None]:
plt.figure(figsize=(3, 3))
plt.imshow(x_train[0], cmap="gray")
plt.show()
y_train[0]

In [None]:
a = x_train[0]
a.shape
plt.imshow(a)
plt.show()
a.shape

In [None]:
def crop_and_resize(image):
    crop_percent = np.random.uniform(0.4, 1)

    if crop_percent == 1:
        return image
        
    # Get width and height from the image shape
    width = image.shape[0]
    height = image.shape[1]

    # Create new dimensions for the image
    crop_width = int(width * crop_percent)
    crop_height = int(height * crop_percent)
    print(crop_width, crop_height, '\n')

    cropped = tf.image.random_crop(tf.expand_dims(image, -1), (crop_width, crop_height, 1))
    cropped = tf.image.resize(cropped, (image.shape[0], image.shape[1]))
    return cropped

b = crop_and_resize(a)
b.shape
plt.imshow(b)
plt.show()
a.shape

In [None]:
class DataGenerator(tf.keras.utils.Sequence):
    def __init__(self, x, y, batch_size=32, num_batches=100, rng: np.random.Generator = np.random.default_rng(), min_crop = .5, max_crop = 1):
        super().__init__()

        # Group by labels
        groups = [[] for _ in range(len(np.unique(y)))]
        for x_i, y_i in zip(x, y):
            groups[y_i].append(x_i)
        groups = list(map(np.array, groups))

        self.groups = groups
        self.batch_size = batch_size
        self.num_batches = num_batches
        self.rng = rng

        self.min_crop = min_crop
        self.max_crop = max_crop

        self.shuffle()

    def shuffle(self):
        self.labels = self.rng.integers(len(self.groups), size=(self.num_batches, self.batch_size))   # Which labels to draw
        self.variants = self.rng.uniform(size=(self.num_batches, 2, self.batch_size)) # Which variant of the label

    # Function to randomly crop the image
    def crop_and_resize(self, image):
        crop_percent = np.random.uniform(self.min_crop, self.max_crop)
    
        if crop_percent == 1:
            return image
            
        # Get width and height from the image shape
        width = image.shape[0]
        height = image.shape[1]
    
        # Create new dimensions for the image
        crop_width = int(width * crop_percent)
        crop_height = int(height * crop_percent)
    
        cropped = tf.image.random_crop(image, (crop_width, crop_height, 1))
        # cropped = tf.squeeze(tf.image.resize(cropped, image.shape))
        
        return cropped


    def __getitem__(self, batch_index):
        batch = ([], [])
        for label, variant_a, variant_b in zip(self.labels[batch_index], *self.variants[batch_index]):
            group = self.groups[label]
            variant_a_index = int(len(group)*variant_a)
            # Instead of using two different batches, use the same with different augmentations:
            image_a = self.crop_and_resize(group[variant_a_index])
            image_b = self.crop_and_resize(group[variant_a_index])

            batch[0].append(image_a)
            batch[1].append(image_b)
            
        return tuple(np.array(batch)) # batch

    def __len__(self):
        return self.num_batches

    def on_epoch_end(self):
        self.shuffle()

In [None]:
# Generator Testing/Debugging
# [batch, x/y, set, image]
data = DataGenerator(x_train, y_train, rng=np.random.default_rng(0), min_crop = 0.4, max_crop = 1)
batch_index = 0
index = 0
plt.subplot(121)
plt.imshow(data[batch_index][0][index])
plt.subplot(122)
plt.imshow(data[batch_index][1][index])
plt.show()
data[batch_index][0][index].shape

### ViT Encoders

In [None]:
## Standardize the -input- data between 0.0-1.0 (real)
## instead of the default 0-255 (integer)
x_train = np.expand_dims(x_train,-1)
x_val = np.expand_dims(x_val,-1)

display(x_train.shape)
display(y_train.shape)
display(x_val.shape)
display(y_val.shape)

In [None]:
class TransformerBlock(keras.layers.Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
        super(TransformerBlock, self).__init__()
        self.att = keras.layers.MultiHeadAttention(num_heads=num_heads,
                                                   key_dim=embed_dim)
        self.ffn = keras.Sequential(
            [keras.layers.Dense(ff_dim, activation="gelu"),
             keras.layers.Dense(embed_dim),]
        )
        self.layernorm1 = keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = keras.layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = keras.layers.Dropout(rate)
        self.dropout2 = keras.layers.Dropout(rate)

    def compute_output_shape(self, input_shape):
        return input_shape
        
    def call(self, x, training):
        y = x
        y = self.layernorm1(y, training=training)
        y, scores = self.att(y, y, return_attention_scores=True, training=training)
        y = self.dropout1(y, training=training)
        x += y
        y = x
        y = self.layernorm2(y, training=training)
        y = self.ffn(y, training=training)
        y = self.dropout2(y, training=training)
        return (x + y, scores)


In [None]:
class PositionEmbedding(keras.layers.Layer):
    def __init__(self, maxlen, embed_dim):
        super(PositionEmbedding, self).__init__()
        self.pos_emb = keras.layers.Embedding(input_dim=maxlen, output_dim=embed_dim)

    def call(self, x):
        maxlen = tf.shape(x)[-2] # x already embedded
        positions = tf.range(start=0, limit=maxlen, delta=1)
        positions = self.pos_emb(positions)
        return x + positions

In [None]:
embed_dim = 128
kernel = 4
num_heads = 6
ff_dim = 512
stack = 5

# Note the input size (there is only one channel - intensity)
# these images... if you are using color images, your would
# need to set the last dimension of the input_shape to -3-
# above and this would carry over into this cell...
y = x = keras.layers.Input(shape=x_train.shape[1:])
# This layer will just be passed a constant integer for
# embedding (class token - see the ViT paper)
# input_class = keras.layers.Input(shape=(1,))

# Patches
y = keras.layers.Conv2D(embed_dim,
                        kernel_size=(kernel,kernel),
                        strides=(kernel,kernel))(y)
# Hybrid CNN
# y = keras.layers.Conv2D(embed_dim*2,
#                         kernel_size=(5, 5),
#                         strides=2,
#                         activation='gelu')(y)
# y = keras.layers.Conv2D(embed_dim,
#                         kernel_size=(5, 5),
#                         strides=2,
#                         activation='gelu')(y)

# Flatten 2D arrangement to 1D arrangement of tokens
y = keras.layers.Reshape((-1,embed_dim))(y)
y = PositionEmbedding(y.shape[-2],embed_dim)(y)

# Create class token
c = keras.layers.Lambda(lambda x: tf.tile(tf.constant([[0]]),(tf.shape(x)[0],1)))(y)
c = keras.layers.Embedding(input_dim=1,output_dim=embed_dim)(c)

# Prepend class token
y = keras.layers.Concatenate(axis=1)([c,y])

for _ in range(stack):
    y = TransformerBlock(embed_dim, num_heads, ff_dim)(y)[0]
# Lambda layer is like the ViT the paper...
# y = keras.layers.Lambda(lambda x: x[:,0,:])(y)
# Use the following -instead- of Lambda is common in other
# models I have seen... probably not important though
y = keras.layers.GlobalAveragePooling1D()(y)
y = keras.layers.Dropout(0.1)(y)
y = keras.layers.Dense(embed_dim,activation='gelu')(y)
#    y = keras.layers.Dropout(0.5)(y)
y = keras.layers.Dense(10)(y)

model = keras.Model(x,y)
model.compile(optimizer=keras.optimizers.Nadam(learning_rate=0.0001),
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=keras.metrics.SparseCategoricalAccuracy())
model.summary()
keras.utils.plot_model(model,show_shapes=True,expand_nested=True)

masked_encoder = model
unmasked_encoder = model

### Note: This metric only *correlates* with accuracy

This is an approximate accuracy which doesn't really track the model's actual performance, and is just used here as a metric to show progress.

In [None]:
accuracy = keras.metrics.SparseCategoricalAccuracy()
def ContrastiveAccuracy(y_true, y_pred):
    # return accuracy(y_true[tf.argmax(y_pred,-1)],
    #                 y_pred)
    acc_masked = accuracy(y_true, y_pred)
    acc_unmasked = accuracy(y_true, tf.transpose(y_pred))
    acc = (acc_masked + acc_unmasked) / 2.0
    return acc

In [None]:
class ContrastiveModel(tf.keras.Model):
    def __init__(
        self,
        masked_encoder: tf.keras.Model,
        unmasked_encoder: tf.keras.Model,
        embed_dim: int = 512,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.masked_encoder = masked_encoder
        self.unmasked_encoder = unmasked_encoder

#         self.masked_encoder.trainable = False
#         self.unmasked_encoder.trainable = False

        self.embed_dim = embed_dim
        self.W_masked = tf.keras.layers.Dense(
            self.embed_dim,
            # input_shape=self.masked_encoder.output_shape[-1],
            use_bias=False)
        self.W_unmasked = tf.keras.layers.Dense(
            self.embed_dim,
            # input_shape=self.unmasked_encoder.output_shape[-1],
            use_bias=False)
        self.t = self.add_weight(
            name="Temperature",
            shape=None,
            trainable=True
        )

    def compile(self, *args, **kwargs):
        return super().compile(
            *args,
            loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
            **kwargs)

    def test_step(self, data):
        n = tf.shape(data[0])[0]
        y_true = tf.range(n)
        y_pred = self(data, training=False)
        loss_masked = self.compiled_loss(y_true, y_pred)
        loss_unmasked = self.compiled_loss(y_true, tf.transpose(y_pred))
        loss = (loss_masked + loss_unmasked) / 2.0
        self.compiled_metrics.update_state(y_true, y_pred)
        return {m.name: m.result() for m in self.metrics}
    
    def train_step(self, data):
        n = tf.shape(data[0])[0]
        y_true = tf.range(n)
        with tf.GradientTape() as tape:
            y_pred = self(data, training=True)
            loss_masked = self.compiled_loss(y_true, y_pred)
            loss_unmasked = self.compiled_loss(y_true, tf.transpose(y_pred))
            loss = (loss_masked + loss_unmasked) / 2.0
        grads = tape.gradient(loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.compiled_metrics.update_state(y_true, y_pred)
        return {m.name: m.result() for m in self.metrics}

    def call(self, inputs, training=False):
        # Get the images from input
        masked_images, unmasked_images = inputs[0], inputs[1]

        # Embed them using the encoders
        masked_features = self.masked_encoder(masked_images)
        unmasked_features = self.unmasked_encoder(unmasked_images)

        # Joint multimodal embedding
        masked_embeddings = self.W_masked(masked_features)
        unmasked_embeddings = self.W_unmasked(unmasked_features)

        # Normalize
        masked_embeddings = masked_embeddings / tf.norm(masked_embeddings)
        unmasked_embeddings = unmasked_embeddings / tf.norm(unmasked_embeddings)

        logits = tf.tensordot(masked_embeddings, tf.transpose(unmasked_embeddings), axes=1) * tf.exp(self.t)

        return logits

In [None]:
cm = ContrastiveModel(masked_encoder, unmasked_encoder)
cm.compile(
    optimizer=tf.keras.optimizers.Adam(),
    metrics=[
        ContrastiveAccuracy
    ],
    run_eagerly=False # Set to true to debug
)

In [None]:
data[0][1][:4].shape

In [None]:
cm((data[0][0][:4],data[0][1][:4]))

In [None]:
training_data = DataGenerator(x_train, y_train, batch_size=32, rng=np.random.default_rng(0), min_crop = 0.4, max_crop = 1)
# No cropping
validation_data = DataGenerator(x_val, y_val, batch_size=32, rng=np.random.default_rng(0), min_crop = 1, max_crop = 1)

In [None]:
len(training_data[0][0])

In [None]:
cm(validation_data[0])

In [None]:
ContrastiveAccuracy(np.arange(validation_data[0][0].shape[0]),cm(validation_data[0]))

In [None]:
callback = tf.keras.callbacks.EarlyStopping(monitor = "val_loss", mode = 'min', patience = 15)
history = cm.fit(training_data,
                 epochs=300,
                 verbose=1,
                 validation_data=validation_data,
                 callbacks=[callback]
                )

In [None]:
cm.summary()

In [None]:
with open("training_history.txt", 'w') as file:
    file.write(f"Epochs: {len(history.history['loss'])}\n")
    file.write(f"Loss: {history.history['loss'][-1]}\n")
    file.write(f"Validation Loss: {history.history['val_loss'][-1]}\n")
    file.write(f"Contrastive Accuracy: {history.history['ContrastiveAccuracy'][-1] * 100}%\n")
    file.write(f"Validation Contrastive Accuracy: {history.history['val_ContrastiveAccuracy'][-1] * 100}%\n")

In [None]:
cm.save_weights("Weights/ViT-11-18-23")

In [None]:
# Plot History in one image
plt.figure(1)

# Loss history
plt.subplot(211)
plt.plot(history.history['loss'],label='Training')
plt.plot(history.history['val_loss'],label='Validation')
plt.legend()

plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('ViT Loss')

# Accuracy history
plt.subplot(212)
plt.plot(history.history['ContrastiveAccuracy'],label='Training')
plt.plot(history.history['val_ContrastiveAccuracy'],label='Validation')
plt.legend()

plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('ViT Accuracy')

plt.tight_layout(h_pad = 5.0)
plt.savefig('history.png')
plt.show()

In [None]:
# Plot Loss in individual image
plt.plot(history.history['loss'],label='Training')
plt.plot(history.history['val_loss'],label='Validation')
plt.legend()

plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('ViT Loss')

plt.savefig('history_loss.png')
plt.show()

In [None]:
# Plot Accuracy in individual image
plt.plot(history.history['ContrastiveAccuracy'],label='Training')
plt.plot(history.history['val_ContrastiveAccuracy'],label='Validation')
plt.legend()

plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('ViT Accuracy')

plt.savefig('history_acc.png')
plt.show()

### Obtaining Realistic Accuracies

In [None]:
offset = 0
n_comp = 100

In [None]:
# update with notes to compare
cm((x_val[0:1],x_train[offset:offset+n_comp]))

In [None]:
y_val[0:1]

In [None]:
y_train[np.argmax(cm((x_val[0:1],x_train[offset:offset+n_comp])))+offset]

In [None]:
accuracy = 0.0
n = 10
with open("10_accuracy.txt", 'w') as file:
    print("True\tPred")
    file.write("True\tPred" + '\n')
    for i in range(n):
        true = y_val[i] 
        pred = y_train[np.argmax(cm((x_val[i:i+1], x_train[offset:offset+n_comp])))+offset]
        print(true, '\t', pred, end = '')
        file.write(str(true) + '\t\t' + str(pred))
        if (y_val[i] == y_train[np.argmax(cm((x_val[i:i+1], x_train[offset:offset+n_comp])))+offset]):
            accuracy+= 1.0
            print("\t\u2714\n")
            file.write("\t\t\u2714\n")
        else:
            print('\n')
            file.write('\n')

    accuracy /= n
    print(f"Accuracy: {accuracy*100}%")
    file.write(f"Accuracy: {accuracy*100}%\n")

In [None]:
accuracy = 0.0
n = y_val.shape[0]
print(n, "computations to make . . .")
for i in range(n):
   # print(y_val[i:i+1],end='\t')
    #print(y_train[np.argmax(cm((x_val[i:i+1], x_train[offset:offset+n_comp])))+offset])
    if (i % (n / 10) == 0 or i == n - 1):
        print(f"\rProgress: {i}/{n} ({(i/n)*100:.2f}%)", end="")  # Update the progress
    if (y_val[i] == y_train[np.argmax(cm((x_val[i:i+1], x_train[offset:offset+n_comp])))+offset]):
        accuracy+= 1.0

accuracy /= n
print(f"\nFinal Accuracy: {accuracy*100}%")
with open("final_accuracy.txt", 'w') as file:
    file.write(f"Final Accuracy: {accuracy*100}%\n")