# Poeem Tutorial

In this notebook, we will demonstrate how to use *poeem* to jointly learn embedding index together with retrieval model. In addition to learning how to use *poeem*, you will also learn

- Write a simple embedding retrieval model
- Nearest neighbor search with brute force
- Approximate nearest neighbor (ANN) search with Faiss - a Facebook open source library for ANN search with separately already learned embeddings
- Approximate nearest neighbor (ANN) search with *Poeem* 

In [1]:
import tensorflow as tf
import numpy as np
import poeem


So far, *poeem* only supports Tensorflow 1.15, other tensorflow versions have not been tested. Users may need to make minor changes accordingly to let it run on other versions of Tensorflow.

In [2]:
assert tf.__version__[:4] == '1.15'

## Toy data

To demonstrate how *poeem* works, here we synthesizes a toy data for a quick tutorial. More real-world and larger data set tutorial is given [here]() 

In this toy data, the query and item are both represented as numerical ID numbers, which is simply also the row indices to their embedding matrices. Specifically, 

- a query is an integer number ranging from 0 to *vocab_size* (10,000 in this tutorial)
- a item is an integer number ranging from 0 to *vocab_size* (10,000 in this tutorial)
- a query ending with last two digits as *xy*, will retrieve items ending with last 4 digits as *abcd* where any two of them equal to x and y, e.g., a=x, b=y, or c=x, b=y and so on.

In [3]:
N = 1000000  # number of training examples
VOCAB_SIZE = 10000 

In [4]:
# simulate the data
query = np.random.randint(0, high=VOCAB_SIZE, size=N)
item = np.random.randint(0, high=VOCAB_SIZE, size=N)

d, c, b, a = item % 10, (item // 10) % 10, (item // 100) % 10, (item // 1000) % 10

def get_xy(a, axis):
    idx = np.random.rand(*a.shape).argsort(axis=axis)
    shuffled = np.take_along_axis(a,idx,axis=axis)
    return shuffled[:, 0], shuffled[:, 1]

x, y = get_xy(np.stack([a, b, c, d], axis=1), 1)
query = (query // 100) * 100 + x * 10 + y

Let's take a look at the synthetic data to make sure the pattern is correct.

In [5]:
[query[:10], item[:10]]

[array([ 336, 2328, 9134,  902, 8923, 1868,  168, 9344, 3081, 4668]),
 array([3679, 1282, 5334,  320, 3192, 9683, 6681, 4244, 8616, 9687])]

## Training

In this training section, we will write a very simple embedding retrieval model for demonstration. Please be advised that this embedding model is solely for tutorial but not immediately applicable to real-world industrial systems yet where more practical techniques are necessary.

First let's define some hyperparameters

In [6]:
BATCH_SIZE = 128
LEARNING_RATE = 0.1
EPOCH = 3
EMB_DIM = 64

Second, let's leverage the Tensorflow Estimator API with custom model_fn where we can define the model by ourselves and reuse the other convenient utilities to train the model.

The queries and items are represented as separate embeddings.

In [7]:
def model_fn(features, labels, mode):
    query_column = tf.feature_column.embedding_column(
        tf.feature_column.categorical_column_with_vocabulary_list(
            key='query',
            vocabulary_list=range(VOCAB_SIZE),
            dtype=tf.int32),
        dimension=EMB_DIM)
    item_column = tf.feature_column.embedding_column(
        tf.feature_column.categorical_column_with_vocabulary_list(
            key='item',
            vocabulary_list=range(VOCAB_SIZE),
            dtype=tf.int32),
        dimension=EMB_DIM)
    
    query_emb = tf.feature_column.input_layer(features, [query_column])
    item_emb = tf.feature_column.input_layer(features, [item_column])

    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(
            mode, predictions={'query': query_emb, 'item': item_emb})

    def cosine(a, b):
        a = tf.nn.l2_normalize(a, axis=1)
        b = tf.nn.l2_normalize(b, axis=1)
        return tf.matmul(a, b, transpose_b=True)

    scores = cosine(query_emb, item_emb)

    batch_size = tf.shape(query_emb)[0]
    loss = tf.reduce_sum(
        tf.nn.softmax_cross_entropy_with_logits_v2(
            labels=tf.eye(batch_size),
            logits=scores * 30))  # 1/30 is softmax temperature. Not carefully tune.

    optimizer = tf.train.AdagradOptimizer(LEARNING_RATE)
    train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())

    return tf.estimator.EstimatorSpec(
        mode, loss=loss, train_op=train_op, predictions={'query': query_emb, 'item': item_emb})


Third, let's define an input function that feeds data into the *model_fn* defined above.

In [8]:
def input_fn():
    dataset = tf.data.Dataset.from_tensor_slices({'query': query.astype(np.int32), 'item': item.astype(np.int32)})
    dataset = dataset.shuffle(buffer_size=1000).batch(BATCH_SIZE).repeat(EPOCH)
    return dataset

Finally, we are ready to train the model with the above defined *model_fn* and *input_fn* by simply two lines.

In [9]:
retrieval_model = tf.estimator.Estimator(model_fn=model_fn)
retrieval_model.train(input_fn=input_fn)

INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpevc_odlj', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f592f328668>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
Instructions for updating:
Use Variable.read_value. Vari

INFO:tensorflow:global_step/sec: 297.189
INFO:tensorflow:loss = 1088.1648, step = 3800 (0.336 sec)
INFO:tensorflow:global_step/sec: 292.949
INFO:tensorflow:loss = 998.91, step = 3900 (0.340 sec)
INFO:tensorflow:global_step/sec: 290.474
INFO:tensorflow:loss = 992.79736, step = 4000 (0.345 sec)
INFO:tensorflow:global_step/sec: 298.805
INFO:tensorflow:loss = 897.68207, step = 4100 (0.335 sec)
INFO:tensorflow:global_step/sec: 282.223
INFO:tensorflow:loss = 938.8969, step = 4200 (0.354 sec)
INFO:tensorflow:global_step/sec: 290.298
INFO:tensorflow:loss = 924.9182, step = 4300 (0.346 sec)
INFO:tensorflow:global_step/sec: 263.531
INFO:tensorflow:loss = 865.07947, step = 4400 (0.379 sec)
INFO:tensorflow:global_step/sec: 270.575
INFO:tensorflow:loss = 924.31116, step = 4500 (0.369 sec)
INFO:tensorflow:global_step/sec: 268.637
INFO:tensorflow:loss = 894.6639, step = 4600 (0.372 sec)
INFO:tensorflow:global_step/sec: 265.829
INFO:tensorflow:loss = 872.739, step = 4700 (0.378 sec)
INFO:tensorflow:gl

INFO:tensorflow:loss = 511.33313, step = 12100 (0.383 sec)
INFO:tensorflow:global_step/sec: 247.255
INFO:tensorflow:loss = 471.68274, step = 12200 (0.406 sec)
INFO:tensorflow:global_step/sec: 267.009
INFO:tensorflow:loss = 453.0332, step = 12300 (0.373 sec)
INFO:tensorflow:global_step/sec: 264.794
INFO:tensorflow:loss = 448.7835, step = 12400 (0.378 sec)
INFO:tensorflow:global_step/sec: 296.925
INFO:tensorflow:loss = 466.60553, step = 12500 (0.339 sec)
INFO:tensorflow:global_step/sec: 294.254
INFO:tensorflow:loss = 435.78006, step = 12600 (0.339 sec)
INFO:tensorflow:global_step/sec: 284.3
INFO:tensorflow:loss = 469.10757, step = 12700 (0.351 sec)
INFO:tensorflow:global_step/sec: 296.869
INFO:tensorflow:loss = 469.70255, step = 12800 (0.337 sec)
INFO:tensorflow:global_step/sec: 278.997
INFO:tensorflow:loss = 485.26648, step = 12900 (0.358 sec)
INFO:tensorflow:global_step/sec: 281.377
INFO:tensorflow:loss = 435.17456, step = 13000 (0.355 sec)
INFO:tensorflow:global_step/sec: 295.571
INFO

INFO:tensorflow:global_step/sec: 271.434
INFO:tensorflow:loss = 378.81793, step = 20400 (0.365 sec)
INFO:tensorflow:global_step/sec: 280.316
INFO:tensorflow:loss = 315.43988, step = 20500 (0.356 sec)
INFO:tensorflow:global_step/sec: 282.267
INFO:tensorflow:loss = 393.28018, step = 20600 (0.355 sec)
INFO:tensorflow:global_step/sec: 257.712
INFO:tensorflow:loss = 343.27853, step = 20700 (0.389 sec)
INFO:tensorflow:global_step/sec: 267.135
INFO:tensorflow:loss = 344.5199, step = 20800 (0.374 sec)
INFO:tensorflow:global_step/sec: 266.948
INFO:tensorflow:loss = 339.63144, step = 20900 (0.374 sec)
INFO:tensorflow:global_step/sec: 253.337
INFO:tensorflow:loss = 366.99655, step = 21000 (0.395 sec)
INFO:tensorflow:global_step/sec: 243.126
INFO:tensorflow:loss = 343.1393, step = 21100 (0.411 sec)
INFO:tensorflow:global_step/sec: 251.195
INFO:tensorflow:loss = 342.42993, step = 21200 (0.398 sec)
INFO:tensorflow:global_step/sec: 258.512
INFO:tensorflow:loss = 363.39047, step = 21300 (0.387 sec)
IN

<tensorflow_estimator.python.estimator.estimator.Estimator at 0x7f583781f080>

## Retrieval

### Export embeddings

After the retrieval model is trained, we need to export all the query and item embeddings with the trained parameters. With Tensorflow Estimator framework, this can be done easily by constructing another *input_fn* to feed all the data to the model and grab the predictions once. 

In [10]:
def predict_input_fn():
    dataset = tf.data.Dataset.from_tensor_slices({'query': np.arange(VOCAB_SIZE), 'item': np.arange(VOCAB_SIZE)})
    dataset = dataset.batch(VOCAB_SIZE)
    return dataset

results = list(retrieval_model.predict(input_fn=predict_input_fn))
query_emb = np.stack([r['query'] for r in results], axis=0)
item_emb = np.stack([r['item'] for r in results], axis=0)

INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpevc_odlj/model.ckpt-23439
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.


### Brute force search

Before we try approximate nearest neighbor (ANN) search algorithms, let's first do a brute force search to get the upper bound of the retrieval accuracy. Theoretically, any ANN search algorithms should be much faster than Brute Force method but somewhat worse in retrieval accuracy.

In [11]:
def brute_force_search(query_id, items, k=10):
    query = query_emb[query_id:(query_id+1), :]
    query_norm = np.linalg.norm(query, axis=1, keepdims=True)
    item_norm = np.linalg.norm(items, axis=1, keepdims=True)
    cos = np.matmul(query, np.transpose(items)) / query_norm / np.transpose(item_norm)
    cos = cos.flatten()
    sorted_item_id = np.argsort(-cos)
    return sorted_item_id[:k]

Let's have quick look at the retrieval results. Note that the retrieved item IDs all have the two digits, 8 and 9.

In [12]:
brute_force_search(98, item_emb)

array([9998, 1998, 8179, 1698, 1978, 3981, 8964, 6948, 7789, 8901])

Now we can compute a comprehensive retrieval accuracy called precision@k, where we use k=100. This metric measures that for the top 100 retrieved items, how much percentage of them are correct.

In [13]:
def precision_at_k(search_fn, query_id, items, k=100):
    nn_items = search_fn(query_id, items, k=k)
    
    d, c, b, a = nn_items % 10, (nn_items // 10) % 10, (nn_items // 100) % 10, (nn_items // 1000) % 10
    abcd = np.stack([a, b, c, d], axis=1)
    y, x = query_id % 10, (query_id // 10) % 10
    # check if x and y can be drawn from abcd without replacement
    match = np.sum(np.logical_or(x == abcd, y == abcd), axis=1) >= 2 
    precision = np.sum(match) / k
    return precision

precision_at_100 = [precision_at_k(brute_force_search, i, item_emb, k=100) for i in range(VOCAB_SIZE)]
print("overall precision@100 = ", np.mean(precision_at_100))

overall precision@100 =  0.998862


## Faiss search

Now let's try Faiss, which is a widely used ANN library developed by Facebook. It is based on Product Quantization techniques. Here we set our index type to be 'IVF8,PQ8', which means coarse quantization into *8* clusters and then product quantization into *8* segments, each of which is represented by *one* byte, or *2^8 = 256* subvectors.

First, we need to build an embedding index with all the item embeddings. Note that we need to first normalize all item embeddings before building the index with *inner product* as distance metric, or more precisely, inverse distance metric.

In [14]:
import faiss

index = faiss.index_factory(EMB_DIM, 'IVF8,PQ8', faiss.METRIC_INNER_PRODUCT)
item_emb = item_emb / np.linalg.norm(item_emb, axis=1, keepdims=True)
index.train(item_emb)
index.add(item_emb)

Loading faiss with AVX2 support.
Loading faiss.


In [15]:
def faiss_search(query_id, items, k=10):
    query = query_emb[query_id:(query_id+1), :]
    D, I = index.search(query, k)
    D, I = D.flatten(), I.flatten()
    return I[:k]

In [16]:
precision_at_100 = [precision_at_k(faiss_search, i, item_emb, k=100) for i in range(VOCAB_SIZE)]
print("overall precision@100 = ", np.mean(precision_at_100))   

overall precision@100 =  0.960966


## Poeem search

Apart from above separately built embedding index, *poeem* learns the embedding model jointly with embedding index. Thus, there is no extra index building step. But we need to make some simple changes into the above *model_fn* function to adopt *poeem* indexing layer.

Note that we only need to make changes at three places, marked by ### [poeem code]

In [17]:
def poeem_model_fn(features, labels, mode, params):
    query_column = tf.feature_column.embedding_column(
        tf.feature_column.categorical_column_with_vocabulary_list(
            key='query',
            vocabulary_list=range(VOCAB_SIZE),
            dtype=tf.int32),
        dimension=EMB_DIM)
    query_emb = tf.feature_column.input_layer(features, [query_column])
    
    if mode == tf.estimator.ModeKeys.PREDICT:
        ### [poeem code] directly do ANN search in the model as a TensorFlow op.
        if params.get('item_search', False):
            index = poeem.search.index_from_file(params['index_file'])
            neighbors, scores = index.search(
                tf.expand_dims(query_emb, 1),
                params['topk'],
                params['nprobe'],
                params['metric_type'],
                verbose=False)        
            return tf.estimator.EstimatorSpec(
                mode, predictions={'neighbors': neighbors, 'scores': scores})
        ### end [poeem code]
        
    item_column = tf.feature_column.embedding_column(
        tf.feature_column.categorical_column_with_vocabulary_list(
            key='item',
            vocabulary_list=range(VOCAB_SIZE),
            dtype=tf.int32),
        dimension=EMB_DIM)
    item_emb = tf.feature_column.input_layer(features, [item_column])
    item_emb = tf.nn.l2_normalize(item_emb, axis=1)
    

    ### [poeem code] item indexing layer as the last layer in item tower
    hparams = poeem.embedding.PoeemHparam(coarse_K=8,
                                          K=256,
                                          D=8,
                                          rotate=0) # exactly the same parameters as Faiss, specified above.
    item_batch_quantized = poeem.embedding.PoeemEmbed(
        EMB_DIM,
        warmup_steps=16384,
        buffer_size=8192,
        hparams=hparams,
        mode=mode)
    
    # gradient straight-through estimator. For details, check out our paper.
    item_emb_tau, coarse_code, code, regularizer = item_batch_quantized.forward(item_emb)
    item_emb = item_emb - tf.stop_gradient(item_emb - item_emb_tau)
    ### end [poeem code] 

    if mode == tf.estimator.ModeKeys.PREDICT:
        ### [poeem code] exprt item embeddings/PQ code for disk persistency.
        if params.get('item_predict', False):
            return tf.estimator.EstimatorSpec(
                mode, predictions={
                    'item_coarse_code': coarse_code,
                    'item_code': code,
                    'item_norm': tf.norm(item_emb, axis=1)
                })
        ### end [poeem code] 

    def cosine(a, b):
        a = tf.nn.l2_normalize(a, axis=1)
        b = tf.nn.l2_normalize(b, axis=1)
        return tf.matmul(a, b, transpose_b=True)

    scores = cosine(query_emb, item_emb)

    batch_size = tf.shape(query_emb)[0]
    loss = tf.reduce_sum(
        tf.nn.softmax_cross_entropy_with_logits_v2(
            labels=tf.eye(batch_size),
            logits=scores * 30))
    
    loss = loss + regularizer

    optimizer = tf.train.AdagradOptimizer(LEARNING_RATE)
    train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())

    return tf.estimator.EstimatorSpec(
        mode, loss=loss, train_op=train_op, predictions={'query': query_emb, 'item': item_emb})


In [18]:
MODEL_DIR = './poeem_model'
poeem_model = tf.estimator.Estimator(model_fn=poeem_model_fn, model_dir=MODEL_DIR, params={})
poeem_model.train(input_fn=input_fn)

INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': './poeem_model', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f5836487eb8>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
Instructions for updating

INFO:tensorflow:loss = 912.9878, step = 4900 (0.446 sec)
INFO:tensorflow:global_step/sec: 218.018
INFO:tensorflow:loss = 836.3109, step = 5000 (0.457 sec)
INFO:tensorflow:global_step/sec: 233.603
INFO:tensorflow:loss = 876.7561, step = 5100 (0.428 sec)
INFO:tensorflow:global_step/sec: 201.823
INFO:tensorflow:loss = 825.44214, step = 5200 (0.497 sec)
INFO:tensorflow:global_step/sec: 193.2
INFO:tensorflow:loss = 905.7927, step = 5300 (0.516 sec)
INFO:tensorflow:global_step/sec: 183.035
INFO:tensorflow:loss = 903.2189, step = 5400 (0.547 sec)
INFO:tensorflow:global_step/sec: 213.489
INFO:tensorflow:loss = 791.0683, step = 5500 (0.470 sec)
INFO:tensorflow:global_step/sec: 222.781
INFO:tensorflow:loss = 810.9672, step = 5600 (0.447 sec)
INFO:tensorflow:global_step/sec: 212.611
INFO:tensorflow:loss = 776.9982, step = 5700 (0.470 sec)
INFO:tensorflow:global_step/sec: 193.615
INFO:tensorflow:loss = 779.54535, step = 5800 (0.516 sec)
INFO:tensorflow:global_step/sec: 224.86
INFO:tensorflow:loss 

INFO:tensorflow:loss = 429.44733, step = 13200 (0.459 sec)
INFO:tensorflow:global_step/sec: 208.816
INFO:tensorflow:loss = 449.2091, step = 13300 (0.479 sec)
INFO:tensorflow:global_step/sec: 214.046
INFO:tensorflow:loss = 438.45218, step = 13400 (0.468 sec)
INFO:tensorflow:global_step/sec: 202.494
INFO:tensorflow:loss = 473.16522, step = 13500 (0.494 sec)
INFO:tensorflow:global_step/sec: 200.667
INFO:tensorflow:loss = 454.15295, step = 13600 (0.498 sec)
INFO:tensorflow:global_step/sec: 200.06
INFO:tensorflow:loss = 447.38815, step = 13700 (0.501 sec)
INFO:tensorflow:global_step/sec: 200.387
INFO:tensorflow:loss = 451.46802, step = 13800 (0.498 sec)
INFO:tensorflow:global_step/sec: 212.333
INFO:tensorflow:loss = 415.02158, step = 13900 (0.471 sec)
INFO:tensorflow:global_step/sec: 217.942
INFO:tensorflow:loss = 477.11353, step = 14000 (0.460 sec)
INFO:tensorflow:global_step/sec: 234.041
INFO:tensorflow:loss = 442.51498, step = 14100 (0.426 sec)
INFO:tensorflow:global_step/sec: 236.32
INF

INFO:tensorflow:global_step/sec: 183.42
INFO:tensorflow:loss = 445.52228, step = 21500 (0.545 sec)
INFO:tensorflow:global_step/sec: 187.75
INFO:tensorflow:loss = 469.8447, step = 21600 (0.533 sec)
INFO:tensorflow:global_step/sec: 205.683
INFO:tensorflow:loss = 507.37323, step = 21700 (0.486 sec)
INFO:tensorflow:global_step/sec: 204.526
INFO:tensorflow:loss = 490.19412, step = 21800 (0.489 sec)
INFO:tensorflow:global_step/sec: 216.002
INFO:tensorflow:loss = 499.43527, step = 21900 (0.468 sec)
INFO:tensorflow:global_step/sec: 210.865
INFO:tensorflow:loss = 525.61163, step = 22000 (0.470 sec)
INFO:tensorflow:global_step/sec: 210.081
INFO:tensorflow:loss = 503.64816, step = 22100 (0.476 sec)
INFO:tensorflow:global_step/sec: 211.904
INFO:tensorflow:loss = 487.39618, step = 22200 (0.472 sec)
INFO:tensorflow:global_step/sec: 209.978
INFO:tensorflow:loss = 484.3354, step = 22300 (0.476 sec)
INFO:tensorflow:global_step/sec: 209.813
INFO:tensorflow:loss = 453.09445, step = 22400 (0.476 sec)
INFO

<tensorflow_estimator.python.estimator.estimator.Estimator at 0x7f5836487ac8>

### Export item embedding and build index

Though theoretically *poeem* does not need to build an index, we can still optionally build one to persist the embedding index into disk. Since *poeem* indexing layer has already done coarse quantization and product quantization internally, the index building just needs to export those values into an index file as follows.

In [19]:
def predict_input_fn():
    dataset = tf.data.Dataset.from_tensor_slices({'query': np.arange(VOCAB_SIZE), 'item': np.arange(VOCAB_SIZE)})
    dataset = dataset.batch(VOCAB_SIZE)
    return dataset

poeem_model = tf.estimator.Estimator(model_fn=poeem_model_fn, model_dir=MODEL_DIR, 
                                     params={'item_predict': True})
results = list(poeem_model.predict(input_fn=predict_input_fn))

INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': './poeem_model', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f5836a4a710>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done call

Collect all the data we need to write into an index file

In [20]:
item_coarse_code = np.array([e['item_coarse_code'] for e in results])
item_code = np.array([e['item_code'] for e in results])
item_norm = np.array([e['item_norm'] for e in results])
item_id = np.arange(VOCAB_SIZE)
coarse_codebook = tf.train.load_variable(MODEL_DIR, 'coarse_centroids')
codebook = tf.train.load_variable(MODEL_DIR, 'centroids_k')

INDEX_FILE = './poeem.idx'
poeem.indexing.write_index_file(INDEX_FILE, codebook, item_id, item_norm, item_code, 
                                coarse_codebook, item_coarse_code, use_residual=True)

### Poeem nearest neighbor search

*poeem* ANN search would be in an end-to-end fashion, i.e., input a query and output its nearest neighbor items directly. The most simplest setup would be as follows 

In [21]:
poeem_model = tf.estimator.Estimator(model_fn=poeem_model_fn, model_dir=MODEL_DIR, 
                                     params={'item_search': True, 'index_file': INDEX_FILE, 
                                             'topk': 100, 'nprobe': 1, 'metric_type': 0})

INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': './poeem_model', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f5836a4a908>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}


To measure *poeem* rerieval accuracy, we first compute retrieval results for all queries.

In [22]:
def search_input_fn():
    dataset = tf.data.Dataset.from_tensor_slices({'query': np.arange(VOCAB_SIZE)})
    dataset = dataset.batch(VOCAB_SIZE)
    return dataset

results = list(poeem_model.predict(input_fn=search_input_fn))
neighbors = np.array([e['neighbors'] for e in results])
scores = np.array([e['scores'] for e in results])

INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from ./poeem_model/model.ckpt-23439
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.


Follows the search function interface as above Brute Force and Faiss search, so we can reuse the precision_at_k utitity function.

In [23]:
def poeem_search(query_id, items, k=100):
    return neighbors[query_id, :k]

In [24]:
precision_at_100 = [precision_at_k(poeem_search, i, None, k=100) for i in range(VOCAB_SIZE)]
print("overall precision@100 = ", np.mean(precision_at_100))   

overall precision@100 =  0.9804640000000001


**Observation**: Note that Poeem could reach higher retrieval accuracy than Faiss, by jointly learning the embedding index and retrieval model.

This is a simple example as a quick rampup for beginners. For more rigorous experimental results to draw conclusions, please checkout our SIGIR paper.