In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import itertools
import math
import os
import pickle
from timeit import default_timer as timer

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

# main.py resets stdout, so keep a reference here
import sys
stdout = sys.stdout

In [None]:
import helper

In [None]:
from main import *
sys.stdout = stdout

# Setup

## Example Images

In [None]:
def show_example_images(num_images):
    get_batches_fn = helper.gen_batch_function(
        os.path.join(DATA_DIR, 'data_road/training'), IMAGE_SHAPE)
    
    fig, axes = plt.subplots(
        num_images, 3,
        figsize=(15, 2*num_images))

    i = 0
    for inputs, labels in itertools.islice(get_batches_fn(1), num_images):
        axes[i][0].imshow(inputs[0])
        road_mask = labels[0,:,:,1]
        axes[i][1].imshow(road_mask, cmap='gray')
        axes[i][2].imshow(inputs[0] * road_mask[:,:,np.newaxis])
        i += 1
show_example_images(3)

# Training

## Train / Validation Split

The `helper.gen_batch_function` returns a generator that returns the images in shuffled order, so we just have to cut the total training set into a training set and a validation set; we use a roughly 80/20 split. Pickle the data for fast access later.

In [None]:
TRAINING_FILE = os.path.join(DATA_DIR, 'train_1.pickle')
VALIDATION_FILE = os.path.join(DATA_DIR, 'validation_1.pickle')

NUM_TOTAL_TRAINING_IMAGES = 289
NUM_TRAINING_IMAGES = 230
NUM_VALIDATION_IMAGES = NUM_TOTAL_TRAINING_IMAGES - NUM_TRAINING_IMAGES

def save_input_and_labels(pathname, inputs, labels):
    with open(pathname, 'wb') as file:
        pickle.dump({
            'inputs': inputs,
            'labels': labels
        }, file, pickle.HIGHEST_PROTOCOL)

def generate_training_validation_split():
    if os.path.exists(TRAINING_FILE):
        print('Training set already exists.')
        return
    
    get_batches_fn = helper.gen_batch_function(
        os.path.join(DATA_DIR, 'data_road/training'), IMAGE_SHAPE)
    
    training_inputs = []
    training_labels = []
    validation_inputs = []
    validation_labels = []
    i = 0
    for inputs, labels in get_batches_fn(1):
        if i < NUM_TRAINING_IMAGES:
            training_inputs.append(inputs[0])
            training_labels.append(labels[0])
        else:
            validation_inputs.append(inputs[0])
            validation_labels.append(labels[0])
        i += 1
        
    save_input_and_labels(
        TRAINING_FILE,
        np.array(training_inputs),
        np.array(training_labels))
    save_input_and_labels(
        VALIDATION_FILE,
        np.array(validation_inputs),
        np.array(validation_labels))
        
generate_training_validation_split()

In [None]:
def load_input_and_labels(file):
    with open(file, 'rb') as f:
        data = pickle.load(f)
        return data['inputs'], data['labels']
        
TRAINING_INPUTS, TRAINING_LABELS = load_input_and_labels(TRAINING_FILE)
VALIDATION_INPUTS, VALIDATION_LABELS = load_input_and_labels(VALIDATION_FILE)

# for testing
# TRAINING_INPUTS = TRAINING_INPUTS[:5,:,:,:]
# TRAINING_LABELS = TRAINING_LABELS[:5,:,:,:]
# VALIDATION_INPUTS = VALIDATION_INPUTS[:3,:,:,:]
# VALIDATION_LABELS = VALIDATION_LABELS[:3,:,:,:]

In [None]:
[TRAINING_INPUTS.shape, VALIDATION_INPUTS.shape]

## Tune

In [None]:
# From http://stackoverflow.com/a/40623158/2053820
def dict_product(dicts):
    """
    >>> list(dict_product(dict(number=[1,2], character='ab')))
    [{'character': 'a', 'number': 1},
     {'character': 'a', 'number': 2},
     {'character': 'b', 'number': 1},
     {'character': 'b', 'number': 2}]
    """
    return (dict(zip(dicts, x)) for x in itertools.product(*dicts.values()))

In [None]:
GRID_FILE = os.path.join(DATA_DIR, 'grid_2.pickle')

def run_grid_point(params):
    params = params.copy()
    batch_size = params.pop('batch_size')
    max_epochs_without_progress = params.pop('max_epochs_without_progress')
    max_epochs = params.pop('max_epochs')
    keep_prob_value = params.pop('keep_prob')
    learning_rate_value = params.pop('learning_rate')
    
    num_training_batches = \
        int(math.ceil(TRAINING_INPUTS.shape[0] / batch_size))
    num_validation_batches = \
        int(math.ceil(VALIDATION_INPUTS.shape[0] / batch_size))
    
    tf.reset_default_graph()
    with tf.Session() as sess:
        train_op, cross_entropy_loss, image_input, correct_label, \
            keep_prob, learning_rate, logits = build(sess, params)
            
        mean_iou, mean_iou_update_op = tf.metrics.mean_iou(
            tf.reshape(correct_label[:,:,:,1], [-1]),
            tf.nn.softmax(logits)[:,1] > 0.5,
            NUM_CLASSES)

        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        
        start = timer()
        epochs_without_progress = 0
        best_epoch = 0
        best_validation_loss = math.inf
        best_mean_iou = 0
        for epoch in range(max_epochs):
            if epochs_without_progress > max_epochs_without_progress:
                break
            
            for batch in range(num_training_batches):
                batch_start = batch * batch_size
                batch_end = batch_start + batch_size
                batch_inputs = TRAINING_INPUTS[batch_start:batch_end,:,:,:]
                batch_labels = TRAINING_LABELS[batch_start:batch_end,:,:,:]
                training_loss, _ = sess.run(
                    [cross_entropy_loss, train_op], {
                        image_input: batch_inputs,
                        correct_label: batch_labels,
                        keep_prob: keep_prob_value,
                        learning_rate: learning_rate_value
                    }
                )
                print('training', epoch, batch, training_loss)
            
            validation_loss = 0
            validation_mean_iou = 0
            for batch in range(num_validation_batches):
                batch_start = batch * batch_size
                batch_end = batch_start + batch_size
                batch_inputs = VALIDATION_INPUTS[batch_start:batch_end,:,:,:]
                batch_labels = VALIDATION_LABELS[batch_start:batch_end,:,:,:]
                batch_loss, batch_mean_iou, _ = sess.run(
                    [cross_entropy_loss, mean_iou, mean_iou_update_op], {
                        image_input: batch_inputs,
                        correct_label: batch_labels,
                        keep_prob: 1.0
                    }
                )
                actual_batch_size = batch_inputs.shape[0]
                validation_loss += actual_batch_size * batch_loss
                validation_mean_iou += actual_batch_size * batch_mean_iou
            
            validation_loss /= NUM_VALIDATION_IMAGES
            validation_mean_iou /= NUM_VALIDATION_IMAGES

            print('validation', epoch, validation_loss, validation_mean_iou)
            
            if validation_loss < best_validation_loss:
                best_epoch = epoch
                best_validation_loss = validation_loss
                best_mean_iou = validation_mean_iou
                epochs_without_progress = 0
            else:
                epochs_without_progress += 1
            
        return {
            'best_epoch': best_epoch,
            'best_validation_loss': best_validation_loss,
            'best_mean_iou': best_mean_iou,
            'time': timer() - start,
        }
            
def run_grid():
    if os.path.isfile(GRID_FILE):
        with open(GRID_FILE, 'rb') as f:
            results = pickle.load(f)
    else:
        results = {}
        
    params_dict = {
        'batch_size': [13],
        'max_epochs_without_progress': [3],
        'max_epochs': [50],
        'keep_prob': [0.5],
        'learning_rate': [0.001, 0.0001, 0.00001],
        'kernel_size_3': [8, 16],
        'kernel_size_4': [2, 4],
        'kernel_size_7': [2, 4],
        'conv_1x1_depth': [0, 2048, 4096]
    }
    
    for params in dict_product(params_dict):
        print(params)
        
        frozen_key = frozenset(params.items())
        if frozen_key in results:
            continue
            
        results[frozen_key] = run_grid_point(params)
        
        with open(GRID_FILE, 'wb') as f:
            pickle.dump(results, f, pickle.HIGHEST_PROTOCOL)
        
run_grid()

In [None]:
def summarize_grid(file):
    with open(file, 'rb') as f:
        results = pickle.load(f)
        
    print(results)
    
summarize_grid(GRID_FILE)