In [1]:
import fiftyone as fo
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, Input, Reshape, Lambda, Concatenate
from tensorflow.keras.models import Model
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, TensorBoard
from functools import partial
import os
import time
import signal
import numpy as np
from sklearn.cluster import KMeans

# --- Anchor Box Calculation Function ---
def calculate_anchor_boxes(fo_dataset, num_anchors=5, image_width=416, image_height=416):
    all_boxes = []
    for sample in fo_dataset:
        if sample["ground_truth"] is None:
            continue
        for detection in sample["ground_truth"].detections:
            bbox = detection.bounding_box
            width = bbox[2]
            height = bbox[3]
            all_boxes.append([width, height])

    if not all_boxes:
        raise ValueError("No bounding boxes found in the dataset.")

    boxes = np.array(all_boxes)

    def iou_distance(box1, box2):
        intersection_width = min(box1[0], box2[0])
        intersection_height = min(box1[1], box2[1])
        if intersection_width <=0 or intersection_height <=0:
             return 1.0
        intersection_area = intersection_width * intersection_height
        box1_area = box1[0] * box1[1]
        box2_area = box2[0] * box2[1]
        union_area = box1_area + box2_area - intersection_area
        iou = intersection_area / (union_area + 1e-16)
        return 1.0 - iou

    def kmeans_iou(boxes, k):
        np.random.seed(42)
        clusters = boxes[np.random.choice(boxes.shape[0], k, replace=False)]
        last_clusters = np.zeros((boxes.shape[0],))
        while True:
            distances = np.array([[iou_distance(box, cluster) for cluster in clusters] for box in boxes])
            nearest_clusters = np.argmin(distances, axis=1)
            if (last_clusters == nearest_clusters).all():
                break
            for j in range(k):
                clusters[j] = np.median(boxes[nearest_clusters == j], axis=0)
            last_clusters = nearest_clusters
        return clusters

    anchors = kmeans_iou(boxes, num_anchors)
    anchors = anchors[np.argsort(anchors[:, 0] * anchors[:, 1])]
    anchors[:, 0] *= image_width
    anchors[:, 1] *= image_height
    return anchors

# --- Constants and Dataset Loading ---
classes_to_download = [
    "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat",
    "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat",
    "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack",
    "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball",
    "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket",
    "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
    "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair",
    "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote",
    "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book",
    "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"
]

NUM_CLASSES = len(classes_to_download)
IMAGE_SIZE = (416, 416)
BATCH_SIZE = 64
GRID_SIZE = 13
NUM_BOXES = 5
INPUT_SHAPE = (416, 416, 3)
EPOCHS = 20

# Load datasets
train_fo_dataset = fo.zoo.load_zoo_dataset("coco-2017", split="train", label_types=["detections"], classes=classes_to_download, max_samples=1000)
val_fo_dataset = fo.zoo.load_zoo_dataset("coco-2017", split="validation", label_types=["detections"], classes=classes_to_download, max_samples=500)

# Calculate anchors
ANCHORS = calculate_anchor_boxes(train_fo_dataset, num_anchors=NUM_BOXES, image_width=IMAGE_SIZE[1], image_height=IMAGE_SIZE[0])
ANCHORS = tf.constant(ANCHORS, dtype=tf.float32) / tf.constant([IMAGE_SIZE[1], IMAGE_SIZE[0]], dtype=tf.float32)  # Normalize to [0,1]

# --- Signal Handling ---
class GracefulKiller:
    def __init__(self, model, checkpoint_filepath):
        self.kill_now = False
        self.model = model
        self.checkpoint_filepath = checkpoint_filepath
        signal.signal(signal.SIGINT, self.exit_gracefully)
        signal.signal(signal.SIGTERM, self.exit_gracefully)
    def exit_gracefully(self, *args):
        print("\nSaving model...")
        self.model.save(self.checkpoint_filepath)
        print("Model saved. Exiting.")
        self.kill_now = True

# --- Data Pipeline ---
def get_detections(sample_id, fo_dataset):
    sample_id_str = sample_id.numpy().decode("utf-8")
    sample = fo_dataset[sample_id_str]
    original_width = sample.metadata.width
    original_height = sample.metadata.height
    new_width, new_height = IMAGE_SIZE[1], IMAGE_SIZE[0]
    labels, boxes = [], []
    for detection in sample["ground_truth"].detections:
        bbox = detection.bounding_box
        x_center_orig = (bbox[0] + bbox[2]/2) * original_width
        y_center_orig = (bbox[1] + bbox[3]/2) * original_height
        width_orig = bbox[2] * original_width
        height_orig = bbox[3] * original_height
        x_center_new = x_center_orig * (new_width / original_width)
        y_center_new = y_center_orig * (new_height / original_height)
        width_new = width_orig * (new_width / original_width)
        height_new = height_orig * (new_height / original_height)
        x_center_rel = x_center_new / new_width
        y_center_rel = y_center_new / new_height
        width_rel = width_new / new_width
        height_rel = height_new / new_height
        grid_x = int(x_center_rel * GRID_SIZE)
        grid_y = int(y_center_rel * GRID_SIZE)
        x_center_grid = (x_center_rel * GRID_SIZE) - grid_x
        y_center_grid = (y_center_rel * GRID_SIZE) - grid_y
        boxes.append([grid_x, grid_y, x_center_grid, y_center_grid, width_rel, height_rel])
        labels.append(classes_to_download.index(detection.label))
    return tf.cast(labels, tf.int64), tf.cast(boxes, tf.float32)

def load_and_preprocess_data(sample, fo_dataset, augment=True):
    image = tf.io.read_file(sample["filepath"])
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, IMAGE_SIZE) / 255.0
    if augment:
        image = tf.image.random_flip_left_right(image)
        image = tf.image.random_brightness(image, max_delta=0.1)
        image = tf.image.random_contrast(image, 0.8, 1.2)
    labels, boxes = tf.py_function(
        func=partial(get_detections, fo_dataset=fo_dataset),
        inp=[sample["id"]],
        Tout=(tf.int64, tf.float32)
    )

    @tf.function
    def process_boxes(boxes, labels):
        target = tf.zeros((GRID_SIZE, GRID_SIZE, NUM_BOXES, 5 + NUM_CLASSES))
        for i in tf.range(tf.shape(boxes)[0]):
            box = boxes[i]
            label = labels[i]
            grid_x, grid_y = tf.cast(box[0], tf.int32), tf.cast(box[1], tf.int32)
            x_center, y_center = box[2], box[3]
            width, height = box[4], box[5]

            gt_wh = tf.stack([width, height])
            anchor_wh = ANCHORS
            min_wh = tf.minimum(gt_wh, anchor_wh)
            intersection = min_wh[..., 0] * min_wh[..., 1]
            union = width * height + anchor_wh[..., 0] * anchor_wh[..., 1] - intersection
            ious = intersection / (union + 1e-9)
            best_anchor = tf.argmax(ious)

            tx = x_center
            ty = y_center
            tw = tf.math.log(width / ANCHORS[best_anchor][0] + 1e-9)
            th = tf.math.log(height / ANCHORS[best_anchor][1] + 1e-9)

            indices = [grid_y, grid_x, tf.cast(best_anchor, tf.int32)]
            updates = tf.concat([[tx, ty, tw, th], [1.0], tf.one_hot(label, NUM_CLASSES)], axis=0)
            target = tf.tensor_scatter_nd_update(target, [indices], [updates])
        return target

    target = tf.cond(
        tf.shape(boxes)[0] > 0,
        lambda: process_boxes(boxes, labels),
        lambda: tf.zeros((GRID_SIZE, GRID_SIZE, NUM_BOXES, 5 + NUM_CLASSES))
    )
    return image, target

def fiftyone_dataset_generator(fo_dataset):
    for sample in fo_dataset:
        yield {"filepath": sample.filepath, "id": str(sample.id)}

# Create TensorFlow datasets
train_tf_dataset = tf.data.Dataset.from_generator(
    lambda: fiftyone_dataset_generator(train_fo_dataset),
    output_types={"filepath": tf.string, "id": tf.string}
).map(
    lambda x: load_and_preprocess_data(x, train_fo_dataset, augment=True),
    num_parallel_calls=tf.data.AUTOTUNE
).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

val_tf_dataset = tf.data.Dataset.from_generator(
    lambda: fiftyone_dataset_generator(val_fo_dataset),
    output_types={"filepath": tf.string, "id": tf.string}
).map(
    lambda x: load_and_preprocess_data(x, val_fo_dataset, augment=False),
    num_parallel_calls=tf.data.AUTOTUNE
).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

# --- Model Construction ---
def build_yolo_model(input_shape, num_classes, num_boxes):
    inputs = Input(input_shape)
    base_model = ResNet50(include_top=False, weights="imagenet", input_shape=input_shape)
    base_model.trainable = False
    backbone = base_model(inputs)
    x = Conv2D(num_boxes*(5 + num_classes), 3, padding='same')(backbone)
    x = Reshape((GRID_SIZE, GRID_SIZE, num_boxes, 5 + num_classes))(x)
    txy = Lambda(lambda x: tf.sigmoid(x[..., 0:2]))(x)
    twh = Lambda(lambda x: x[..., 2:4])(x)
    obj = Lambda(lambda x: tf.sigmoid(x[..., 4:5]))(x)
    class_probs = Lambda(lambda x: tf.sigmoid(x[..., 5:]))(x)
    outputs = Concatenate(axis=-1)([txy, twh, obj, class_probs])
    return Model(inputs, outputs)

# --- Loss Function ---
def yolo_loss(y_true, y_pred):
    pred_txy = y_pred[..., 0:2]
    pred_twh = y_pred[..., 2:4]
    pred_obj = y_pred[..., 4:5]
    pred_class = y_pred[..., 5:]
    
    true_txy = y_true[..., 0:2]
    true_twh = y_true[..., 2:4]
    true_obj = y_true[..., 4:5]
    true_class = y_true[..., 5:]
    
    obj_mask = tf.squeeze(true_obj, axis=-1)  # Fix dimension mismatch
    
    xy_loss = obj_mask * tf.reduce_sum(tf.square(true_txy - pred_txy), axis=-1)
    wh_loss = obj_mask * 0.5 * tf.reduce_sum(tf.square(true_twh - pred_twh), axis=-1)
    obj_loss = tf.keras.losses.binary_crossentropy(true_obj, pred_obj)
    class_loss = obj_mask * tf.keras.losses.binary_crossentropy(true_class, pred_class)
    
    return tf.reduce_sum(xy_loss) + tf.reduce_sum(wh_loss) + tf.reduce_sum(obj_loss) + tf.reduce_sum(class_loss)

# --- Custom Metrics ---
# Add these custom metrics to your model compilation
def objectness_accuracy(y_true, y_pred):
    obj_mask = y_true[..., 4:5]  # Ground truth object presence (0 or 1)
    pred_obj = y_pred[..., 4:5]  # Predicted objectness score
    return tf.keras.metrics.binary_accuracy(obj_mask, pred_obj)

def class_accuracy(y_true, y_pred):
    obj_mask = y_true[..., 4:5]  # Only consider cells with objects
    true_class = tf.argmax(y_true[..., 5:], axis=-1)
    pred_class = tf.argmax(y_pred[..., 5:], axis=-1)
    matches = tf.cast(tf.equal(true_class, pred_class), tf.float32)
    return tf.reduce_sum(matches * tf.squeeze(obj_mask, -1)) / (tf.reduce_sum(obj_mask) + 1e-8)

# --- Training Setup ---
checkpoint_filepath = "models/checkpoints/best_model.keras"
os.makedirs(os.path.dirname(checkpoint_filepath), exist_ok=True)
model = build_yolo_model(INPUT_SHAPE, NUM_CLASSES, NUM_BOXES)

# Modify model compilation to include metrics
model.compile(
    optimizer='adam',
    loss=yolo_loss,
    metrics=[
        objectness_accuracy,
        class_accuracy
        # Removed MeanIoU metric as it causes scatter indexing error.
    ]
)

callbacks = [
    ModelCheckpoint(checkpoint_filepath, save_best_only=True, monitor='val_loss'),
    EarlyStopping(monitor='val_loss', patience=5),
    TensorBoard(log_dir='./logs')
]

killer = GracefulKiller(model, checkpoint_filepath)

# --- Training Loop ---
for epoch in range(EPOCHS):
    if killer.kill_now:
        break
    print(f"Epoch {epoch+1}/{EPOCHS}")
    try:
        model.fit(train_tf_dataset,
                   validation_data=val_tf_dataset, 
                   epochs=1, 
                   callbacks=callbacks)
    except KeyboardInterrupt:
        killer.exit_gracefully()
        break

print("Training complete.")

Downloading split 'train' to 'C:\Users\watts\fiftyone\coco-2017\train' if necessary
Found annotations at 'C:\Users\watts\fiftyone\coco-2017\raw\instances_train2017.json'
Sufficient images already downloaded
Existing download of split 'train' is sufficient
Loading existing dataset 'coco-2017-train-1000'. To reload from disk, either delete the existing dataset or provide a custom `dataset_name` to use
Downloading split 'validation' to 'C:\Users\watts\fiftyone\coco-2017\validation' if necessary
Found annotations at 'C:\Users\watts\fiftyone\coco-2017\raw\instances_val2017.json'
Sufficient images already downloaded
Existing download of split 'validation' is sufficient
Loading existing dataset 'coco-2017-validation-500'. To reload from disk, either delete the existing dataset or provide a custom `dataset_name` to use

Epoch 1/20
     16/Unknown [1m43s[0m 3s/step - class_accuracy: 0.1424 - loss: 25099.6699 - objectness_accuracy: 0.8681



[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m65s[0m 4s/step - class_accuracy: 0.1461 - loss: 24586.1328 - objectness_accuracy: 0.8731 - val_class_accuracy: 0.2648 - val_loss: 9531.3398 - val_objectness_accuracy: 0.9922
Epoch 2/20
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m61s[0m 4s/step - class_accuracy: 0.2750 - loss: 8881.7666 - objectness_accuracy: 0.9920 - val_class_accuracy: 0.2532 - val_loss: 7993.1606 - val_objectness_accuracy: 0.9922
Epoch 3/20
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m62s[0m 4s/step - class_accuracy: 0.2821 - loss: 7855.5444 - objectness_accuracy: 0.9920 - val_class_accuracy: 0.2611 - val_loss: 7214.0400 - val_objectness_accuracy: 0.9922
Epoch 4/20
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m61s[0m 4s/step - class_accuracy: 0.2833 - loss: 7371.8066 - objectness_accuracy: 0.9920 - val_class_accuracy: 0.2559 - val_loss: 6882.6914 - val_objectness_accuracy: 0.9922
Epoch 5/20
[1m16/16[0m [32m━━━━━━━━━━━━