In [None]:
import os
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.optimizers import Adam
from PIL import Image
import numpy as np

# SimCLR Augmentation function
def simclr_augment(image):
    image = tf.cast(image, tf.float32)  # Cast to float32
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_crop(image, size=[int(image.shape[0] * 0.9), int(image.shape[1] * 0.9), 3])
    image = tf.image.resize(image, [224, 224])  # Resizing to target size
    image = tf.image.random_brightness(image, max_delta=0.5)
    return image


# Load all images from a directory, including subdirectories
def load_images_from_directory(data_dir, image_size=(224, 224)):
    image_list = []
    for root, _, files in os.walk(data_dir):
        for file in files:
            if file.lower().endswith(('.png', '.jpg', '.jpeg')):  # Check for image file extensions
                try:
                    img_path = os.path.join(root, file)
                    img = Image.open(img_path).convert("RGB")  # Ensure 3-channel RGB
                    img = img.resize(image_size)  # Resize to target size
                    img = np.array(img) / 255.0  # Normalize pixel values
                    image_list.append(img.astype(np.float32))  # Convert to float32
                except Exception as e:
                    print(f"Error loading image {file}: {e}")
    return np.array(image_list, dtype=np.float32)  # Ensure entire dataset is float32


# Build the SimCLR model
def build_simclr_model(input_shape=(224, 224, 3), projection_dim=128):
    # Base encoder: ResNet50
    base_model = ResNet50(include_top=False, weights='imagenet', input_shape=input_shape, pooling='avg')
    base_model.trainable = True

    # Projection head
    inputs = layers.Input(shape=input_shape)
    x = base_model(inputs, training=True)
    x = layers.Dense(512, activation='relu')(x)
    outputs = layers.Dense(projection_dim)(x)
    model = models.Model(inputs=inputs, outputs=outputs)
    return model

# Contrastive loss function
def nt_xent_loss(z_i, z_j, temperature=0.5):
    z_i = tf.math.l2_normalize(z_i, axis=1)
    z_j = tf.math.l2_normalize(z_j, axis=1)
    
    # Concatenate positive pairs
    z = tf.concat([z_i, z_j], axis=0)
    similarity_matrix = tf.matmul(z, z, transpose_b=True) / temperature

    # Create labels
    batch_size = tf.shape(z_i)[0]
    labels = tf.one_hot(tf.range(batch_size), depth=2 * batch_size)
    labels = tf.concat([labels, labels], axis=0)

    loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True)(labels, similarity_matrix)
    return loss

# Train SimCLR model
def train_simclr(image_data, epochs=10, batch_size=32):
    model = build_simclr_model()
    optimizer = Adam(learning_rate=0.001)

    # Cast images to float32 and create TensorFlow dataset
    dataset = tf.data.Dataset.from_tensor_slices(tf.cast(image_data, tf.float32))
    dataset = dataset.shuffle(buffer_size=len(image_data)).batch(batch_size)

    for epoch in range(epochs):
        epoch_loss = []
        for batch in dataset:
            # Create augmented views
            augmented_images_1 = tf.map_fn(simclr_augment, batch)
            augmented_images_2 = tf.map_fn(simclr_augment, batch)

            with tf.GradientTape() as tape:
                z_i = model(augmented_images_1, training=True)
                z_j = model(augmented_images_2, training=True)
                loss = nt_xent_loss(z_i, z_j)
            
            gradients = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(gradients, model.trainable_variables))
            epoch_loss.append(loss.numpy())
        
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {np.mean(epoch_loss)}")

    return model



# Extract embeddings from the trained model
def extract_embeddings(model, image_data, batch_size=32):
    dataset = tf.data.Dataset.from_tensor_slices(image_data).batch(batch_size)
    embeddings = []
    for batch in dataset:
        embeddings_batch = model(batch, training=False).numpy()
        embeddings.append(embeddings_batch)
    embeddings = np.vstack(embeddings)
    return embeddings

# Main function to run SimCLR
def main():
    data_dir = "London_UK/images" #Path to images folder (Use test images folder)
    print("Loading images...")
    image_data = load_images_from_directory(data_dir)
    print(f"Loaded {len(image_data)} images.")
    
    print("Training SimCLR model...")
    model = train_simclr(image_data, epochs=10, batch_size=32)

    print("Extracting embeddings...")
    embeddings = extract_embeddings(model, image_data)
    print(f"Generated embeddings with shape: {embeddings.shape}")

if __name__ == "__main__":
    main()


Loading images...
Loaded 52 images.
Training SimCLR model...
