# TRACO example solution
In this Jupyter Notebook we implemented a really simple approch of how to detect Hexbugs in a frame. The following steps are performed:
- Load all videos and Hexbug positions for training
- Resize all frames to a fixed size (target_shape)
- Create a binary mask from the positions to train a U-Net
- Create U-Net architecture and train it with a learning rate scheduler
- Get the final predictions by finding clusters in the predicted output image and scaling back the points to the original frame size
- Convert the output to fit the ".traco" format that is needed to use our score calculation script

In [201]:
import cv2
import numpy as np
from pathlib import Path
import os
import json
import pandas as pd

import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPool2D, Flatten, Dense, Dropout, BatchNormalization, Concatenate, Reshape, GlobalAveragePooling2D, UpSampling2D, Activation
from tensorflow.keras.optimizers import Adam
from tf.keras.callbacks import LearningRateScheduler
from segmentation_models.metrics import iou_score
from segmentation_models.losses import dice_loss

import matplotlib.pyplot as plt

from sklearn.cluster import DBSCAN

In [162]:
# Define path where the training data is located
path_training_vids = Path("training")

# Downsample the input frames to a fixed target_shape
target_shape = (256, 256)

In [163]:
def load_train_videos(path):
    """
    This function returns all trainings videos and the annotations as binary masks (1 at the positions where a Hexbug is located).
    All frames are resized and normalized. 
    """
    X = []
    Y = []
    
    for vid in os.listdir(path):
        path = Path(path)
        if ".mp4" in vid:
            with open(path / vid.replace("mp4", "traco")) as f:
                annotations = json.load(f)['rois']
            
            file_names.append(path / vid.replace("mp4", "traco"))
            
            cap = cv2.VideoCapture(str(path / vid))
            ret, frame = cap.read()     
            org_shape = frame.shape
            
            z = 0  # frame counter
            while ret:
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                mask_frame = np.zeros(shape=target_shape)
                for annot in annotations:
                    if annot['z'] == z: 
                        # Get pos and scale it down to fit the target_shape
                        pos = annot['pos']
                        pos[0] = pos[0] * target_shape[0] // org_shape[1]
                        pos[1] = pos[1] * target_shape[1] // org_shape[0]
                        
                        # Set the position if the Hexbug in the binary mask to 1
                        try:
                            mask_frame[int(pos[1]), int(pos[0])] = 1
                        except:
                            # IndexOutOfRange error sometimes occurs because of the downsampling of the frames
                            mask_frame[int(pos[1]) - 1, int(pos[0]) - 1] = 1
                        
                # Resize the frame to the target size using bilinear interpolation
                resized_frame = cv2.resize(frame, target_shape, interpolation=cv2.INTER_LINEAR)
                
                # Normalize to zero mean and unit variance
                normalized_frame = (resized_frame.astype('float32') / 255.0 - 0.5) / 0.5
                
                # Append to lists
                X.append(normalized_frame)
                Y.append(mask_frame) 
                
                ret, frame = cap.read()  # read next frame
                z += 1  # increase frame counter
                
    X = np.asarray(X)
    Y = np.asarray(Y)
    
    return X, Y

## Create and train U-Net

In [149]:
def conv_block(x, num_filters):
    x = Conv2D(filters=num_filters, kernel_size=3, padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = Conv2D(filters=num_filters, kernel_size=3, padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    return x


def build_Unet(filters=16, num_classes=1):
    # Input layer
    inputs = Input(shape=(target_shape[0], target_shape[1], 3))
    
    # Encoder
    e1 = conv_block(inputs, filters)
    p1 = MaxPool2D((2, 2))(e1)
    
    e2 = conv_block(p1, filters * 2)
    p2 = MaxPool2D((2, 2))(e2)
    
    e3 = conv_block(p2, filters * 4)
    p3 = MaxPool2D((2, 2))(e3)
    
    e4 = conv_block(p3, filters * 8)
    p4 = MaxPool2D((2, 2))(e4)
    
    # Bottleneck
    b1 = conv_block(p4, filters * 16)
    
    # Decoder
    d1 = UpSampling2D()(b1)
    d1 = Concatenate()([d1, e4])
    d1 = conv_block(d1, filters * 8)
    
    d2 = UpSampling2D()(d1)
    d2 = Concatenate()([d2, e3])
    d2 = conv_block(d2, filters * 4)
    
    d3 = UpSampling2D()(d2)
    d3 = Concatenate()([d3, e2])
    d3 = conv_block(d3, filters * 2)
    
    d4 = UpSampling2D()(d3)
    d4 = Concatenate()([d4, e1])
    d4 = conv_block(d4, filters)
    
    # Output layer
    outputs = Conv2D(filters=num_classes,
                     kernel_size=1,
                     padding='same',
                     activation='sigmoid')(d4)
    
    return Model(inputs, outputs)

In [203]:
model.summary()

Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_2 (InputLayer)           [(None, 256, 256, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d_19 (Conv2D)             (None, 256, 256, 16  448         ['input_2[0][0]']                
                                )                                                                 
                                                                                                  
 batch_normalization_18 (BatchN  (None, 256, 256, 16  64         ['conv2d_19[0][0]']              
 ormalization)                  )                                                           

                                                                                                  
 activation_26 (Activation)     (None, 16, 16, 256)  0           ['batch_normalization_26[0][0]'] 
                                                                                                  
 conv2d_28 (Conv2D)             (None, 16, 16, 256)  590080      ['activation_26[0][0]']          
                                                                                                  
 batch_normalization_27 (BatchN  (None, 16, 16, 256)  1024       ['conv2d_28[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 activation_27 (Activation)     (None, 16, 16, 256)  0           ['batch_normalization_27[0][0]'] 
                                                                                                  
 up_sampli

                                                                                                  
 activation_34 (Activation)     (None, 256, 256, 16  0           ['batch_normalization_34[0][0]'] 
                                )                                                                 
                                                                                                  
 conv2d_36 (Conv2D)             (None, 256, 256, 16  2320        ['activation_34[0][0]']          
                                )                                                                 
                                                                                                  
 batch_normalization_35 (BatchN  (None, 256, 256, 16  64         ['conv2d_36[0][0]']              
 ormalization)                  )                                                                 
                                                                                                  
 activatio

In [151]:
# Build model
model = build_Unet(filters=16)  # original U-Net: 64 filters

# Compiling the model
model.compile(optimizer=Adam(learning_rate=1e-3),  # Define optimizer and learning rate
              loss=dice_loss,                      # Dice loss function
              metrics=[iou_score])     # Intersection over Union (IoU) & Dice score

In [152]:
def exp_scheduler(epoch, lr):
    """
    Learning rate scheduler.
    """
    if epoch < 20:
        return 1e-3
    elif epoch < 40:
        return 1e-4
    elif epoch < 60:
        return 1e-5
    else:
        return lr * np.exp(-0.1)

In [None]:
# Load training data
X_train, Y_train = load_train_videos(path_training_vids)

In [153]:
history = model.fit(x=X_train, y=Y_train, epochs=80,
                    callbacks=[LearningRateScheduler(schedule=exp_scheduler, verbose=0)], 
                    validation_split=0.1)

Epoch 1/80
Epoch 2/80
Epoch 3/80
Epoch 4/80
Epoch 5/80
Epoch 6/80
Epoch 7/80
Epoch 8/80
Epoch 9/80
Epoch 10/80
Epoch 11/80
Epoch 12/80
Epoch 13/80
Epoch 14/80
Epoch 15/80
Epoch 16/80
Epoch 17/80
Epoch 18/80
Epoch 19/80
Epoch 20/80
Epoch 21/80
Epoch 22/80
Epoch 23/80
Epoch 24/80
Epoch 25/80
Epoch 26/80
Epoch 27/80
Epoch 28/80
Epoch 29/80
Epoch 30/80
Epoch 31/80
Epoch 32/80
Epoch 33/80
Epoch 34/80
Epoch 35/80
Epoch 36/80
Epoch 37/80
Epoch 38/80
Epoch 39/80
Epoch 40/80
Epoch 41/80
Epoch 42/80
Epoch 43/80
Epoch 44/80
Epoch 45/80
Epoch 46/80
Epoch 47/80
Epoch 48/80
Epoch 49/80
Epoch 50/80


Epoch 51/80
Epoch 52/80
Epoch 53/80
Epoch 54/80
Epoch 55/80
Epoch 56/80
Epoch 57/80
Epoch 58/80
Epoch 59/80
Epoch 60/80
Epoch 61/80
Epoch 62/80
Epoch 63/80
Epoch 64/80
Epoch 65/80
Epoch 66/80
Epoch 67/80
Epoch 68/80
Epoch 69/80
Epoch 70/80
Epoch 71/80
Epoch 72/80
Epoch 73/80
Epoch 74/80
Epoch 75/80
Epoch 76/80
Epoch 77/80
Epoch 78/80
Epoch 79/80
Epoch 80/80


## Apply the model to our test data

In [202]:
def get_final_points_from_predictions(pred, org_shape):
    """
    This function finds cluster where a Hexbug is detected. It resizes the found positions back to fit the original frame shape.
    """
    points = np.transpose(np.where(pred > 0.01))
    if len(points) == 0:
        return []
    
    # Perform clustering
    eps = 10
    min_samples = 1
    dbscan = DBSCAN(eps=eps, min_samples=min_samples)
    dbscan.fit(points)
    labels = dbscan.labels_
    
    n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
    final_points = []
    for i in range(n_clusters):
        random_point = np.random.choice(np.where(labels == i)[0])
        random_point = points[random_point]
        random_point[0] = int(random_point[0] * org_shape[0] // target_shape[0])
        random_point[1] = int(random_point[1] * org_shape[1] // target_shape[1])
    
        final_points.append(random_point)
    
    return final_points[0:5]

In [None]:
def load_validation_data(path):
    """
    This function returns all validation videos and the annotations as binary masks (1 at the positions where a Hexbug is located).
    All frames are resized and normalized. 
    """
    X = []
    org_shapes = []
    file_names = []
    
    for vid in os.listdir(path):
        path = Path(path)
        if ".mp4" in vid:
            with open(path / vid.replace("mp4", "traco")) as f:
                annotations = json.load(f)['rois']
                  
            cap = cv2.VideoCapture(str(path / vid))
            ret, frame = cap.read()     
            org_shape = frame.shape
            
            file_names.append(path / vid.replace("mp4", "traco"))
            org_shapes.append(org_shape)
            
            X_ = []
            while ret:
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                        
                # Resize the frame to the target size using bilinear interpolation
                resized_frame = cv2.resize(frame, target_shape, interpolation=cv2.INTER_LINEAR)
                
                # Normalize to zero mean and unit variance
                normalized_frame = (resized_frame.astype('float32') / 255.0 - 0.5) / 0.5
                
                # Append to lists
                X_.append(normalized_frame)
                
                ret, frame = cap.read()  # read next frame
            
            X.append(np.asarray(X_))           
    
    return X, org_shapes, file_names

In [170]:
X_test, org_shapes_test, file_names_test = load_test_data("../")

In [200]:

for idx, x in enumerate(X_test):
    rois = []
    
    # Predict all frames of one video
    preds = model.predict(x)
    
    org_shape = org_shapes_test[idx]
    file_name = file_names_test[idx]
    
    d = {'rois': []}
    # Get final predicted points for each frame
    for pred in preds:
        pred = np.squeeze(pred)
        coords = get_final_points_from_predictions(pred, org_shape)
        print(coords)
        
    # Create final csv with all predictions
    
    
    
    
#for frame, org_shape in zip(X_test, org_shapes_test):
    #pred = np.squeeze(model.predict(frame[None, ...])[0])
    #coords = get_final_points_from_predictions(pred, org_shape)
    #if len(coords) == 0:      
    #for idx, coord in enumerate(coords):     
    #print(coords)

[array([ 278, 1016]), array([1025,  312])]
[array([409, 999]), array([1043,  206])]
[array([509, 898]), array([1031,  126])]
[array([592, 826])]
[array([640, 738]), array([895, 126])]
[array([664, 658]), array([806, 156])]
[array([705, 552]), array([747, 219])]
[array([735, 282]), array([735, 476])]
[array([711, 388]), array([776, 350])]
[array([670, 303]), array([871, 392])]
[array([652, 248]), array([984, 367])]
[array([664, 160]), array([1049,  307])]
[array([670,  75]), array([1102,  227])]
[array([652,  42]), array([1114,  113])]
[array([575,  54])]
[array([492, 105]), array([1185,   54])]
[array([432, 143]), array([1274,   59])]
[array([373, 164]), array([1346,  126])]
[array([343, 202])]
[array([314, 227])]
[array([290, 265])]
[array([290, 316])]
[array([332, 388])]
[array([379, 459])]
[array([320, 506])]
[array([320, 561])]
[array([355, 615])]
[array([367, 670]), array([1452,  320])]
[array([290, 700]), array([1464,  350])]
[array([302, 755]), array([1464,  383])]
[array([385, 