In [1]:
import tensorflow as tf
import numpy as np
import sys

from datetime import datetime

python_root = '../'
sys.path.insert(0, python_root)
from model.alexnet import AlexNet
from model.triplet_loss import batch_all_triplet_loss, batch_hardest_triplet_loss

In [2]:
# Set a seed for numpy.
np.random.seed(1)

## Model

The virtual-camera images as well as the line data can be read either directly from their location on the disk or from compressed _pickle_ files, that store all these info in a compact way, in a dict-like structure. This representation is particularly useful when the time required to access the file from the disk constitutes the bottleneck of the process (e.g., when training on ETH's cluster Euler and having the data files in a `scratch` folder and the training scripts in the `home` folder): the data is directly loaded in the memory, therefore allowing faster access. To enable this modality, set `read_as_pickle` to `True`. To use regular textfiles and images (generated by the previous steps of the pipeline), set `read_as_pickle` to `False` (make sure to place the data in the correct locations, in particular check that the path `LINESANDIMAGESFOLDER_PATH` in the config file `python/config_paths_and_variables.sh` is set to right value).

__Configuration settings (folders, learning parameters, name of the job, etc.)__

In [3]:
# Configuration settings.
from datetime import datetime
import os

from tools import pathconfig

# Name of the training job (used to properly name the folders when the results are
# stored.)
job_name = datetime.now().strftime("%d%m%y_%H%M")

# Set read_as_pickle to True to interpret the train/val files below as pickle
# files, False to interpret them as regular text files with the format outputted
# by split_dataset_with_labels_world.py.
read_as_pickle = True

if (read_as_pickle):
    pickleandsplit_path = pathconfig.obtain_paths_and_variables(
            "PICKLEANDSPLIT_PATH")
    # * Pickle-files version: path of the pickle files to use for training
    #       and validation. Note: more than one pickle file at a time can be
    #       used for both training and validation. Therefore, train_files
    #       and val_files should both be lists.
    train_files = [os.path.join(pickleandsplit_path, 'train_0/traj_1/pickled_train.pkl')]
    val_files = [os.path.join(pickleandsplit_path, 'train_0/traj_1/pickled_val.pkl')]
else:
    linesandimagesfolder_path = pathconfig.obtain_paths_and_variables(
            "LINESANDIMAGESFOLDER_PATH")
    # * Textfile version: path to the textfiles for the trainings and
    #       validation set.
    # TODO: fix to use this non-pickle version. The paths in train_files
    # and val_files below should work, but the current version does not
    # allow to train on several sets at a time with textfiles. Also,
    # textfiles do not contain the endpoints of the lines, but only their
    # center point, making it therefore not possible to use the line
    # direction/orthonormal representation required by the last version of
    # the network.
    train_files = [
        os.path.join(linesandimagesfolder_path, dataset_name, 'traj_{}'.format(
            trajectory), 'train.txt')
    ]
    val_files = [
        os.path.join(linesandimagesfolder_path, dataset_name, 'traj_{}'.format(
            trajectory), 'val.txt')
    ]
    
# Either 'bgr' or 'bgr-d': type of the image fed to the network.
image_type = 'bgr-d'

# Type of line parametrization can be:
# * 'direction_and_centerpoint':
#      Each line segment is parametrized by its center point and by its unit
#      direction vector.  To obtain invariance on the orientation of the
#      line (i.e., given the two endpoints we do NOT want to consider one of
#      them as the start and the other one as the end of the line segment),
#      we enforce that the first entry should be non-negative. => 6
#      parameters per line.
# * 'orthonormal':
#      A line segment is parametrized with a minimum-DOF parametrization
#      (4 degrees of freedom) of the infinite line that it belongs to. The
#      representation is called orthonormal. => 4 parameters per line.
line_parametrization = 'direction_and_centerpoint'

# Folder that store the output logs.
log_files_folder = "./logs/"

# Learning parameters.
learning_rate = 0.0001
num_epochs = 90
batch_size = 128
# Margin of the triplet loss.
margin = 0.2
# Either "batch_all" or "batch_hard". Strategy for triplets selection.
triplet_strategy = "batch_all"

# Network parameters.
dropout_rate = 0.5

# How often we want to write the tf.summary data to disk.
display_step = 1

# Path for tf.summary.FileWriter and to store model checkpoints.
filewriter_path = os.path.join(log_files_folder, job_name,
                               "triplet_loss_{}".format(triplet_strategy))
checkpoint_path = os.path.join(
    log_files_folder, job_name,
    "triplet_loss_{}_ckpt".format(triplet_strategy))
    
# Create parent path if it does not exist.
if not os.path.isdir(checkpoint_path):
    os.makedirs(checkpoint_path)

__Create network structure__

In [4]:
from tools.train_utils import get_train_set_mean

# Check if checkpoints already exist.
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_path)

# Input placeholder.
if image_type == 'bgr':
    input_img = tf.placeholder(
        tf.float32, [None, 227, 227, 3], name="input_img")
elif image_type == 'bgr-d':
    input_img = tf.placeholder(
        tf.float32, [None, 227, 227, 4], name="input_img")

# For each line, labels is in the format
#   [line_center (3x)] [instance label (1x)]
labels = tf.placeholder(tf.float32, [None, 4], name="labels")
# Dropout probability.
keep_prob = tf.placeholder(tf.float32)
# Line types.
line_types = tf.placeholder(tf.float32, [None, 1], name="line_types")
# Geometric information.
if line_parametrization == 'direction_and_centerpoint':
    geometric_info = tf.placeholder(
        tf.float32, [None, 6], name="geometric_info")
elif line_parametrization == 'orthonormal':
    geometric_info = tf.placeholder(
        tf.float32, [None, 4], name="geometric_info")
else:
    raise ValueError("Line parametrization should be "
                     "'direction_and_centerpoint' or 'orthonormal'.")

# Layers for which weights should not be trained.
no_train_layers = ['conv1', 'pool1', 'norm1', 'conv2', 'pool2', 'norm2']
# Layers for which ImageNet weights should not be loaded.
skip_layers = ['fc8', 'fc9']

# Initialize model.
model = AlexNet(
    x=input_img,
    line_types=line_types,
    geometric_info=geometric_info,
    keep_prob=keep_prob,
    skip_layer=skip_layers,
    input_images=image_type)
    
# Retrieve embeddings (cluster descriptors) from model output.
embeddings = tf.nn.l2_normalize(model.fc9, axis=1)

# Get mean of training set if the training is just starting (i.e., if no
# previous checkpoints are found).
if latest_checkpoint is None:
    train_set_mean = get_train_set_mean(
        train_files, image_type, read_as_pickle=read_as_pickle)
    train_set_mean_tensor = tf.convert_to_tensor(
        train_set_mean, dtype=np.float64)
    train_set_mean_variable = tf.Variable(
        initial_value=train_set_mean_tensor,
        trainable=False,
        name="train_set_mean")
else:
    if image_type == 'bgr':
        train_set_mean_shape = (3,)
    elif image_type == 'bgr-d':
        train_set_mean_shape = (4,)
    # The value will be restored from the checkpoint.
    train_set_mean_variable = tf.get_variable(
        name="train_set_mean",
        shape=train_set_mean_shape,
        dtype=tf.float64,
        trainable=False)

# List of trainable variables of the layers we want to train.
var_list = [
    v for v in tf.trainable_variables()
    if v.name.split('/')[0] not in no_train_layers
]

# List of parameters trained.
total_parameters = 0
print("**** List of variables used for training ****")
for var in var_list:
    shape = var.get_shape()
    var_parameters = 1
    for dim in shape:
        var_parameters *= dim.value
    print("{0} --- {1} parameters".format(var.name, var_parameters))
    total_parameters += var_parameters
print("Total number of parameters is {}".format(total_parameters))

**** List of variables used for training ****
conv3/weights:0 --- 884736 parameters
conv3/biases:0 --- 384 parameters
conv4/weights:0 --- 663552 parameters
conv4/biases:0 --- 384 parameters
conv5/weights:0 --- 442368 parameters
conv5/biases:0 --- 256 parameters
fc6/weights:0 --- 37748736 parameters
fc6/biases:0 --- 4096 parameters
fc7/weights:0 --- 16777216 parameters
fc7/biases:0 --- 4096 parameters
fc8/weights:0 --- 16834609 parameters
fc8/biases:0 --- 4103 parameters
fc9/weights:0 --- 262592 parameters
fc9/biases:0 --- 64 parameters
Total number of parameters is 73627192


__Define loss and train operation. Also retrieve statistics about the triplets selected while training.__

In [5]:
# Define loss.
with tf.name_scope("triplet_loss"):
    if triplet_strategy == "batch_all":
        (loss, fraction, valid_triplets,
         pairwise_dist) = batch_all_triplet_loss(
             labels, embeddings, margin=margin, squared=False)
    elif triplet_strategy == "batch_hard":
        (loss, mask_anchor_positive, mask_anchor_negative,
         hardest_positive_dist, hardest_negative_dist,
         hardest_positive_element, hardest_negative_element,
         pairwise_dist) = batch_hardest_triplet_loss(
             labels, embeddings, margin=margin, squared=False)
    else:
        raise ValueError(
            "Triplet strategy not recognized: {}".format(triplet_strategy))
    # The following only to assign a name to the tensor.
    loss = tf.identity(loss, name="train_loss")
# Train operation.
with tf.name_scope("train"):
    # Get gradients of all trainable variables.
    gradients = tf.gradients(loss, var_list)
    gradients = list(zip(gradients, var_list))

    # Create optimizer and apply gradient descent to the trainable
    # variables.
    optimizer = tf.train.GradientDescentOptimizer(learning_rate)
    train_op = optimizer.apply_gradients(grads_and_vars=gradients)

__Add variables to summary and initialize the Tensorflow saver__

In [6]:
# Add gradients to summary.
for gradient, var in gradients:
    tf.summary.histogram(var.name + '/gradient', gradient)

# Add the variables we train to the summary.
for var in var_list:
    tf.summary.histogram(var.name, var)

# Add the loss to summary.
if triplet_strategy == "batch_all":
    tf.summary.scalar('triplet_loss', loss)
    tf.summary.scalar('fraction_positive_triplets', fraction)
elif triplet_strategy == "batch_hard":
    tf.summary.scalar('triplet_loss', loss)

# Add embedding_mean_norm (should always be 1) to summary.
embedding_mean_norm = tf.reduce_mean(tf.norm(embeddings, axis=1))
tf.summary.scalar("embedding_mean_norm", embedding_mean_norm)

merged_summary = tf.summary.merge_all()

# Initialize the FileWriter.
writer = tf.summary.FileWriter(filewriter_path)

# Initialize an saver to store model checkpoints.
saver = tf.train.Saver()

INFO:tensorflow:Summary name conv3/weights:0/gradient is illegal; using conv3/weights_0/gradient instead.
INFO:tensorflow:Summary name conv3/biases:0/gradient is illegal; using conv3/biases_0/gradient instead.
INFO:tensorflow:Summary name conv4/weights:0/gradient is illegal; using conv4/weights_0/gradient instead.
INFO:tensorflow:Summary name conv4/biases:0/gradient is illegal; using conv4/biases_0/gradient instead.
INFO:tensorflow:Summary name conv5/weights:0/gradient is illegal; using conv5/weights_0/gradient instead.
INFO:tensorflow:Summary name conv5/biases:0/gradient is illegal; using conv5/biases_0/gradient instead.
INFO:tensorflow:Summary name fc6/weights:0/gradient is illegal; using fc6/weights_0/gradient instead.
INFO:tensorflow:Summary name fc6/biases:0/gradient is illegal; using fc6/biases_0/gradient instead.
INFO:tensorflow:Summary name fc7/weights:0/gradient is illegal; using fc7/weights_0/gradient instead.
INFO:tensorflow:Summary name fc7/biases:0/gradient is illegal; usi

# Run model

In [None]:
from model.datagenerator import ImageDataGenerator
from tools.train_utils import print_batch_triplets_statistics
from tools.lines_utils import get_label_with_line_center, get_geometric_info

# Set to True to display statistics about the triplets selected while training.
display_triplets_statistics = False

with tf.Session() as sess:
    if latest_checkpoint is None:
        # No previous checkpoint to start training from => Train from
        # scratch.
        
        # Initialize all variables.
        sess.run(tf.global_variables_initializer())
        # Load the pretrained weights into the layers which are not in
        # skip_layers.
        model.load_initial_weights(sess)
        # Set first epoch to use for training as 0.
        starting_epoch = 0
    else:
        print("Found checkpoint {}".format(latest_checkpoint))
        # Load values of variables from checkpoint.
        saver.restore(sess, latest_checkpoint)
        # Set first epoch to use for training as the number in the last
        # checkpoint (note that the number saved in this filename
        # corresponds to the number of the last epoch + 1, cf. lines where
        # the checkpoints are saved).
        start_char = latest_checkpoint.find("epoch")
        if start_char == -1:
            print(
                "File name of checkpoint is in unexpected format: did not "
                "find ''epoch''. Exiting.")
            exit()
        else:
            start_char += 5  # Length of the string 'epoch'
            end_char = latest_checkpoint.find(".ckpt", start_char)
            if end_char == -1:
                print(
                    "File name of checkpoint is in unexpected format: did "
                    "not find ''.ckpt''. Exiting.")
                exit()
            else:
                starting_epoch = int(latest_checkpoint[start_char:end_char])

    # Obtain training set mean.
    train_set_mean = sess.run(train_set_mean_variable)
    print("Mean of train set: {}".format(train_set_mean))

    # Initialize generators for image data.
    train_generator = ImageDataGenerator(
        files_list=train_files,
        horizontal_flip=False,
        shuffle=True,
        image_type=image_type,
        mean=train_set_mean,
        read_as_pickle=read_as_pickle)
    val_generator = ImageDataGenerator(
        files_list=val_files,
        shuffle=True,
        image_type=image_type,
        mean=train_set_mean,
        read_as_pickle=read_as_pickle)

    # Get the number of training/validation steps per epoch.
    train_batches_per_epoch = np.floor(
        train_generator.data_size / batch_size).astype(np.int16)
    val_batches_per_epoch = np.floor(
        val_generator.data_size / batch_size).astype(np.int16)

    # Add the model graph to TensorBoard.
    writer.add_graph(sess.graph)

    print("{} Start training...".format(datetime.now()))
    print("{} Open Tensorboard at --logdir {}".format(
        datetime.now(), filewriter_path))

    print("Starting epoch is {0}, num_epochs is {1}".format(
        starting_epoch, num_epochs))

    # Loop over number of epochs.
    for epoch in range(starting_epoch, num_epochs):
        print("{} Epoch number: {}".format(datetime.now(), epoch + 1))
        step = 1

        while step < train_batches_per_epoch:
            # Get a batch of images and labels.
            (batch_input_img_train, batch_labels_train,
             batch_line_types_train
            ) = train_generator.next_batch(batch_size)
            # Pickled files have labels in the endpoints format -> convert
            # them to center format.
            labels_for_stats = batch_labels_train
            if read_as_pickle:
                batch_start_points_train = batch_labels_train[:, :3]
                batch_end_points_train = batch_labels_train[:, 3:6]
                batch_geometric_info_train = get_geometric_info(
                    start_points=batch_start_points_train,
                    end_points=batch_end_points_train,
                    line_parametrization=line_parametrization)
                batch_labels_train = get_label_with_line_center(
                    labels_batch=batch_labels_train)

            # Display statistics about triplets.
            if (display_triplets_statistics):
                if (triplet_strategy == 'batch_all'):
                    (pairwise_dist_for_stats,
                     valid_triplets_for_stats) = sess.run(
                            [pairwise_dist, valid_triplets],
                            feed_dict={
                                input_img: batch_input_img_train,
                                labels: batch_labels_train,
                                line_types: batch_line_types_train,
                                geometric_info: batch_geometric_info_train,
                                keep_prob: dropout_rate
                            })
                    print_batch_triplets_statistics(
                        triplet_strategy=triplet_strategy,
                        images=batch_input_img_train,
                        set_mean=train_set_mean,
                        batch_index=step,
                        epoch_index=epoch,
                        write_folder='{}_logs/'.format(job_name),
                        labels=labels_for_stats,
                        pairwise_dist=pairwise_dist_for_stats,
                        valid_triplets=valid_triplets_for_stats)
                elif (triplet_strategy == 'batch_hard'):
                    (pairwise_dist_for_stats,
                     mask_anchor_positive_for_stats,
                     mask_anchor_negative_for_stats,
                     hardest_positive_dist_for_stats,
                     hardest_negative_dist_for_stats,
                     hardest_positive_element_for_stats,
                     hardest_negative_element_for_stats) = sess.run(
                         [
                             pairwise_dist, mask_anchor_positive,
                             mask_anchor_negative, hardest_positive_dist,
                             hardest_negative_dist,
                             hardest_positive_element,
                             hardest_negative_element
                         ],
                         feed_dict={
                             input_img: batch_input_img_train,
                             labels: batch_labels_train,
                             line_types: batch_line_types_train,
                             geometric_info: batch_geometric_info_train,
                             keep_prob: dropout_rate
                         })
                    print_batch_triplets_statistics(
                        triplet_strategy=triplet_strategy,
                        images=batch_input_img_train,
                        set_mean=train_set_mean,
                        batch_index=step,
                        epoch_index=epoch,
                        write_folder='{}_logs/'.format(job_name),
                        labels=labels_for_stats,
                        pairwise_dist=pairwise_dist_for_stats,
                        mask_anchor_positive=mask_anchor_positive_for_stats,
                        mask_anchor_negative=mask_anchor_negative_for_stats,
                        hardest_positive_dist=
                        hardest_positive_dist_for_stats,
                        hardest_negative_dist=
                        hardest_negative_dist_for_stats,
                        hardest_positive_element=
                        hardest_positive_element_for_stats,
                        hardest_negative_element=
                        hardest_negative_element_for_stats)
            # Run the training operation.
            sess.run(
                train_op,
                feed_dict={
                    input_img: batch_input_img_train,
                    labels: batch_labels_train,
                    line_types: batch_line_types_train,
                    geometric_info: batch_geometric_info_train,
                    keep_prob: dropout_rate
                })
            # Generate summary with the current batch of data and write it
            # to file.
            if step % display_step == 0:
                s = sess.run(
                    merged_summary,
                    feed_dict={
                        input_img: batch_input_img_train,
                        labels: batch_labels_train,
                        line_types: batch_line_types_train,
                        geometric_info: batch_geometric_info_train,
                        keep_prob: 1.
                    })
                writer.add_summary(s,
                                    epoch * train_batches_per_epoch + step)

            step += 1

        # Validate the model on the entire validation set.
        print("{} Start validation".format(datetime.now()))
        loss_val = 0.
        val_count = 0
        for _ in range(val_batches_per_epoch):
            (batch_input_img_val, batch_labels_val,
             batch_line_types_val) = val_generator.next_batch(batch_size)
            # Pickled files have labels in the endpoints format -> convert
            # them to center format.
            if read_as_pickle:
                batch_start_points_val = batch_labels_val[:, :3]
                batch_end_points_val = batch_labels_val[:, 3:6]
                batch_geometric_info_val = get_geometric_info(
                    start_points=batch_start_points_val,
                    end_points=batch_end_points_val,
                    line_parametrization=line_parametrization)
                batch_labels_val = get_label_with_line_center(
                    labels_batch=batch_labels_val)
            # Obtain validation loss.
            loss_current = sess.run(
                loss,
                feed_dict={
                    input_img: batch_input_img_val,
                    labels: batch_labels_val,
                    line_types: batch_line_types_val,
                    geometric_info: batch_geometric_info_val,
                    keep_prob: 1.
                })
            loss_val += loss_current
            val_count += 1
        if val_count != 0:
            loss_val = loss_val / val_count
            print("{} Average loss for validation set = {:.4f}".format(
                datetime.now(), loss_val))
            
        # Reset the file pointer of the image data generator at the end of
        # each epoch.
        val_generator.reset_pointer()
        train_generator.reset_pointer()

        print("{} Saving checkpoint of model...".format(datetime.now()))
        # Save checkpoint of the model.
        checkpoint_name = os.path.join(
            checkpoint_path,
            image_type + '_model_epoch' + str(epoch + 1) + '.ckpt')
        save_path = saver.save(sess, checkpoint_name)

        print("{} Model checkpoint saved at {}".format(
              datetime.now(), checkpoint_name))

        # The following is useful if one has no access to standard output.
        with open(
                os.path.join(log_files_folder, job_name,
                                "epochs_completed"), "aw") as f:
            f.write("Completed epoch {}\n".format(epoch + 1))

Mean of train set: [ 33.6551056   34.38172304  32.57309097 634.94343205]
Just set data_size to be 9103
Just set data_size to be 0
2019-01-29 15:50:54.364741 Start training...
2019-01-29 15:50:55.247954 Open Tensorboard at --logdir ./logs/290119_1550/triplet_loss_batch_all
Starting epoch is 0, num_epochs is 10
2019-01-29 15:50:55.248908 Epoch number: 1
