In [1]:
import tensorflow as tf
from pyramda import compose, curry
import os
import numpy as np
import sys
from sklearn.model_selection import train_test_split
from visuazlizers.nb_graph_visualizer import show_graph, rename_nodes
from utils.curried_functions import tf_add, tf_cast, tf_multiply, filter_list
from loaders import load_batch_of_images, load_model_pb
from siamese import compute_loss, get_anchor_positive_mask, get_negative_mask, train_siamese_model
from utils.metrics import cosine_distance

### Batch

In [2]:
classes_to_labels = compose(
    list,
    range,
    len,
)

### Model

In [3]:
import sys
sys.path.append("..")
from input.models.deep_sort_cnn.freeze_model import _preprocess, _network_factory

def create_deep_sort_cnn_graph():
    input_var = tf.placeholder(tf.uint8, (None, 128, 64, 3), name="images")
    image_var = tf.map_fn(
        lambda x: _preprocess(x), tf.cast(input_var, tf.float32),
        back_prop=False
    )

    factory_fn = _network_factory()
    features, _ = factory_fn(image_var, reuse=None)
    features = tf.identity(features, name="features")
    
    return input_var, features

### Training

In [4]:
tf.reset_default_graph()

source_path = '../input/mars/bbox_train/'
dirs = compose(
    filter_list(['.DS_Store'], False),
    os.listdir,
)(source_path)

labels = classes_to_labels(dirs)
train_dirs, val_dirs, train_labels, val_labels = train_test_split(dirs[0:20], labels[0:20], test_size=0.1)

inputs, outputs, _ = load_model_pb(
    '../input/models/deep_sort_cnn/mars-small128.pb', 
    input_name="images", 
    output_name="features", 
    graph_creator=create_deep_sort_cnn_graph,
)

session = tf.Session()

train_siamese_model(
    session=session,
    model=[inputs, outputs],
    source_path=source_path,
    dirs=(train_dirs, val_dirs),
    class_labels=(train_labels, val_labels),
    metric=cosine_distance,
    optimizer=tf.train.AdamOptimizer(learning_rate=0.00001),
    batch_loader=load_batch_of_images(image_shape=(128, 64, 3)),
    num_iter=2,
)

INFO:tensorflow:Summary name conv2_1/1/Elu:0/activations is illegal; using conv2_1/1/Elu_0/activations instead.
INFO:tensorflow:Summary name conv2_3/1/Elu:0/activations is illegal; using conv2_3/1/Elu_0/activations instead.
INFO:tensorflow:Summary name conv3_1/1/Elu:0/activations is illegal; using conv3_1/1/Elu_0/activations instead.
INFO:tensorflow:Summary name conv3_3/1/Elu:0/activations is illegal; using conv3_3/1/Elu_0/activations instead.
INFO:tensorflow:Summary name conv4_1/1/Elu:0/activations is illegal; using conv4_1/1/Elu_0/activations instead.
INFO:tensorflow:Summary name conv4_3/1/Elu:0/activations is illegal; using conv4_3/1/Elu_0/activations instead.


100%|██████████| 18/18 [00:00<00:00, 20.65it/s]
100%|██████████| 2/2 [00:00<00:00, 28.96it/s]

Training loss 0.199428



 22%|██▏       | 4/18 [00:00<00:00, 30.14it/s]

Validation loss 0.2


100%|██████████| 18/18 [00:00<00:00, 20.51it/s]
100%|██████████| 2/2 [00:00<00:00, 21.26it/s]

Training loss 0.200007
Validation loss 0.2



