# Hexbug detection using U-Net
In this Colab notebook, we will train U-Net to segment the heads of hexbugs.

## Preparations

**Use Colab's free GPU**

* Click "Runtime" tab
* Select "change runtime type"
* Select GPU

**Get data into Colab**
* Zip folder "train_data"
* Create and zip folder "recorded_data"
* Copy the zip-file to GoogleDrive
* Connect to GoogleDrive and unzip the folder:

In [None]:
# Connect to GoogleDrive
from google.colab import drive
drive.mount('/content/drive')

# ******
# TO DO: Adjust the path (on your GoogleDrive) to the folders to unzip
# ******
!unzip "drive/MyDrive/Colab Notebooks/TRACO_Budapest/train_data.zip"
!unzip "drive/MyDrive/Colab Notebooks/TRACO_Budapest/recorded_data.zip"

In [None]:
# Check if the videos and annotations are there
import os
os.listdir("train_data")[:5]

In [None]:
# Install useful packages for image segmentation
!pip install albumentations
!pip install segmentation-models
%env SM_FRAMEWORK=tf.keras

In [None]:
# Import packages
import os
import cv2
import json
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from ipywidgets import interact

from keras import layers
from keras.models import Model
from keras.optimizers import Adam
from keras.models import load_model
from keras.utils import Sequence

from sklearn.model_selection import train_test_split
import albumentations as A

from segmentation_models.losses import dice_loss, binary_focal_loss
from segmentation_models.metrics import iou_score

## Loading the data

**DataGenerator**

Since we have a lot of training data and the memory in Colab is limited, we created a generator.
The DataGenerator loads one batch of data at a time during the training. It also takes care of preprocessing the images.

In [None]:
class DataGenerator(Sequence):
  def __init__(self, paths, image_shape=(128, 128), batch_size=32, shuffle=False, augment=None):
    self.paths = paths
    self.image_shape = image_shape
    self.batch_size = batch_size
    self.shuffle = shuffle
    self.augment = augment

    self.image_paths = [p for p in paths if p.endswith("img.png")]
    self.on_epoch_end()

  def __len__(self):
    # Returns number of baches
    return len(self.image_paths) // self.batch_size

  def __getitem__(self, idx):
    # Get one batch of data
    batch_indexes = self.indexes[idx * self.batch_size:(idx+1) * self.batch_size]
    X, y = self._generate_data(batch_indexes)
    return X, y

  def on_epoch_end(self):
    # Set up indexes
    self.indexes = np.arange(len(self.image_paths))
    if self.shuffle:
      np.random.shuffle(self.indexes)

  def _generate_data(self, batch_indexes):
    # Generate one batch
    X = np.zeros((self.batch_size, *self.image_shape, 3))
    y = np.zeros((self.batch_size, *self.image_shape, 1))

    for i, batch_idx in enumerate(batch_indexes):

      img_path = self.image_paths[batch_idx]
      mask_path = img_path.replace("img", "mask")

      # Load image
      img = cv2.imread(img_path)
      img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

      # Load mask
      mask = cv2.imread(mask_path)
      mask = mask[:,:,0]

      # Resize
      if img.shape[0] != self.image_shape[0]:
        img = cv2.resize(img, self.image_shape, interpolation=cv2.INTER_LINEAR)
        mask = cv2.resize(mask, self.image_shape, interpolation=cv2.INTER_LINEAR)

      # ***************
      # Optional TO DO: Apply data augmentation (Hint: Use package "albumentations")
      # ***************


      # Add to stack of frames and masks
      X[i, ] = img
      y[i, ] = mask[..., None]

    X = X.astype('float32') / 255.0
    y = y.astype('float32') / 255.0  # Mask will contain zeros and ones
    return X, y

**Let's try our DataGenerator**

In [None]:
# Get paths and set up a generator
train_paths = ["train_data/" + f for f in os.listdir("train_data")]
gen = DataGenerator(train_paths, image_shape=(128, 128), batch_size=256, shuffle=True)

In [None]:
# Get first batch from the generator and display it
for x, y in gen:
  print(x.shape) # images
  print(y.shape) # masks
  break

def show_frame_and_mask(i):
    plt.subplot(121)
    plt.imshow(x[i])
    plt.subplot(122)
    plt.imshow(y[i], cmap="gray")
    plt.show()

interact(show_frame_and_mask, i=(0, 32))

##  Build and train U-Net

**Original U-Net paper:**

Ronneberger, O., Fischer, P., & Brox, T. (2015, October). U-net: Convolutional networks for biomedical image segmentation. *In International Conference on Medical image computing and computer-assisted intervention (pp. 234-241)*. Springer, Cham. [Link](https://arxiv.org/abs/1505.04597)

In [None]:
def conv(x, num_filters):
  """ Conv block with two convolutional layers"""

  x = layers.Conv2D(filters=num_filters, kernel_size=3, padding="same")(x)
  x = layers.BatchNormalization()(x)
  x = layers.Activation("relu")(x)

  x = layers.Conv2D(filters=num_filters, kernel_size=3, padding="same")(x)
  x = layers.BatchNormalization()(x)
  x = layers.Activation("relu")(x)
  return x

def build_unet(filters, input_shape):
  """Returns U-Net model"""

    # Input layer
    inputs = layers.Input(shape=input_shape)

    # Encoder
    e1 = conv(inputs, filters)
    p1 = layers.MaxPool2D((2, 2))(e1)

    e2 = conv(p1, filters * 2)
    p2 = layers.MaxPool2D((2, 2))(e2)

    e3 = conv(p2, filters * 4)
    p3 = layers.MaxPool2D((2, 2))(e3)

    e4 = conv(p3, filters * 8)
    p4 = layers.MaxPool2D((2, 2))(e4)

    # Bottleneck
    b1 = conv(p4, filters * 16)

    # Decoder
    d1 = layers.UpSampling2D()(b1)
    d1 = layers.Concatenate()([d1, e4])
    d1 = conv(d1, filters * 8)

    d2 = layers.UpSampling2D()(d1)
    d2 = layers.Concatenate()([d2, e3])
    d2 = conv(d2, filters * 4)

    d3 = layers.UpSampling2D()(d2)
    d3 = layers.Concatenate()([d3, e2])
    d3 = conv(d3, filters * 2)

    d4 = layers.UpSampling2D()(d3)
    d4 = layers.Concatenate()([d4, e1])
    d4 = conv(d4, filters)

    # Output layer
    outputs = layers.Conv2D(
        filters=1,
        kernel_size=1,
        padding="same",
        activation="sigmoid"
    )(d4)

    return Model(inputs, outputs)

**Training**

Now, we can set up U-Net for training. We will use a loss function called "dice_loss" and the metric "iou_score". They are from the segmentation-models package, see: https://segmentation-models.readthedocs.io/en/latest/api.html#losses

In [None]:
# Build U-Net
unet = build_unet(filters=8, input_shape=(128, 128, 3))
unet.compile(optimizer=Adam(learning_rate=1e-2), loss=dice_loss, metrics=[iou_score])

In [None]:
train_paths = ["train_data/" + f for f in os.listdir("train_data")]

# Split data for training and validation
train_paths, val_paths = train_test_split(train_paths, test_size=0.2)

# Set up generators
train_gen = DataGenerator(train_paths, image_shape=(128, 128), batch_size=256, shuffle=True)
val_gen = DataGenerator(val_paths, image_shape=(128, 128), batch_size=256, shuffle=False)

In [None]:
# Training
history = unet.fit(train_gen, validation_data=val_gen, epochs=200)

# ******
# TO DO: Adjust the path (on your GoogleDrive) for saving the trained model
# ******
unet.save("/content/drive/MyDrive/Colab Notebooks/TRACO_Budapest/trained_traco_unet.h5")

In [None]:
# Plot IoU score over epochs
plt.plot(history.history["iou_score"], label="training")
plt.plot(history.history["val_iou_score"], label="validation")
plt.xlabel("epochs")
plt.ylabel("IoU")
plt.show()

In [None]:
# Display some predictions for the validation images
for x, y in val_gen:
  predictions = unet.predict(x)
  break

def show_frame_and_prediction(i):
    plt.subplot(121)
    plt.imshow(x[i])
    plt.subplot(122)
    plt.imshow(predictions[i], cmap="gray")
    plt.show()

interact(show_frame_and_prediction, i=(0, 32))

The number of filters for U-Net, the learning rate, and number of epochs are examples for **hyperparameters**. To improve the performance of U-Net, you can try changing these and comparing the results.

## Check the predictions for you own recorded video!

In [None]:
# ******
# TO DO: Set up a DataGenerator for your recorded hexbug video
# ******

# ******
# TO DO: Load the trained U-Net model
# ******

# ******
# TO DO: Display the frames with predictions. You can use the function "show_frame_and_predictions"
# ******

## Convert predicted masks back to coordinates

In [None]:
# ******
# TO DO: Find individual heads of hexbugs in the predicted masks
# ******

# ******
# TO DO: Get the center coordinate of each head
# ******