In [39]:
from PIL import Image
from keras.layers import *
from keras.applications import *
from keras import *
import numpy as np
import tensorflow as tf

In [40]:
IMG_HEIGHT = 96
IMG_WIDTH = 96

In [41]:
def load_image(path):
    path = "food/"+path+".jpg"
    image = tf.image.decode_jpeg(path,channels=3)
    image = tf.cast(image, tf.float32)
    image = image / 127.5 - 1
    image = tf.image.resize(image,(IMG_HEIGHT, IMG_WIDTH))
    print(tf.keras.backend.shape(image))
    return image

def load_triplet(sample):
    sample = tf.strings.split(sample)
    baseImage = load_image(sample[0])
    trueImage = load_image(sample[1])
    falseImage= load_image(sample[2])
    print(tf.keras.backend.shape(baseImage))
    output = tf.stack([baseImage,trueImage,falseImage],axis=0)
    print(tf.keras.backend.shape(output))
    return output

def create_dataset(filename):
    dataset = tf.data.TextLineDataset(filename)
    dataset.map(lambda triplet : load_triplet(triplet))
    return dataset

In [45]:
def extract(predictions):
    anchor, correct, wrong = predictions[...,0],predictions[...,1],predictions[...,2]
    distance_correct = tf.reduce_sum(tf.square(anchor - correct),1)
    distance_false = tf.reduce_sum(tf.square(anchor - wrong),1)
    return distance_correct,distance_false

def triplet_loss(_,predictions): # we use a triplet loss in this case like for facenet
    distance_correct, distance_false = extract(predictions)
    return tf.reduce_mean(tf.math.softplus(distance_correct - distance_false))
   

def accuracy(_,predictions):
    distance_correct, distance_false = extract(predictions)
    return tf.reduce_mean(tf.cast(tf.greater_equal(distance_correct,distance_false), tf.float32))

In [46]:
def load_image(img, training):
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.cast(img, tf.float32)
    img = img / 127.5 - 1
    img = tf.image.resize(img, (IMG_HEIGHT, IMG_WIDTH))
    if training:
        img = tf.image.random_flip_left_right(img)
        img = tf.image.random_flip_up_down(img)
    return img


def load_triplets(triplet, training):
    ids = tf.strings.split(triplet)
    anchor = load_image(tf.io.read_file('food/' + ids[0] + '.jpg'), training)
    truthy = load_image(tf.io.read_file('food/' + ids[1] + '.jpg'), training)
    falsy = load_image(tf.io.read_file('food/' + ids[2] + '.jpg'), training)
    if training:
        return tf.stack([anchor, truthy, falsy], axis=0), 1
    else:
        return tf.stack([anchor, truthy, falsy], axis=0)
    
def create_dataset(dataset_filename, training=True):
    dataset = tf.data.TextLineDataset(dataset_filename)
    dataset = dataset.map(
        lambda triplet: load_triplets(triplet, training),
        num_parallel_calls=tf.data.experimental.AUTOTUNE)
    return dataset

In [47]:
def create_model():
    input_layer = Input(shape=(3,IMG_HEIGHT, IMG_WIDTH, 3))
    
    anchor, correct,wrong = input_layer[:,0,...], input_layer[:,1,...], input_layer[:,2,...]
    
    encoder = MobileNetV2(include_top=False,input_shape=(IMG_HEIGHT, IMG_WIDTH, 3))
    encoder.trainable = False
    
    decoder = Sequential([
        GlobalAveragePooling2D(),
        Dense(512, activation="relu"),
        BatchNormalization(),
        Dense(256, activation="relu"),
        BatchNormalization(),
        Dense(256)
    ])
    
    anchor_decoded = decoder(encoder(anchor))
    correct_decoded = decoder(encoder(correct))
    wrong_decoded = decoder(encoder(wrong))
    
    output_layer = tf.stack([anchor_decoded,correct_decoded,wrong_decoded],axis=-1)
    
    model = Model(inputs=input_layer,outputs=output_layer)
    model.summary()
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
                  loss = triplet_loss,
                   metrics=[accuracy])
    
    return model

In [None]:
def main():
    train_dataset = create_dataset('train_triplets.txt')
    model = create_model()
    
    train_dataset = train_dataset.shuffle(1024, reshuffle_each_iteration=True) \
        .repeat().batch(32)
    history = model.fit(
        train_dataset,
        steps_per_epoch=int(np.ceil(10000)),
        epochs=10,
    )
    
main()

Model: "model_12"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_25 (InputLayer)           [(None, 3, 96, 96, 3 0                                            
__________________________________________________________________________________________________
tf.__operators__.getitem_36 (Sl (None, 96, 96, 3)    0           input_25[0][0]                   
__________________________________________________________________________________________________
tf.__operators__.getitem_37 (Sl (None, 96, 96, 3)    0           input_25[0][0]                   
__________________________________________________________________________________________________
tf.__operators__.getitem_38 (Sl (None, 96, 96, 3)    0           input_25[0][0]                   
___________________________________________________________________________________________