In [None]:
import keras
import tensorflow as tf
from keras.applications.inception_v3 import InceptionV3, preprocess_input
from keras.preprocessing import image
from keras.layers import Dense, Lambda
from keras.models import Model
from keras import optimizers
import keras.backend as K
import numpy as np
from PIL import Image
import multiprocessing as mp
import math
import tqdm
from annoy import AnnoyIndex
import json
import PIL
from functools import partial, update_wrapper

In [None]:
img_size = 299     # input image size for network
margin = 0.3       # margin for triplet loss
batch_size = 42    # size of mini-batch
num_triplet = 700 
valid_frequency = 100
num_epoch = 3600 
log_dir = './logs/triplet_semihard_v3'
checkpoint_dir = 'checkpoints/'
recall_log_file = './logs/recall_triplet_semihard_v3.json'
recall_values = [1, 3, 5, 10, 25, 50, 100]

In [None]:
# utility function to freeze some portion of a function's arguments
def wrapped_partial(func, *args, **kwargs):
    partial_func = partial(func, *args, **kwargs)
    update_wrapper(partial_func, func)
    return partial_func

# calculate recall score
def recall(y_true, y_pred):
    return min(len(set(y_true) & set(y_pred)), 1)

# margin triplet loss
def margin_triplet_loss(y_true, y_pred, margin, batch_size):
    out_a = tf.gather(y_pred, tf.range(0, batch_size, 3))
    out_p = tf.gather(y_pred, tf.range(1, batch_size, 3))
    out_n = tf.gather(y_pred, tf.range(2, batch_size, 3))
    
    loss = K.maximum(margin
                 + K.sum(K.square(out_a-out_p), axis=1)
                 - K.sum(K.square(out_a-out_n), axis=1),
                 0.0)
    return K.mean(loss)

### Define model architecture

In [None]:
def get_model():
    no_top_model = InceptionV3(include_top=False, weights='imagenet', pooling='avg')

    x = no_top_model.output
    x = Dense(512, activation='elu', name='fc1')(x)
    x = Dense(128, name='fc2')(x)
    x = Lambda(lambda x: K.l2_normalize(x, axis=1), name='l2_norm')(x)
    return Model(no_top_model.inputs, x)

In [None]:
model = get_model()

### Compile model

In [None]:
opt = keras.optimizers.Adam(lr=0.0001)
model.compile(loss=wrapped_partial(margin_triplet_loss, margin=margin, batch_size=batch_size), optimizer=opt)

### Load data

I have used mongodb to store dataset. Document structure:
- type: define split on train, validation and test sets, possible values - 'train', 'val', 'test'
- seller_img: list of IDs of seller's images
- user_img: list of IDs of user's images

Example:
```javascript
{
    'type': 'test',
    'seller_img': ['HTB1s7ZiLFXXXXatXFXXq6xXFXXXu', 
                'HTB1pGaAKXXXXXczXXXXq6xXFXXXN'],
    'user_img': ['UTB8KtbUXXPJXKJkSahVq6xyzFXaE',
                'UTB8OeDUXgnJXKJkSaelq6xUzXXag',
                'UTB8h7HUXnzIXKJkSafVq6yWgXXap',
                'UTB8auL6XevJXKJkSajhq6A7aFXaa',
                'UTB8rrevXevJXKJkSajhq6A7aFXa5',
                'UTB8WUCuXXPJXKJkSahVq6xyzFXa6',
                'UTB8MHmvXXfJXKJkSamHq6zLyVXa1']
}
```

In [None]:
from pymongo import MongoClient
client = MongoClient(connect=False)
db = client['aliexpress']
coll = db['items']

In [None]:
train_seller_images = []
train_user_images = []
for item in coll.find({'type': 'train'}):
    train_seller_images.append(item['seller_img'])
    train_user_images.append(item['user_img'])

In [None]:
from sklearn.preprocessing import MultiLabelBinarizer
mlb = MultiLabelBinarizer(sparse_output=True)

In [None]:
X_train_seller = mlb.fit_transform(train_seller_images)

In [None]:
train_images_seller = mlb.classes_

In [None]:
X_train_user = mlb.fit_transform(train_user_images)

In [None]:
train_images_user = mlb.classes_

In [None]:
num_images_by_class = np.asarray(X_train_user.sum(axis=1)).ravel()

In [None]:
val_images = [(x['seller_img'], x['user_img']) for x in coll.find({'type': 'val'})]
val_image_to_class = {}
for i, item in enumerate(val_images):
    for x in item[0] + item[1]:
        val_image_to_class[x] = i

val_images_clean = [(x['seller_img_clean'], x['user_img_clean']) for x in coll.find({'type': 'val', 'clean': True})]
val_image_clean_to_class = {}
for i, item in enumerate(val_images_clean):
    for x in item[0] + item[1]:
        val_image_clean_to_class[x] = i

test_images = [(x['seller_img'], x['user_img']) for x in coll.find({'type': 'test'})]

seller_images = [x for item in coll.find({}) for x in item['seller_img']]
val_user_images = [x for item in val_images for x in item[1]]
val_user_clean_images = [x for item in val_images_clean for x in item[1]]

### Data loading utils

In [None]:
# load image by id without augmentations
def preprocess_image_worker(media_id):
    img = Image.open(('./images/'+media_id+'.jpg')).convert('RGB')
    img = img.resize((img_size, img_size))

    x = image.img_to_array(img)
    x = np.expand_dims(x, axis=0)
    x = preprocess_input(x)
    return x

In [None]:
# load image by id with augmentations
def preprocess_image_worker_aug(media_id):
    img = Image.open(('./images/'+media_id+'.jpg')).convert('RGB')
    img = img.crop((int(np.random.uniform(0, 0.05)*img.width), int(np.random.uniform(0, 0.05)*img.height),
                  int(np.random.uniform(0.95, 1.)*img.width), int(np.random.uniform(0.95, 1.)*img.height)))
    if np.random.randint(2) == 0:
        img = img.transpose(np.random.choice([PIL.Image.FLIP_LEFT_RIGHT, PIL.Image.FLIP_TOP_BOTTOM, PIL.Image.ROTATE_90, PIL.Image.ROTATE_180, PIL.Image.ROTATE_270, PIL.Image.TRANSPOSE]))
    img = img.resize((img_size, img_size))

    x = image.img_to_array(img)
    x = np.expand_dims(x, axis=0)
    x = preprocess_input(x)
    return x

In [None]:
# data generator, divides list images into mini-batches
def batch_generator_predict(pool, batch_size, images):
    i = 0
    while True:
        batch = images[i:i+batch_size]
        i += batch_size
        if len(batch) == 0:
            yield np.zeros((0, img_size, img_size, 3))
        else:
            result = pool.map(preprocess_image_worker, batch)
            X_batch = np.concatenate(result, axis=0)
            yield X_batch

### Semi-hard negative mining

In [None]:
class SemiHardNegativeSampler:
    def __init__(self, pool, batch_size, num_samples):
        self.pool = pool
        self.batch_size = batch_size
        self.num_samples = num_samples
        self.resample()
        
    # sample triplets with semi-hard negatives
    def resample(self):        
        sample_classes = np.random.choice(np.arange(X_train_user.shape[0]), p=num_images_by_class/num_images_by_class.sum(), size=self.num_samples)

        sample_images = []
        for i in sample_classes:
            sample_images.append(train_images_user[np.random.choice(X_train_user[i].nonzero()[1], replace=False)])
            sample_images.append(train_images_seller[np.random.choice(X_train_seller[i].nonzero()[1], replace=False)])
        sample_images = np.array(sample_images)

        pred_sample = model.predict_generator(batch_generator_predict(pool, 32, sample_images), math.ceil(len(sample_images)/32), max_q_size=1, workers=1)

        a = pred_sample[np.arange(0, len(pred_sample), 2)]
        p = pred_sample[np.arange(1, len(pred_sample), 2)]
        triplets = []
        self.dists = []

        for i in range(self.num_samples):
            d = np.square(a[i] - p[i]).sum()
            neg_sample_classes = (sample_classes != sample_classes[i]).nonzero()[0]

            neg = p[neg_sample_classes]

            neg_ids = sample_images.reshape((-1, 2))[neg_sample_classes, 1]

            d_neg = np.square(neg - a[i]).sum(axis=1)

            semihard = np.where(d_neg > d)[0]

            if len(semihard) == 0:
                n = np.argmax(d_neg)
            else:
                n = semihard[np.argmin(d_neg[semihard])]
                
            self.dists.append(d_neg[n]-d)

            triplets.append(np.concatenate([sample_images.reshape((-1, 2))[i], np.array([neg_ids[n]])]))

        self.triplets = np.array(triplets)
        
    # data generator for triplets
    def batch_generator(self):
        i = 0
        while True:
            batch = self.triplets[i:i+self.batch_size//3].ravel()
            
            i += self.batch_size//3
            if len(batch) == 0:
                yield np.zeros((0, img_size, img_size, 3))
            else:
                result = pool.map(preprocess_image_worker_aug, batch)
                X_batch = np.concatenate(result, axis=0)
                yield X_batch, np.zeros(len(batch))

    # return data generator for triplets
    def get_generator(self):
        gen = self.batch_generator()
        return gen

In [None]:
sampler = SemiHardNegativeSampler(pool, batch_size, num_triplets)

### Create pool of processes for parallel data loading

In [None]:
pool = mp.Pool(processes=8)

### Attach tensorboard to monitor learining process

In [None]:
tensorboard = keras.callbacks.TensorBoard(log_dir=log_dir,
                 histogram_freq=1, 
                 write_graph=False, 
                 write_images=False)

### Training process
Each epoch we 
- train model on triplets with semi-hard negatives from `sampler`
- resample triplets
- do validation and save model with frequency `valid_frequency`

I use annoy index for nearest neighbors search to speedup validation

In [None]:
epoch = 0

In [None]:
for i in range(num_epoch):
    train_gen = sampler.get_generator()
    h = model.fit_generator(train_gen, steps_per_epoch=num_triplets//(batch_size//3), epochs=epoch+1, initial_epoch=epoch, verbose=2,
                        max_q_size=1, callbacks=[tensorboard,])
    sampler.resample()
    
    if epoch%valid_frequency == 0 and epoch != 0:
        seller_pred = model.predict_generator(batch_generator_predict(pool, 32, seller_images), math.ceil(len(seller_images)/32), max_q_size=1, workers=1)
        val_user_pred = model.predict_generator(batch_generator_predict(pool, 32, val_user_images), math.ceil(len(val_user_images)/32), max_q_size=1, workers=1)
        val_user_clean_pred = model.predict_generator(batch_generator_predict(pool, 32, val_user_clean_images), math.ceil(len(val_user_clean_images)/32), max_q_size=1, workers=1)

        search_index = AnnoyIndex(128, metric='euclidean')
        for i in range(len(seller_pred)):
            search_index.add_item(i, seller_pred[i])
        search_index.build(50)

        recall_scores = {i: [] for i in recall_values}
        for i in range(len(val_user_pred)):
            r = search_index.get_nns_by_vector(val_user_pred[i], 100)
            val_cl = val_image_to_class[val_user_images[i]]
            for k in recall_values:
                recall_scores[k].append(recall(val_images[val_cl][0], [seller_images[i] for i in r[:k]]))

        print ('val on full')
        for k in recall_values:
            print (k, np.mean(recall_scores[k]))

        val_recall = [(k, np.mean(recall_scores[k])) for k in recall_values]

        recall_scores = {i: [] for i in recall_values}

        for i in range(len(val_user_clean_pred)):
            r = search_index.get_nns_by_vector(val_user_clean_pred[i], 100)
            val_cl = val_image_clean_to_class[val_user_clean_images[i]]
            for k in recall_values:
                recall_scores[k].append(recall(val_images_clean[val_cl][0], [seller_images[i] for i in r[:k]]))

        print ('val on clean')
        for k in recall_values:
            print (k, np.mean(recall_scores[k]))

        val_recall_clean = [(k, np.mean(recall_scores[k])) for k in recall_values]
        
        try:
            with open(recall_log_file, 'r') as f:
                recall_log = json.load(f)
        except:
            recall_log = []

        recall_log.append((epoch, val_recall, val_recall_clean))

        with open(recall_log_file, 'w') as f:
            json.dump(recall_log, f)
            
        model.save_weights(os.path.join(checkpoint_dir, 'triplet_semihard_%d.keras'%epoch))
        
    epoch += 1