## **Data and code setup**

In [None]:
%%capture
!pip install gdown
!pip3 install gpustat

In [None]:
%%capture
!git clone https://github.com/gevero/enet_tensorflow.git

In [None]:
%%capture
!gdown https://drive.google.com/uc?id=1gt0nCGft0winZqHBYaTb1EL6zM8lrKPA
!unzip -o camvid.zip

## **Notebook Setup**

In [None]:
# update to tf 2.0
from __future__ import absolute_import, division, print_function, unicode_literals

# Install TensorFlow
try:
  # %tensorflow_version only exists in Colab.
  %tensorflow_version 2.1
except Exception:
  pass

# importing standard libraries
import tensorflow as tf
print(tf.__version__)
import matplotlib.pylab as plt
import numpy as np
import os, os.path
from functools import partial
from google.colab import files

# Importing utils and models
import sys
sys.path.append('./enet_tensorflow')
from utils import preprocess_img_label, map_singlehead, map_doublehead, map_label, tf_dataset_generator, get_class_weights
from models import EnetModel

## **Create training test and validation dataset, and get class weights**

In [None]:
%%capture
# creating datasets
img_pattern = "./dataset/train/images/*.png"
label_pattern = "./dataset/train/labels/*.png"
img_pattern_val = "./dataset/val/images/*.png"
label_pattern_val = "./dataset/val/labels/*.png"
img_pattern_test = "./dataset/test/images/*.png"
label_pattern_test = "./dataset/test/labels/*.png"

# batch size
batch_size = 8

# image size
img_height = 360
img_width = 480
h_enc = img_height // 8
w_enc = img_width // 8
h_dec = img_height
w_dec = img_width

# create (img,label) string tensor lists
filelist_train = preprocess_img_label(img_pattern, label_pattern)
filelist_val = preprocess_img_label(img_pattern_val, label_pattern_val)
filelist_test = preprocess_img_label(img_pattern_test, label_pattern_test)

# training dataset size
n_train = tf.data.experimental.cardinality(filelist_train).numpy()
n_val = tf.data.experimental.cardinality(filelist_val).numpy()
n_test = tf.data.experimental.cardinality(filelist_test).numpy()

# define mapping functions for single and double head nets
map_single = lambda img_file, label_file: map_singlehead(
    img_file, label_file, h_dec, w_dec)
map_double = lambda img_file, label_file: map_doublehead(
    img_file, label_file, h_enc, w_enc, h_dec, w_dec)

# create single head datasets
train_single_ds = filelist_train.shuffle(n_train).map(map_single).cache().batch(batch_size).repeat()
val_single_ds = filelist_val.map(map_single).cache().batch(batch_size).repeat()
test_single_ds = filelist_test.map(map_single).cache().batch(batch_size).repeat()

# create double head datasets
train_double_ds = filelist_train.shuffle(n_train).map(map_double).cache().batch(batch_size).repeat()
val_double_ds = filelist_val.map(map_double).cache().batch(batch_size).repeat()
test_double_ds = filelist_test.map(map_double).cache().batch(batch_size).repeat()

# get class weights
label_filelist = tf.data.Dataset.list_files(label_pattern, shuffle=False)
label_ds = label_filelist.map(lambda x: map_label(x, h_dec, w_dec))
class_weights = get_class_weights(label_ds).tolist()

## **Example (Image,Label) pair from the training set**

In [None]:
for img,iml in train_single_ds.take(1):
  plt.figure(figsize=(15,10))
  plt.subplot(1,2,1)
  plt.imshow(img.numpy()[0,:,:,:])
  plt.subplot(1,2,2)
  plt.imshow(iml.numpy()[0,:,:,0])

## **1 - Two stage training: first Encoder then Decoder**

### Training the encoder

In [None]:
Enet = EnetModel(C=12,MultiObjective=True,l2=1e-3)

In [None]:
for layer in Enet.layers[-6:]:
  layer.trainable = False

In [None]:
# compile model: only the first objective matters
n_epochs = 60
adam_optimizer = tf.keras.optimizers.Adam(learning_rate=5e-4)
Enet.compile(optimizer=adam_optimizer,
             loss=['sparse_categorical_crossentropy','sparse_categorical_crossentropy'],
             metrics=['accuracy','accuracy'],
             loss_weights=[1.0,0.0])

In [None]:
enet_enc_history = Enet.fit(x= train_double_ds,
        epochs=n_epochs,
        steps_per_epoch=n_train//batch_size,
        validation_data= val_double_ds,
        validation_steps=n_val//batch_size//5,
        class_weight=[class_weights,class_weights])

### Training the decoder

In [None]:
for layer in Enet.layers[-6:]:
  layer.trainable = True
for layer in Enet.layers[:-6]:
  layer.trainable = False

In [None]:
# compile model: only the first objective matters
n_epochs = 60
adam_optimizer = tf.keras.optimizers.Adam(learning_rate=5e-4)
Enet.compile(optimizer=adam_optimizer,
             loss=['sparse_categorical_crossentropy','sparse_categorical_crossentropy'],
             metrics=['accuracy','accuracy'],
             loss_weights=[0.0,1.0])

In [None]:
enet_dec_history = Enet.fit(x= train_double_ds,
        epochs=n_epochs,
        steps_per_epoch=n_train//batch_size,
        validation_data= val_double_ds,
        validation_steps=n_val//batch_size//5,
        class_weight=[class_weights,class_weights])

### Check performance

In [None]:
Enet.evaluate(x=test_double_ds,steps=n_test//batch_size)

In [None]:
loss = enet_dec_history.history['loss']
val_loss = enet_dec_history.history['val_loss']
acc = enet_dec_history.history['output_2_accuracy']
val_acc = enet_dec_history.history['val_output_2_accuracy']

epochs = range(n_epochs)

plt.figure(figsize=(12,8))
plt.plot(epochs, loss/np.max(loss), 'r', label='Training loss')
plt.plot(epochs, val_loss/np.max(val_loss), 'b', label='Validation loss')
plt.plot(epochs, acc, 'r:', label='Training accuracy')
plt.plot(epochs, val_acc, 'b:', label='Validation accuracy')

plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss Value')
plt.ylim([0, 1])
plt.legend()
plt.show()

## **2 - Training both objectives simultaneously**

###Training

In [None]:
EnetMulti = EnetModel(C=12,MultiObjective=True,l2=1e-3)

In [None]:
# compile model: only the first objective matters
n_epochs = 80
adam_optimizer = tf.keras.optimizers.Adam(learning_rate=5e-4)
EnetMulti.compile(optimizer=adam_optimizer,
             loss=['sparse_categorical_crossentropy','sparse_categorical_crossentropy'],
             metrics=['accuracy','accuracy'],
             loss_weights=[0.5,0.5])

In [None]:
enet_multi_history = EnetMulti.fit(x= train_double_ds,
        epochs=n_epochs,
        steps_per_epoch=n_train//batch_size,
        validation_data= val_double_ds,
        validation_steps=n_val//batch_size//5,
        class_weight=[class_weights,class_weights])

### Check performance

In [None]:
EnetMulti.evaluate(x=test_double_ds,steps=n_test//batch_size)

In [None]:
loss = enet_multi_history.history['loss']
val_loss = enet_multi_history.history['val_loss']
acc = enet_multi_history.history['output_2_accuracy']
val_acc = enet_multi_history.history['val_output_2_accuracy']

epochs = range(n_epochs)

plt.figure(figsize=(12,8))
plt.plot(epochs, loss/np.max(loss), 'r', label='Training loss')
plt.plot(epochs, val_loss/np.max(val_loss), 'b', label='Validation loss')
plt.plot(epochs, acc, 'r:', label='Training accuracy')
plt.plot(epochs, val_acc, 'b:', label='Validation accuracy')

plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss Value')
plt.ylim([0, 1])
plt.legend()
plt.show()

## **3 - EndtoEnd Training**

In [None]:
EnetEndToEnd = EnetModel(C=12,MultiObjective=False,l2=1e-3)

In [None]:
# compile model: only the first objective matters
n_epochs = 80
adam_optimizer = tf.keras.optimizers.Adam(learning_rate=5e-4)
EnetEndToEnd.compile(optimizer=adam_optimizer,
             loss=['sparse_categorical_crossentropy'],
             metrics=['accuracy'])

In [None]:
enet_endtoend_history = EnetEndToEnd.fit(x= train_single_ds,
        epochs=n_epochs,
        steps_per_epoch=n_train//batch_size,
        validation_data= val_single_ds,
        validation_steps=n_val//batch_size//5,
        class_weight=class_weights)

### Check performance

In [None]:
EnetEndToEnd.evaluate(x=test_single_ds,steps=n_test//batch_size)

In [None]:
loss = enet_endtoend_history.history['loss']
val_loss = enet_endtoend_history.history['val_loss']
acc = enet_endtoend_history.history['accuracy']
val_acc = enet_endtoend_history.history['val_accuracy']

epochs = range(n_epochs)

plt.figure(figsize=(12,8))
plt.plot(epochs, loss/np.max(loss), 'r', label='Training loss')
plt.plot(epochs, val_loss/np.max(val_loss), 'b', label='Validation loss')
plt.plot(epochs, acc, 'r:', label='Training accuracy')
plt.plot(epochs, val_acc, 'b:', label='Validation accuracy')

plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss Value')
plt.ylim([0, 1])
plt.legend()
plt.show()

## **Test Masks**

In [None]:
def create_mask(pred_mask):
  pred_mask = tf.argmax(pred_mask, axis=-1)
  pred_mask = pred_mask[..., tf.newaxis]
  return pred_mask[0]
for img,iml in train_dec_ds.take(10):
  img_test = img
  iml_test = iml

img_enc_probs, img_dec_probs = Enet(img_test[0:1,:,:,:])
img_enc_probs, img_multi_probs = EnetMulti(img_test[0:1,:,:,:])
img_endtoend_probs = EnetEndToEnd(img_test[0:1,:,:,:])
img_dec_out = create_mask(img_dec_probs)
img_multi_out = create_mask(img_multi_probs)
img_endtoend_out = create_mask(img_endtoend_probs)

plt.figure(figsize=(20,10))
plt.subplot(2,3,1)
plt.xticks([])
plt.yticks([])
plt.title('Image',fontdict={'fontsize':20})
plt.imshow(img_test.numpy()[0,:,:,:])

plt.subplot(2,3,2)
plt.xticks([])
plt.yticks([])
plt.title('Ground Truth',fontdict={'fontsize':20})
plt.imshow(iml_test.numpy()[0,:,:,0])

plt.subplot(2,3,4)
plt.imshow(img_dec_out[:,:,0])
plt.xticks([])
plt.yticks([])
plt.title('Encoder + Decoder',fontdict={'fontsize':20})

plt.subplot(2,3,5)
plt.xticks([])
plt.yticks([])
plt.title('Multiple Objectives',fontdict={'fontsize':20})
plt.imshow(img_multi_out[:,:,0])

plt.subplot(2,3,6)
plt.xticks([])
plt.yticks([])
plt.title('End to End',fontdict={'fontsize':20})
plt.imshow(img_endtoend_out[:,:,0])

plt.tight_layout()
plt.savefig('./segmentation.png')

# **Save models**
You can download them in your google drive. Mount it with che command below and drag and drop the weight files

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
Enet.save_weights('Enet.tf')
EnetMulti.save_weights('EnetMulti.tf')
EnetEndToEnd.save_weights('EnetEndToEnd.tf')