In [None]:
## only run if using google Colab
from google.colab import drive
drive.mount('/content/drive')

import os
os.chdir('/content/drive/MyDrive/FCN')

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import pickle

import utils, augment, models

# for auto-reloading external modules
%load_ext autoreload
%autoreload 2

In [None]:
## read data from tfrecords files
train_dataset = tf.data.TFRecordDataset('data/train.tfrecords').map(utils.parse_example)
val_dataset = tf.data.TFRecordDataset('data/val.tfrecords').map(utils.parse_example)

In [None]:
## no dataset augmentation
train_dataset = train_dataset.map(augment.resize_with_pad)
val_dataset = val_dataset.map(augment.resize_with_pad)

In [None]:
## check data
for X, y in train_dataset.shuffle(100).take(1):
    plt.figure()
    plt.imshow(X)
    plt.show()
    
    plt.figure()
    plt.imshow(utils.label_to_image(utils.onehot_to_label(y)))
    plt.show()

# Models

In [None]:
## VGG16 base model
base_model = models.vgg16(l2=1e-6, dropout=0.2)

## Load ImageNet weights
## https://keras.io/api/applications/
vgg16 = keras.applications.vgg16.VGG16(weights='imagenet')
weight_list = vgg16.get_weights()
weight_list[26] = weight_list[26].reshape(7, 7, 512, 4096)
weight_list[28] = weight_list[28].reshape(1, 1, 4096, 4096)
weight_list[30] = weight_list[30].reshape(1, 1, 4096, 1000)
base_model.set_weights(weight_list)
del weight_list

In [None]:
## Test base model
from class_names.imagenet import class_names

img = utils.get_image('assets/laska.png')
plt.figure()
plt.imshow(img)
plt.show()

print("Top three guesses with probabilities:")
probs = base_model.predict(img[None])[0, 3, 3] # pick the central prediction from 7 x 7 square
pred = np.argsort(probs)[::-1]
for i in range(3):
    c = pred[i]
    print("%s (p=%.3f)" % (class_names[c], probs[c]))

In [None]:
## FCN32
fcn32 = models.fcn32(base_model, l2=1e-6)
## freeze lower layers for fine-tuning
fcn32.get_layer('block1_conv1').trainable=False
fcn32.get_layer('block1_conv2').trainable=False
fcn32.get_layer('block2_conv1').trainable=False
fcn32.get_layer('block2_conv2').trainable=False
fcn32.get_layer('block3_conv1').trainable=False
fcn32.get_layer('block3_conv2').trainable=False
fcn32.get_layer('block3_conv3').trainable=False
fcn32.get_layer('block4_conv1').trainable=False
fcn32.get_layer('block4_conv2').trainable=False
fcn32.get_layer('block4_conv3').trainable=False
## freeze upsample layer
fcn32.get_layer('fcn32').trainable=False

In [None]:
## Load weights, if trained
fcn32.load_weights('weights/fcn32_10e.h5')

In [None]:
## FCN16

fcn16 = models.fcn16(base_model, fcn32, l2=1e-5)

fcn16.get_layer('block4_conv1').trainable=True
fcn16.get_layer('block4_conv2').trainable=True
fcn16.get_layer('block4_conv3').trainable=True

## freeze upsample layer
fcn16.get_layer('score7_upsample').trainable=False
fcn16.get_layer('fcn16').trainable=False

In [None]:
## Load weights, if trained
fcn16.load_weights('weights/fcn16_10e.h5')

In [None]:
## FCN8

fcn8 = models.fcn8(base_model, fcn16, l2=1e-5)

fcn8.get_layer('block1_conv1').trainable=True
fcn8.get_layer('block1_conv2').trainable=True
fcn8.get_layer('block2_conv1').trainable=True
fcn8.get_layer('block2_conv2').trainable=True
fcn8.get_layer('block3_conv1').trainable=True
fcn8.get_layer('block3_conv2').trainable=True
fcn8.get_layer('block3_conv3').trainable=True

## freeze upsample layer
fcn8.get_layer('skip4_upsample').trainable=False
fcn8.get_layer('fcn8').trainable=False



# Training

In [None]:
## Load model
model = fcn32
model.summary()

In [None]:
## compile

## custom CategoricalCrossentropy---default one does not work
def crossentropy(y_true, y_pred):
    return -21*tf.math.reduce_mean(tf.cast(y_true, tf.float32) * tf.math.log(y_pred + 1e-7))
    
opt = keras.optimizers.Adam(learning_rate=1e-4)
#loss = keras.losses.CategoricalCrossentropy(name='crossentropy')
metrics = [crossentropy,
           keras.metrics.CategoricalAccuracy(name='pixelacc'),
           keras.metrics.MeanIoU(num_classes=21, name='meanIoU')]
model.compile(optimizer=opt, loss=crossentropy, metrics=metrics)

In [None]:
## test model
for X, y in train_dataset.shuffle(100).batch(1).take(1):
    model.evaluate(X, y)
    plt.figure()
    plt.imshow(X[0])
    plt.show()
    
    plt.figure()
    plt.imshow(utils.label_to_image(utils.onehot_to_label(model(X)[0])))
    plt.show()
    
    plt.figure()
    plt.imshow(utils.label_to_image(utils.onehot_to_label(y[0])))
    plt.show()

In [None]:
## check running GPU
tf.test.gpu_device_name()

In [None]:
## training and validation datasets
train = train_dataset.shuffle(2000).batch(20)
val = val_dataset.shuffle(1200).batch(20)

In [None]:
history = model.fit(train, epochs=10, validation_data=val)

In [None]:
## Save weights and metrics

model.save_weights('weights/fcn8_10.h5')

with open('logs/fcn32_10', 'wb') as f:
    pickle.dump(history.history, f)

# Testing

In [None]:
## Test model

sample = val_dataset.take(10)
for X, y in sample:
  
  utils.display_image(X)
  utils.display_image(y)

  y_pred = model.predict(X[np.newaxis,:,:,:])
  utils.display_image(y_pred[0])

In [None]:
## example image

image = utils.get_image('assets/biker.jpg')
label = utils.get_label_png('assets/biker_label.png')

y_pred = model.predict(np.expand_dims(image, 0))

utils.display_image(image)
utils.display_image(label)
utils.display_image(y_pred[0])

In [None]:
## Test on validation set

val_test = val_dataset.batch(32)
model.evaluate(val_test)

# Plots

In [None]:
with open('logs/fcn32_10e', 'rb') as f:
  h1 = pickle.load(f)
with open('logs/fcn16_10e', 'rb') as f:
  h2 = pickle.load(f)
with open('logs/fcn8_10e', 'rb') as f:
  h3 = pickle.load(f)

In [None]:
plt.plot(range(1,11), h1['cross_entropy'], '-', color='blue', label='FCN32 training loss')
plt.plot(range(1,11), h1['val_cross_entropy'], '--', color='blue', label='FCN32 validation loss')
plt.plot(range(1,11), h2['cross_entropy'], '-', color='red', label='FCN16 training loss')
plt.plot(range(1,11), h2['val_cross_entropy'], '--', color='red', label='FCN16 validation loss')
plt.plot(range(1,11), h3['cross_entropy'], '-', color='green', label='FCN8 training loss')
plt.plot(range(1,11), h3['val_cross_entropy'], '--', color='green', label='FCN8 validation loss')
plt.legend()
plt.show()


In [None]:
plt.plot(range(1,11), h1['meanIoU'], '-', color='blue', label='FCN32 training meanIoU')
plt.plot(range(1,11), h1['val_meanIoU'], '--', color='blue', label='FCN32 validation meanIoU')
plt.plot(range(1,11), h2['meanIoU'], '-', color='red', label='FCN16 training meanIoU')
plt.plot(range(1,11), h2['val_meanIoU'], '--', color='red', label='FCN16 validation meanIoU')
plt.plot(range(1,11), h3['meanIoU'], '-', color='green', label='FCN8 training meanIoU')
plt.plot(range(1,11), h3['val_meanIoU'], '--', color='green', label='FCN8 validation meanIoU')
plt.legend()
plt.show()