In [None]:
%load_ext autoreload
%autoreload 2

import os
os.sys.path += ['slim']

import tensorflow as tf
import slim.nets.mobilenet.mobilenet_v2 as mobilenet_v2
import numpy as np
import matplotlib.pyplot as plt

from my_dataset import MyDataset
from model import mobilenet_backbone, load_mobilenet_weights, segmentation_head

## Option 1: birdeye images

In [None]:
shape = (30, 100)
train_file = '/home/ilya/random steering/dataset-utils/trial_birdeye_train.tfrecord'
val_file = '/home/ilya/random steering/dataset-utils/trial_birdeye_val.tfrecord'

## Option 2: untransformed images

In [None]:
shape = (256, 640)
train_file = '/home/ilya/random steering/dataset-utils/trial_train.tfrecord'
val_file = '/home/ilya/random steering/dataset-utils/trial_val.tfrecord'

## Load the filenames of the validation data.

In [None]:
filenames = []
sess = tf.Session()
for record in tf.python_io.tf_record_iterator(val_file):
    features = {
        'image':
            tf.io.FixedLenFeature((), tf.string, default_value=''),
        'height':
            tf.io.FixedLenFeature((), tf.int64, default_value=0),
        'width':
            tf.io.FixedLenFeature((), tf.int64, default_value=0),
        'mask':
            tf.io.FixedLenFeature((), tf.string, default_value=''),
        'name':
            tf.io.FixedLenFeature((), tf.string, default_value=''),
    }

    parsed_features = tf.io.parse_single_example(record, features)
    filenames.append(sess.run(parsed_features['name']))

filenames = [f.decode('ascii').replace('/', '_') for f in filenames]

## Create the TFData input pipeline.

In [None]:
tf.reset_default_graph()
sess = tf.Session()

train_dataset = MyDataset(filename=train_file,
                              batch_size=1,
                              shape=shape,
                              num_readers=2,
                              num_classes=2,
                              is_training=True,
                              should_shuffle=True,
                              should_repeat=False,
                              should_augment=False).get()

val_dataset = MyDataset(filename=val_file,
                       batch_size=1,
                       shape=shape,
                       num_readers=2,
                       num_classes=2,
                       is_training=True,
                       should_shuffle=True,
                       should_repeat=False,
                       should_augment=False).get()

train_iterator = train_dataset.make_initializable_iterator()
validation_iterator = val_dataset.make_initializable_iterator()
dataset_handle = tf.placeholder(tf.string, shape=[], name='dataset_handle')
iterator = tf.data.Iterator.from_string_handle(dataset_handle, train_dataset.output_types,
                                               train_dataset.output_shapes)
samples = iterator.get_next()

input_tensor = tf.reshape(samples[0], shape=(-1, shape[0], shape[1], 3))
is_training = tf.placeholder(tf.bool)

## Option 1: Training from scratch, create the model with pretrained Imagenet weights.

In [None]:
net = mobilenet_backbone(input_tensor, 0.35, output_stride=2, is_training=True, weight_decay=0.00001)
load_mobilenet_weights(sess, checkpoint='mobilenet_v2_0.35_224/mobilenet_v2_0.35_224.ckpt')
net = segmentation_head(input_tensor, net, is_training=True, weight_decay=0.00001, dropout=0)

head_variables = [t for t in tf.all_variables() if ('MobilenetV2' not in t.name or 'quant' in t.name)]

## Option 2: Continuing training from a checkpoint with quantization aware training.

In [None]:
net = mobilenet_backbone(input_tensor, 0.35, output_stride=2, is_training=True, weight_decay=0.00001)
net = segmentation_head(input_tensor, net, is_training=True, weight_decay=0.00001, dropout=0)

saver = tf.train.Saver(var_list=[t for t in tf.all_variables() if 'Conv_1' not in t.name and 'Logits' not in t.name])
saver.restore(sess, 'checkpoints/train-checkpoint')

g = tf.get_default_graph()
tf.contrib.quantize.create_training_graph(input_graph=g,
                                          quant_delay=0)
head_variables = [t for t in tf.all_variables() if ('MobilenetV2' not in t.name or 'quant' in t.name)]

## Set up the optimizer.

In [None]:
loss = tf.reduce_mean(tf.nn.weighted_cross_entropy_with_logits(samples[1], net, 
                                                               pos_weight=20))
loss = tf.add_n([loss] + [t for t in tf.losses.get_regularization_losses() if 'Conv_1' not in t.name and 'Logits' not in t.name])

optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
train_op = optimizer.minimize(loss)

## Initialize everything.

In [None]:
labels = tf.argmax(samples[1], axis=-1)
predictions = tf.argmax(tf.nn.softmax(net, axis=-1), axis=-1)

train_miou, train_miou_update = tf.metrics.mean_iou(labels, predictions, num_classes=2)
val_miou, val_miou_update = tf.metrics.mean_iou(labels, predictions, num_classes=2)

if head_variables:
    tf.variables_initializer(head_variables).run(session=sess)
    
tf.local_variables_initializer().run(session=sess)
tf.variables_initializer(optimizer.variables()).run(session=sess)

train_handle, validation_handle = sess.run([train_iterator.string_handle(), validation_iterator.string_handle()])
sess.run(train_iterator.initializer)
global_step = tf.train.get_global_step()

## Train.

In [None]:
show_images_every = 5

for epoch in range(100):
    try:
        batch = 0
        losses = []
        while True:
            _, t_loss, t_miou, _ = sess.run([train_op, loss, train_miou, train_miou_update], 
                                            {dataset_handle: train_handle, is_training: True})
            losses.append(t_loss)
            print('\rEpoch: {}\tBatch: {}\tTrain loss: {:06.4f}\tMiou: {:06.4f}'.format(epoch, batch, np.mean(losses), t_miou), end='')
            batch += 1
    except tf.errors.OutOfRangeError:
        print('')
        sess.run(validation_iterator.initializer)
        try:
            print('Epoch {}:'.format(epoch))
            losses = []
            while True:
                v_loss, v_miou, _, img, label, pred = sess.run([loss, val_miou, val_miou_update, input_tensor, labels, predictions], 
                                             {dataset_handle: validation_handle, is_training: False})
                losses.append(v_loss)
                
                if epoch % show_images_every != 0:
                    continue
                    
                for i in range(img.shape[0]):
                    _, axes = plt.subplots(1, 2, figsize=(15, 7))
                    
                    alpha = 0.2
                    overlayed = (img[i].copy() + 1.) * 255. / 2.
                    overlayed[label[i] == 1] = (alpha * np.array([0, 255, 0]) + (1 - alpha) * overlayed[label[i] == 1]).round()
                    overlayed = overlayed.astype(np.uint8)
                    axes[0].imshow(overlayed)
                    
                    overlayed = (img[i].copy() + 1.) * 255. / 2.
                    overlayed[pred[i] == 1] = (alpha * np.array([0, 255, 0]) + (1 - alpha) * overlayed[pred[i] == 1]).round()
                    overlayed = overlayed.astype(np.uint8)
                    axes[1].imshow(overlayed)
                    
                    plt.suptitle(filenames[i])
                    plt.show()
        except tf.errors.OutOfRangeError:
            sess.run(train_iterator.initializer)
            print('Val loss: {:06.4f}\tMiou: {:06.4f}'.format(np.mean(loss), v_miou), end='')
            print('')

## Saving the train checkpoint

In [None]:
to_save = [t for t in tf.all_variables() if ('Conv_1' not in t.name and 'Logits' not in t.name)]
saver = tf.train.Saver(to_save)
saver.save(sess, 'checkpoints/train-checkpoint')

## Convert the train checkpoint to the eval checkpoint for TFLite (the model should have been trained with quantization).

In [None]:
tf.reset_default_graph()
sess = tf.Session()
input_tensor = tf.placeholder(tf.float32, (1, 256, 640, 3))
net = mobilenet_backbone(input_tensor, 0.35, output_stride=8, weight_decay=0., is_training=False)
net = segmentation_head(input_tensor, net, weight_decay=0., is_training=False, dropout=0)

g = tf.get_default_graph()
tf.contrib.quantize.create_eval_graph(input_graph=g)
to_save = [t for t in tf.all_variables() if 'Conv_1' not in t.name and 'Logits' not in t.name]
saver = tf.train.Saver(var_list=to_save)
saver.restore(sess, 'checkpoints/train-checkpoint')

with open('eval_graph.pb', 'w') as f:
    f.write(str(g.as_graph_def()))
# to_save = [t for t in tf.all_variables() if 'Conv_1' not in t.name and 'Logits' not in t.name]
saver = tf.train.Saver(var_list=to_save)
saver.save(sess, 'checkpoints/eval-checkpoint')