In [1]:
import tensorflow as tf
from keras.models import load_model
from keras.layers import *
from keras import Model, Sequential, Input
from keras import optimizers
from keras.regularizers import l2
import keras.backend as K
import numpy as np
from pprint import pprint
import pickle


def load(filepath):
    with open(filepath, 'rb') as f:
        return pickle.load(f)

dataTriple = load('obj/triplets.pkl')
dataPairs = load('obj/pairs.pkl')
triplets = dataTriple['triplets'] / 256.
pairs = dataPairs['pairs'] / 256.

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [5]:
def get_darknet():
    model = load_model('darknet.h5')
    for layer in model.layers:
        layer.trainable = False
    return model

def make_style_model():
    darknet = get_darknet()
    
    # pretrained layers
    inputs = darknet.layers[0].input
    x = darknet.layers[16].output
    
    # new layers
    x = Flatten()(x)
    x = Dense(256, activation='relu')(x)
    x = Dense(256, activation='relu')(x)
    x = Lambda(lambda x: tf.nn.l2_normalize(x, axis=0))(x)
    base_model = Model(inputs, x)
    
    # siamese networks
    input_a = Input((256,256,3))
    input_b = Input((256,256,3))
    input_c = Input((256,256,3))
    
    encoding_a = base_model(input_a)
    encoding_b = base_model(input_b)
    encoding_c = base_model(input_c)
    
    dist_pos = Lambda(sqEucl, output_shape=(1,))([encoding_a, encoding_b])
    dist_neg = Lambda(sqEucl, output_shape=(1,))([encoding_a, encoding_c])
    triple_loss = Lambda(triplet_loss, output_shape=(1,))([dist_pos, dist_neg])
    
    triple_model = Model([input_a, input_b, input_c], triple_loss)
    pairModel = Model([input_a, input_b], dist_pos)
    
    return triple_model, pairModel

def triplet_loss(x):
    return K.maximum(x[0] - x[1] + 1, 0)

def sqEucl(x):
    return K.sum(K.square(x[0] - x[1]), axis=-1, keepdims=True)

def norm(x):
    return K.sqrt(sqEucl(x))
    
def identity(y_true, y_pred):
    return K.mean(K.square(y_pred))

In [6]:
tripleModel, pairModel = make_style_model()

sgd = optimizers.SGD(lr=0.001, decay=1e-6, momentum=0.9, nesterov=True)
tripleModel.compile(optimizer=sgd, loss=identity)
pairModel.compile(optimizer=sgd, loss=identity)
alphas = np.zeros(triplets.shape[0])


tripleModel.fit([triplets[:, 0], triplets[:, 1], triplets[:, 2]], alphas, epochs=50)




Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 46/50
Epoch 47/50
Epoch 48/50
Epoch 49/50
Epoch 50/50


<keras.callbacks.History at 0x7f0eeb592e80>

In [8]:

def predict(x, y):
    return pairModel.predict([x, y]).ravel()

for i in range(20):
    begin = i*20
    end = begin + 20
    predictions = predict(pairs[begin:end,0], pairs[begin:end,1]).reshape((-1,2))
    print(predictions.reshape((-1, 2)))

[[2.4950957  5.1264925 ]
 [3.8152025  3.6399188 ]
 [0.9264819  4.3501234 ]
 [1.2057012  5.349226  ]
 [1.242723   2.6187909 ]
 [1.5210639  1.945425  ]
 [0.9375831  2.1345413 ]
 [0.8826487  4.4403906 ]
 [0.91110337 2.4546547 ]
 [2.7326872  3.8913617 ]]
[[1.8716867  3.3427753 ]
 [0.69695514 3.4665778 ]
 [2.8598783  3.8320096 ]
 [1.3541038  4.5497518 ]
 [0.40543768 3.0174813 ]
 [1.4152956  2.8619452 ]
 [1.7945244  2.2570984 ]
 [0.968596   4.809819  ]
 [0.36610913 1.6447008 ]
 [0.32626522 4.873949  ]]
[[2.7475417 3.8120337]
 [1.2986265 1.9739528]
 [0.7770302 1.641121 ]
 [2.4982455 3.863014 ]
 [1.0114396 1.766274 ]
 [1.0013199 3.2714286]
 [2.2748902 6.1354713]
 [2.1386037 3.5767896]
 [1.3487297 2.646108 ]
 [1.4915531 2.3894005]]
[[0.793393   1.9127164 ]
 [2.5631135  6.202421  ]
 [1.4571168  4.110161  ]
 [1.0843265  2.3613775 ]
 [3.0512862  3.8455002 ]
 [2.2754607  2.260673  ]
 [1.0580051  3.600628  ]
 [0.44122577 5.1234956 ]
 [2.361735   1.9579437 ]
 [1.2477882  2.2233624 ]]
[[1.4938849  2.7