In [1]:
from tensorflow import keras
import tensorflow as tf
from tensorflow.keras.applications.inception_v3 import InceptionV3, preprocess_input
from tensorflow.keras.preprocessing import image
from tensorflow.keras.layers import Dense, Lambda
from tensorflow.keras.models import Model
from tensorflow.keras import optimizers
import tensorflow.keras.backend as K
import numpy as np
from PIL import Image
import multiprocessing as mp
import math
import tqdm
from annoy import AnnoyIndex
import json
import PIL
from functools import partial, update_wrapper
from sklearn.preprocessing import MultiLabelBinarizer

In [2]:
img_size = 299     # input image size for network
margin = 0.3       # margin for triplet loss
batch_size = 42    # size of mini-batch
num_triplets = 700 
valid_frequency = 100
num_epoch = 3600 
log_dir = './logs/triplet_semihard_v3'
checkpoint_dir = 'checkpoints/'
recall_log_file = './logs/recall_triplet_semihard_v3.json'
recall_values = [1, 3, 5, 10, 25, 50, 100]
image_dir = '../../../DeepFashion2_Dataset/'

In [3]:
# utility function to freeze some portion of a function's arguments
def wrapped_partial(func, *args, **kwargs):
    partial_func = partial(func, *args, **kwargs)
    update_wrapper(partial_func, func)
    return partial_func

# calculate recall score
def recall(y_true, y_pred):
    return min(len(set(y_true) & set(y_pred)), 1)

# margin triplet loss
def margin_triplet_loss(y_true, y_pred, margin, batch_size):
    out_a = tf.gather(y_pred, tf.range(0, batch_size, 3))
    out_p = tf.gather(y_pred, tf.range(1, batch_size, 3))
    out_n = tf.gather(y_pred, tf.range(2, batch_size, 3))
    
    loss = K.maximum(margin
                 + K.sum(K.square(out_a-out_p), axis=1)
                 - K.sum(K.square(out_a-out_n), axis=1),
                 0.0)
    return K.mean(loss)

### Define model architecture

In [12]:
# def get_model():
#     no_top_model = InceptionV3(include_top=False, weights='imagenet', pooling='avg')

#     x = no_top_model.output
#     x = Dense(512, activation='elu', name='fc1')(x)
#     x = Dense(128, name='fc2')(x)
#     x = Lambda(lambda x: K.l2_normalize(x, axis=1), name='l2_norm')(x)
#     return Model(no_top_model.inputs, x)

In [13]:
# model = get_model()
# model

In [10]:
def get_trained_model():
    model = tf.keras.models.load_model('./models/triplet_semihard_final.h5')
    return model

In [11]:
model = get_trained_model()
model



<tensorflow.python.keras.engine.training.Model at 0x1a39e526d0>

### Compile model

In [14]:
opt = keras.optimizers.Adam(lr=0.0001)
model.compile(loss=wrapped_partial(margin_triplet_loss, margin=margin, batch_size=batch_size), optimizer=opt)

In [15]:
model

<tensorflow.python.keras.engine.training.Model at 0x1a39e526d0>

### Load data

I have used mongodb to store dataset. Document structure:
- type: define split on train, validation and test sets, possible values - 'train', 'val', 'test'
- seller_img: list of IDs of seller's images
- user_img: list of IDs of user's images

Example:
```javascript
{
    'type': 'test',
    'seller_img': ['HTB1s7ZiLFXXXXatXFXXq6xXFXXXu', 
                'HTB1pGaAKXXXXXczXXXXq6xXFXXXN'],
    'user_img': ['UTB8KtbUXXPJXKJkSahVq6xyzFXaE',
                'UTB8OeDUXgnJXKJkSaelq6xUzXXag',
                'UTB8h7HUXnzIXKJkSafVq6yWgXXap',
                'UTB8auL6XevJXKJkSajhq6A7aFXaa',
                'UTB8rrevXevJXKJkSajhq6A7aFXa5',
                'UTB8WUCuXXPJXKJkSahVq6xyzFXa6',
                'UTB8MHmvXXfJXKJkSamHq6zLyVXa1']
}
```

In [16]:
import pandas as pd
import numpy as np
dt_all = pd.read_csv('deepfashion_retreival_train_val.csv')
dt_all['image_path'] = dt_all['val'].map(str) + '/image/' + dt_all['image_id'].map(str)
dt_all.tail()

Unnamed: 0,image_id,category_id,category_name,bounding_box,style,pair_id,source,val,pair_unique_id,image_path
364671,022559.jpg,1,short sleeve top,"[0, 0, 433, 505]",1,1318,shop,validation,1318-1-1,validation/image/022559.jpg
364672,022559.jpg,8,trousers,"[145, 373, 466, 701]",0,1318,shop,validation,1318-8-0,validation/image/022559.jpg
364673,016905.jpg,1,short sleeve top,"[304, 3, 586, 324]",2,670,shop,validation,670-1-2,validation/image/016905.jpg
364674,016905.jpg,7,shorts,"[341, 285, 556, 514]",0,670,shop,validation,670-7-0,validation/image/016905.jpg
364675,017617.jpg,12,vest dress,"[275, 227, 527, 823]",1,738,shop,validation,738-12-1,validation/image/017617.jpg


In [24]:
# dt_all.category_name.unique()

In [25]:
train_seller_images = []
train_user_images = []
# 221563 rows
categories = ['long sleeve top', 'short sleeve top',
       'vest dress', 'vest', 'short sleeve dress',
       'sling dress', 'long sleeve outwear', 'long sleeve dress', 'sling',
       'short sleeve outwear']
# 145420 rows
# categories = ['shorts', 'skirt', 'trousers', 'sling']

for pair_unique_id in dt_all[dt_all.category_name.isin(categories)].pair_unique_id.unique():
    seller_imgs = dt_all[(dt_all.pair_unique_id == pair_unique_id) & (dt_all.source == 'shop')].image_path.values
    user_imgs = dt_all[(dt_all.pair_unique_id == pair_unique_id) & (dt_all.source == 'user')].image_path.values
    if len(seller_imgs)>0 and len(user_imgs)>0:
        train_seller_images.append(seller_imgs)
        train_user_images.append(user_imgs)


In [27]:
len(train_user_images)

15597

```javascript
train_seller_images = [
                        ['HTB1s7ZiLFXXXXatXFXXq6xXFXXXu', 
                        'HTB1pGaAKXXXXXczXXXXq6xXFXXXN'],
                        ['HTB1s7ZiLFXXXXatXFXXq6xXFXXXu', 
                        'HTB1pGaAKXXXXXczXXXXq6xXFXXXN'],
                        ...
                      ]

train_user_images = [
                        ['UTB8KtbUXXPJXKJkSahVq6xyzFXaE',
                        'UTB8OeDUXgnJXKJkSaelq6xUzXXag',
                        'UTB8h7HUXnzIXKJkSafVq6yWgXXap',
                        'UTB8auL6XevJXKJkSajhq6A7aFXaa',
                        'UTB8rrevXevJXKJkSajhq6A7aFXa5',
                        'UTB8WUCuXXPJXKJkSahVq6xyzFXa6',
                        'UTB8MHmvXXfJXKJkSamHq6zLyVXa1'],
                        ['UTB8KtbUXXPJXKJkSahVq6xyzFXaE',
                        'UTB8OeDUXgnJXKJkSaelq6xUzXXag',
                        'UTB8h7HUXnzIXKJkSafVq6yWgXXap',
                        'UTB8WUCuXXPJXKJkSahVq6xyzFXa6',
                        'UTB8MHmvXXfJXKJkSamHq6zLyVXa1'],
                        ...
                      ]
```

In [30]:
# train_user_images[0:10]

In [31]:
# Test MultiLabelBinarizer
mlb = MultiLabelBinarizer()
mlb.fit_transform(train_user_images[0:10])

array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 

In [32]:
mlb.classes_

array(['train/image/013850.jpg', 'train/image/015077.jpg',
       'train/image/015078.jpg', 'train/image/025419.jpg',
       'train/image/046108.jpg', 'train/image/046109.jpg',
       'train/image/046110.jpg', 'train/image/046111.jpg',
       'train/image/046112.jpg', 'train/image/046113.jpg',
       'train/image/046114.jpg', 'train/image/046115.jpg',
       'train/image/046117.jpg', 'train/image/046118.jpg',
       'train/image/058095.jpg', 'train/image/092998.jpg',
       'train/image/092999.jpg', 'train/image/093000.jpg',
       'train/image/093002.jpg', 'train/image/093003.jpg',
       'train/image/115610.jpg', 'train/image/115611.jpg',
       'train/image/115612.jpg', 'train/image/115613.jpg',
       'train/image/115614.jpg', 'train/image/115615.jpg',
       'train/image/115616.jpg', 'train/image/115617.jpg',
       'train/image/115618.jpg', 'train/image/115619.jpg',
       'train/image/115620.jpg', 'train/image/115621.jpg',
       'train/image/115622.jpg', 'train/image/115623.jpg

In [33]:
from sklearn.preprocessing import MultiLabelBinarizer
mlb = MultiLabelBinarizer(sparse_output=True)

In [34]:
X_train_seller = mlb.fit_transform(train_seller_images)

In [35]:
train_images_seller = mlb.classes_

In [36]:
X_train_user = mlb.fit_transform(train_user_images)

In [37]:
train_images_user = mlb.classes_

In [38]:
num_images_by_class = np.asarray(X_train_user.sum(axis=1)).ravel()

```javascript
X_train_seller = [[0,0,0,1], [1,0,0,0]...]
```

In [39]:
# val_images = [(x['seller_img'], x['user_img']) for x in coll.find({'type': 'val'})]
# val_image_to_class = {}
# for i, item in enumerate(val_images):
#     for x in item[0] + item[1]:
#         val_image_to_class[x] = i

# val_images_clean = [(x['seller_img_clean'], x['user_img_clean']) for x in coll.find({'type': 'val', 'clean': True})]
# val_image_clean_to_class = {}
# for i, item in enumerate(val_images_clean):
#     for x in item[0] + item[1]:
#         val_image_clean_to_class[x] = i

# test_images = [(x['seller_img'], x['user_img']) for x in coll.find({'type': 'test'})]

# seller_images = [x for item in coll.find({}) for x in item['seller_img']]
# val_user_images = [x for item in val_images for x in item[1]]
# val_user_clean_images = [x for item in val_images_clean for x in item[1]]

In [None]:
val_images = []
dt_val = dt_all[dt_all.val = 'validation']

for pair_unique_id in dt_val[dt_val.category_name.isin(categories)].pair_unique_id.unique():
    seller_imgs = dt_val[(dt_val.pair_unique_id == pair_unique_id) & (dt_val.source == 'shop')].image_path.values
    user_imgs = dt_val[(dt_val.pair_unique_id == pair_unique_id) & (dt_val.source == 'user')].image_path.values
    if len(seller_imgs)>0 and len(user_imgs)>0:
        val_images.append((seller_imgs, user_imgs))
        
val_image_to_class = {}
for i, item in enumerate(val_images):
    for x in item[0] + item[1]:
        val_image_to_class[x] = i

        
seller_images = dt_all[dt_all.source == 'shop'].image_path.values
val_user_images = [x for item in val_images for x in item[1]]

### Data loading utils

In [40]:
# load image by id without augmentations
def preprocess_image_worker(media_id):
    img = Image.open((image_dir + media_id)).convert('RGB')
    img = img.resize((img_size, img_size))

    x = image.img_to_array(img)
    x = np.expand_dims(x, axis=0)
    x = preprocess_input(x)
    return x

# preprocess_image_worker('validation/image/010779.jpg')

In [41]:
# load image by id with augmentations
def preprocess_image_worker_aug(media_id):
    img = Image.open((image_dir + media_id)).convert('RGB')
    img = img.crop((int(np.random.uniform(0, 0.05)*img.width), int(np.random.uniform(0, 0.05)*img.height),
                  int(np.random.uniform(0.95, 1.)*img.width), int(np.random.uniform(0.95, 1.)*img.height)))
    if np.random.randint(2) == 0:
        img = img.transpose(np.random.choice([PIL.Image.FLIP_LEFT_RIGHT, PIL.Image.FLIP_TOP_BOTTOM, PIL.Image.ROTATE_90, PIL.Image.ROTATE_180, PIL.Image.ROTATE_270, PIL.Image.TRANSPOSE]))
    img = img.resize((img_size, img_size))

    x = image.img_to_array(img)
    x = np.expand_dims(x, axis=0)
    x = preprocess_input(x)
    return x

# preprocess_image_worker_aug('validation/image/010779.jpg')

In [42]:
# data generator, divides list images into mini-batches
def batch_generator_predict(pool, batch_size, images):
    i = 0
    while True:
        batch = images[i:i+batch_size]
        i += batch_size
        if len(batch) == 0:
            yield np.zeros((0, img_size, img_size, 3))
        else:
            result = pool.map(preprocess_image_worker, batch)
            X_batch = np.concatenate(result, axis=0)
            yield X_batch

### Semi-hard negative mining

In [43]:
class SemiHardNegativeSampler:
    def __init__(self, pool, batch_size, num_samples):
        self.pool = pool
        self.batch_size = batch_size
        self.num_samples = num_samples
        self.resample()
        
    # sample triplets with semi-hard negatives
    def resample(self):        
        sample_classes = np.random.choice(np.arange(X_train_user.shape[0]), p=num_images_by_class/num_images_by_class.sum(), size=self.num_samples)

        sample_images = []
        for i in sample_classes:
            sample_images.append(train_images_user[np.random.choice(X_train_user[i].nonzero()[1], replace=False)])
            sample_images.append(train_images_seller[np.random.choice(X_train_seller[i].nonzero()[1], replace=False)])
        sample_images = np.array(sample_images)

        pred_sample = model.predict_generator(batch_generator_predict(pool, 32, sample_images), math.ceil(len(sample_images)/32), 
                                              # max_q_size=1, 
                                              workers=1)

        a = pred_sample[np.arange(0, len(pred_sample), 2)]
        p = pred_sample[np.arange(1, len(pred_sample), 2)]
        triplets = []
        self.dists = []

        for i in range(self.num_samples):
            d = np.square(a[i] - p[i]).sum()
            neg_sample_classes = (sample_classes != sample_classes[i]).nonzero()[0]

            neg = p[neg_sample_classes]

            neg_ids = sample_images.reshape((-1, 2))[neg_sample_classes, 1]

            d_neg = np.square(neg - a[i]).sum(axis=1)

            semihard = np.where(d_neg > d)[0]

            if len(semihard) == 0:
                n = np.argmax(d_neg)
            else:
                n = semihard[np.argmin(d_neg[semihard])]
                
            self.dists.append(d_neg[n]-d)

            triplets.append(np.concatenate([sample_images.reshape((-1, 2))[i], np.array([neg_ids[n]])]))

        self.triplets = np.array(triplets)
        
    # data generator for triplets
    def batch_generator(self):
        i = 0
        while True:
            batch = self.triplets[i:i+self.batch_size//3].ravel()
            
            i += self.batch_size//3
            if len(batch) == 0:
                yield np.zeros((0, img_size, img_size, 3))
            else:
                result = pool.map(preprocess_image_worker_aug, batch)
                X_batch = np.concatenate(result, axis=0)
                yield X_batch, np.zeros(len(batch))

    # return data generator for triplets
    def get_generator(self):
        gen = self.batch_generator()
        return gen

### Create pool of processes for parallel data loading

In [44]:
# num_triplets = len(train_user_images) # Need to check
pool = mp.Pool(processes=8)
sampler = SemiHardNegativeSampler(pool, batch_size, num_triplets)

### Attach tensorboard to monitor learining process

In [45]:
tensorboard = keras.callbacks.TensorBoard(log_dir=log_dir,
                 histogram_freq=1, 
                 write_graph=False, 
                 write_images=False)

### Training process
Each epoch we 
- train model on triplets with semi-hard negatives from `sampler`
- resample triplets
- do validation and save model with frequency `valid_frequency`

I use annoy index for nearest neighbors search to speedup validation

In [46]:
epoch = 0

In [None]:
# https://www.pyimagesearch.com/2018/12/24/how-to-use-keras-fit-and-fit_generator-a-hands-on-tutorial/
for i in range(num_epoch):
    train_gen = sampler.get_generator()
    h = model.fit_generator(train_gen, 
                        steps_per_epoch=num_triplets//(batch_size//3), 
                        epochs=epoch+1, 
                        initial_epoch=epoch, 
                        verbose=2,
                        # max_q_size=1, 
                        callbacks=[tensorboard,])
    sampler.resample()
    
    if epoch%valid_frequency == 0 and epoch != 0:
        seller_pred = model.predict_generator(batch_generator_predict(pool, 32, seller_images), 
                                              math.ceil(len(seller_images)/32), 
                                              # max_q_size=1, 
                                              workers=1)
        val_user_pred = model.predict_generator(batch_generator_predict(pool, 32, val_user_images), 
                                                math.ceil(len(val_user_images)/32), 
                                                # max_q_size=1, 
                                                workers=1)
        # val_user_clean_pred = model.predict_generator(batch_generator_predict(pool, 32, val_user_clean_images), math.ceil(len(val_user_clean_images)/32), max_q_size=1, workers=1)

        search_index = AnnoyIndex(128, metric='euclidean')
        for i in range(len(seller_pred)):
            search_index.add_item(i, seller_pred[i])
        search_index.build(50)

        recall_scores = {i: [] for i in recall_values}
        for i in range(len(val_user_pred)):
            r = search_index.get_nns_by_vector(val_user_pred[i], 100)
            val_cl = val_image_to_class[val_user_images[i]]
            for k in recall_values:
                recall_scores[k].append(recall(val_images[val_cl][0], [seller_images[i] for i in r[:k]]))

        print ('val on full')
        for k in recall_values:
            print (k, np.mean(recall_scores[k]))

        val_recall = [(k, np.mean(recall_scores[k])) for k in recall_values]

        recall_scores = {i: [] for i in recall_values}

#         for i in range(len(val_user_clean_pred)):
#             r = search_index.get_nns_by_vector(val_user_clean_pred[i], 100)
#             val_cl = val_image_clean_to_class[val_user_clean_images[i]]
#             for k in recall_values:
#                 recall_scores[k].append(recall(val_images_clean[val_cl][0], [seller_images[i] for i in r[:k]]))

#         print ('val on clean')
#         for k in recall_values:
#             print (k, np.mean(recall_scores[k]))

#         val_recall_clean = [(k, np.mean(recall_scores[k])) for k in recall_values]
        
        try:
            with open(recall_log_file, 'r') as f:
                recall_log = json.load(f)
        except:
            recall_log = []

        recall_log.append((epoch, val_recall, val_recall_clean))

        with open(recall_log_file, 'w') as f:
            json.dump(recall_log, f)
            
        model.save_weights(os.path.join(checkpoint_dir, 'triplet_semihard_%d.keras'%epoch))
        
    epoch += 1

Instructions for updating:
Use tf.identity instead.
50/50 - 626s - loss: 0.1842
Epoch 2/2
50/50 - 674s - loss: 0.1891
Epoch 3/3
50/50 - 835s - loss: 0.1758
Epoch 4/4
50/50 - 827s - loss: 0.1931
Epoch 5/5
50/50 - 844s - loss: 0.1822
Epoch 6/6
50/50 - 811s - loss: 0.1882
Epoch 7/7
50/50 - 841s - loss: 0.1901
Epoch 8/8


In [None]:
# %tensorboard --logdir logs/triplet_semihard_v3