# Build the FOMO model with pre-trained backbone

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.optimizers import Adam
import cv2
import os
import matplotlib.pyplot as plt
import numpy as np

def build_fomo_model(input_shape=(224, 224, 3), num_classes=2):
    """
    Builds a FOMO (Faster Objects, More Objects) model using a truncated MobileNetV2 as the feature extractor 
    and a grid-based classifier head for object detection.

    Parameters:
    -----------
    input_shape : tuple, optional
        Shape of the input image, default is (224, 224, 3), where 224x224 are the height and width in pixels, 
        and 3 corresponds to the number of channels (RGB).
    
    num_classes : int, optional
        Number of object classes for classification and localization (can be binary or multiclass but we always
        need a background class)

    Returns:
    --------
    model : tf.keras.Model
        A TensorFlow Keras model object that can be used for object detection tasks.
    """
    
    # Load the MobileNetV2 model pre-trained on ImageNet
    base_model = MobileNetV2(input_shape=input_shape, include_top=False, weights='imagenet')

    # To keep the model simple we are cutting the network to an early feature extraction layer (block_6_expand_relu)
    # block_6_expand_relu output shape is (None, 20, 20, 96)
    feature_extractor = models.Model(inputs=base_model.input, 
                                     outputs=base_model.get_layer('block_6_expand_relu').output,
                                     name = 'MobileNetV2_Block6')
    
    # Freeze the layers of the backbone
    feature_extractor.trainable = False

    # Input layer
    inputs = tf.keras.Input(shape=input_shape, name = 'FOMO_input')
    
    # Pass input through the feature extractor
    x = feature_extractor(inputs)
    
    #-----------FOMO head grid based classifier----------
    x = layers.Conv2D(16, (3, 3), padding='same', activation='relu', name='conv_1')(x)
    x = layers.Conv2D(16, (3, 3), padding='same', activation='relu', name='conv_2')(x)
    
    # Output layer
    logits = layers.Conv2D(num_classes, (3, 3), padding='same', activation='softmax', name='output')(x)
    
    # Create the model
    model = models.Model(inputs=inputs, outputs=logits, name = 'FOMO')
    
    return model

### Compile the model

In [None]:
optimizer = Adam(learning_rate=0.001)
model = build_fomo_model(num_classes=2)

model.compile(optimizer=optimizer, 
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
model.summary()

### Downloads the chess pieces dataset from Roboflow and create a Tensorflow dataset generator

Link: https://public.roboflow.com/object-detection/chess-full/

**Important Note**: Make sure you select the Pytorch YoloV5 version of the dataset when exporting!

In [None]:
!curl -L "CURL LINK" -o roboflow.zip
!unzip roboflow.zip -d chess_ds
!rm roboflow.zip

### Create Tensorflow Dataset generator

In [3]:
MODEL_SIZE = 224

def load_image_and_bboxes(image_path, label_path):
    # Load the image
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = cv2.resize(image, (MODEL_SIZE, MODEL_SIZE))

    # Load bounding boxes
    bboxes = []
    classes = []
    with open(label_path, 'r') as file:
        for line in file.readlines():
            parts = list(map(float, line.strip().split()))

            # We just treat every chess piece as one class for simplicity, but feel
            # free to use the original classes.
            # Set class index to 1 for single class detection, 0 as background
            class_idx = 1
            x_center, y_center, width, height = parts[1:5]
            bboxes.append([x_center, y_center, width, height])
            classes.append([class_idx])

    return image, bboxes, classes

def create_tf_dataset(image_dir, label_dir):
    image_paths = [os.path.join(image_dir, filename) for filename in os.listdir(image_dir) if filename.endswith('.jpg')]
    label_paths = [os.path.join(label_dir, filename[:-4] + '.txt')for filename in os.listdir(image_dir) if filename.endswith('.jpg')]

    def generator():
        for img_path, lbl_path in zip(image_paths, label_paths):
            image, bboxes, classes = load_image_and_bboxes(img_path, lbl_path)
            if len(bboxes) == 0:
                continue  # Skip if no bounding boxes
            image = tf.convert_to_tensor(image, dtype=tf.float32) / 255.0
            bboxes = tf.ragged.constant(bboxes, dtype=tf.float32)
            classes = tf.ragged.constant(classes, dtype=tf.int64)
            yield {'image': image, 'bboxes': bboxes, 'classes': classes}

    # Create TensorFlow Dataset from generator
    dataset = tf.data.Dataset.from_generator(
        generator,
        output_signature={
            'image': tf.TensorSpec(shape=(MODEL_SIZE, MODEL_SIZE, 3), dtype=tf.float32),
            'bboxes': tf.RaggedTensorSpec(shape=(None, None), dtype=tf.float32),
            'classes': tf.RaggedTensorSpec(shape=(None, None), dtype=tf.int64),
        }
    )

    return dataset

# Load training and validation datasets
train_dataset = create_tf_dataset("chess_ds/train/images", "chess_ds/train/labels")
val_dataset = create_tf_dataset("chess_ds/valid/images", "chess_ds/valid/labels")

### Verify dataset preparation

In [None]:
def visualize_bboxes(image, bboxes, labels):
    """
    Visualizes bounding boxes on the image.

    Args:
        image (tf.Tensor): The image tensor (height, width, channels).
        bboxes (tf.Tensor): The bounding boxes tensor (num_boxes, 4), where each box is [cx, cy, w, h].
        labels (tf.Tensor): The labels tensor (num_boxes, 1) or (num_boxes,).
    """
    # Convert the image tensor to a numpy array and scale if necessary
    image = image.numpy()
    
    # Create a figure and axis
    plt.figure(figsize=(10, 10))
    plt.imshow(image)

    # Get the number of boxes
    num_boxes = tf.shape(bboxes)[0]

    for i in range(num_boxes):
        # Get the center coordinates and dimensions
        cx, cy, w, h = bboxes[i].numpy() * MODEL_SIZE

        # Calculate the top-left corner of the box
        x_min = int(cx - w / 2)
        y_min = int(cy - h / 2)

        # Draw the rectangle
        rect = plt.Rectangle((x_min, y_min), w, h, fill=False, color='blue', linewidth=4)
        plt.gca().add_patch(rect)

        # Add label text
        plt.text(x_min, y_min,"chess_piece", color='white', fontsize=12,
                 bbox=dict(facecolor='blue', alpha=0.5))

    plt.axis('off')  # Hide the axis
    plt.show()

# Example usage: Visualizing a single image and its bounding boxes from the training dataset
for sample in train_dataset.take(1):
    visualize_bboxes(sample['image'], sample['bboxes'], sample['classes'])
    break


### Unpack dataset from dict version and convert RaggedTensors to dense

In [11]:
# Function to unpack image, bounding boxes, and labels from the dataset
def unpack_info(sample):
    image = sample['image']
    labels = sample['classes'].to_tensor()
    bboxes = sample['bboxes'].to_tensor()
    return image, bboxes, labels

# Prepare the dataset and unpack the data
train_dataset = train_dataset.map(unpack_info, num_parallel_calls=tf.data.AUTOTUNE)
val_dataset = val_dataset.map(unpack_info, num_parallel_calls=tf.data.AUTOTUNE)

### Convert to grid-styled label annotations

In [None]:
def assign_grid_labels(image, bboxes, labels, grid_size=28):
    """Convert bounding boxes to grid-style centroids label map"""
    height, width, _ = image.shape
    cell_height = height / grid_size
    cell_width = width / grid_size

    # Create a temporary grid of zeros for labels
    temp_grid_labels = tf.zeros((grid_size, grid_size), dtype=tf.int32)

    # Use the center coordinates directly from bboxes
    center_x = bboxes[:, 0] * MODEL_SIZE  # bboxes[:, 0] is center_x
    center_y = bboxes[:, 1] * MODEL_SIZE  # bboxes[:, 1] is center_y

    # Calculate grid cell indices
    grid_y = tf.floor(center_y / cell_height)
    grid_x = tf.floor(center_x / cell_width)

    # Clip values to be within the grid size
    grid_y = tf.clip_by_value(grid_y, 0, grid_size - 1)
    grid_x = tf.clip_by_value(grid_x, 0, grid_size - 1)

    # Prepare the indices to update the grid
    indices = tf.cast(tf.stack([grid_y, grid_x], axis=-1), dtype=tf.int32)

    # Use tf.tensor_scatter_nd_update to assign labels directly
    updates = tf.squeeze(labels, axis=-1)

    # Scatter updates to the grid labels
    temp_grid_labels = tf.tensor_scatter_nd_update(temp_grid_labels, indices, tf.cast(updates, dtype = tf.int32))

    return image, temp_grid_labels

# Convert boxes to the 28x28 grid style centroid labels representation 
labeled_train = train_dataset.map(lambda image, bboxes, labels: assign_grid_labels(image, bboxes, labels), 
                                   num_parallel_calls=tf.data.AUTOTUNE)
labeled_val = val_dataset.map(lambda image, bboxes, labels: assign_grid_labels(image, bboxes, labels), 
                               num_parallel_calls=tf.data.AUTOTUNE)

# Preview the first few labeled examples
# for img, grid in labeled_train.take(1):
#     print("Image shape:", img.shape)
#     print("Grid labels:", grid.numpy())
#     break


### Verify grid-styled labels after conversion

In [None]:
def plot_non_zero_centers(image, grid, grid_size=28):
    """Plot the grid centers where class index is not background"""
    height, width, _ = image.shape
    cell_height = height / grid_size
    cell_width = width / grid_size 

    # Create a copy of the image to plot on
    image_with_centers = image.copy()

    # Iterate through the grid to find non-zero class indices
    for y in range(grid_size):
        for x in range(grid_size):
            if grid[y, x] != 0:  # Check if the class index is not zero
                # Calculate the center of the cell to plot
                center_x = int((x + 0.5) * cell_width)
                center_y = int((y + 0.5) * cell_height)

                # Draw a green circle at the center location
                cv2.circle(image_with_centers, (center_x, center_y), radius=5, color=(0, 255, 0), thickness=-1)

    # Display the image
    plt.imshow(image_with_centers.astype(np.uint8))
    plt.axis('off')  # Hide axes
    plt.title('Ground truth object centers')
    plt.show()

num_images_to_display = 5

# Create a loop to display multiple images
for i, (img, grid) in enumerate(labeled_train.take(num_images_to_display)):
    # Ensure the image is in the correct format and scale it properly
    image_for_plot = img.numpy() * 255.0
    
    # Plot the non-zero centers for the current image
    plot_non_zero_centers(image_for_plot.astype(np.uint8), grid.numpy())

## Train the model

In [None]:
batch_size = 4
train_dataset = labeled_train.batch(batch_size, drop_remainder= True).prefetch(tf.data.AUTOTUNE)
val_dataset = labeled_val.batch(batch_size, drop_remainder= True).prefetch(tf.data.AUTOTUNE)

history = model.fit(train_dataset, epochs=10, validation_data=val_dataset)

### Convert grid-styled detections to centroid and test the model

In [15]:
def get_centroids_from_heatmap(heatmap, threshold=0.5, rescale=8.0):
    """Convert output heatmap to non-background object centroids"""
    print(f'Rescaling by {rescale}...')

    heatmap = heatmap.squeeze()

    # Get class probabilities for all classes
    class_probabilities = heatmap[..., 0:] 

    centroids = []
    grid_size = heatmap.shape[0]
    cell_size = rescale

    for i in range(grid_size): # Y
        for j in range(grid_size): # X
            # Get the maximum probability and the index of that probability
            max_prob = np.max(class_probabilities[i, j,:]) 
            max_prob_index = np.argmax(class_probabilities[i, j, :])

            # Exclude background
            if max_prob_index == 0:
                continue

            if max_prob >= threshold:
                cx = int((j + 0.5) * cell_size)
                cy = int((i + 0.5) * cell_size) 
                centroids.append((cx, cy))

    return centroids

def plot_detections(image, centroids):
    """Plot the predicted object centers"""
    image = image * 255

    image = np.clip(image, 0, 255)
    image = image.astype(np.uint8)

    plt.figure(figsize=(8, 8))
    plt.imshow(image)

    # Plot each centroid
    for cx, cy in centroids:
        plt.plot(cx, cy, 'ro', markersize=10)

    plt.imshow(image.astype(np.uint8))

    plt.title('Predicted object centers')
    plt.axis('off')
    plt.show()

### Try the model on the test set

In [None]:
for images, labels in val_dataset.take(1):
    predictions = model.predict(images)
    
    # Get the batch size
    batch_size = predictions.shape[0]  

    # Iterate over each image and its corresponding prediction in the batch
    for i in range(batch_size):
        print(f"Processing image {i + 1}/{batch_size}")
        img = images[i]
        pred = predictions[i]
        
        # Calculate the rescaling factor based on the prediction grid size and the original image size (assumed 224x224 here)
        rescale_factor = images.shape[1] / pred.shape[0]
        
        # Get the predicted (y,x) centroids
        centroids = get_centroids_from_heatmap(pred, threshold=0.5, rescale=rescale_factor)
        print(f"Detected centroids for image {i + 1}: {centroids}")
        print(f"Number of detected centroids for image {i + 1}: {len(centroids)}")
        
        # Plot the centroids on the image
        plot_detections(img, centroids)
