In [None]:
#default_exp ranker

In [None]:
#hide
from nbdev.showdoc import *

# Ranker

Takes a query and an index and finds the nearest neighbors or most similar scores. Ideally this is just a simple Annoy `get_nns_by_vector`, or in the simple case a similarity score across all the vectors.

In [None]:
import torch


from pathlib import Path

from memery.loader import treemap_loader, db_loader
from memery.encoder import text_encoder

In [None]:
treemap = treemap_loader(Path('images/memery.ann'))

In [None]:
treemap.get_n_items()

80

In [None]:
#export
def ranker(query_vec, treemap):
    nn_indexes = treemap.get_nns_by_vector(query_vec[0], treemap.get_n_items())
    return(nn_indexes)

In [None]:
#export
def nns_to_files(db, indexes):
    return([[v['fpath'] for k,v in db.items() if v['index'] == ind][0] for ind in indexes])

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
db = db_loader(Path('images/memery.pt'), device)

In [None]:
query = 'dog'

In [None]:
query_vec = text_encoder(query, device)
indexes = ranker(query_vec, treemap)
ranked_files = nns_to_files(db, indexes)

In [None]:
def n2flong(db, indexes):
    ranked = []
    for ind in indexes:
        for k, v in db.items():
            if v['index'] == ind:
                ranked.append(v['fpath'])
    return(ranked)

In [None]:
%load_ext line_profiler

In [None]:
%prun nns_to_files(db, indexes)

 

         162 function calls in 0.000 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
       79    0.000    0.000    0.000    0.000 <ipython-input-7-498f8d95be68>:3(<listcomp>)
        1    0.000    0.000    0.000    0.000 {built-in method builtins.exec}
       79    0.000    0.000    0.000    0.000 {method 'items' of 'dict' objects}
        1    0.000    0.000    0.000    0.000 <ipython-input-7-498f8d95be68>:2(nns_to_files)
        1    0.000    0.000    0.000    0.000 <string>:1(<module>)
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}

In [None]:
%prun n2flong(db, indexes)

 

         162 function calls in 0.000 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    0.000    0.000 <ipython-input-11-7425d3e4632d>:1(n2flong)
        1    0.000    0.000    0.000    0.000 {built-in method builtins.exec}
       79    0.000    0.000    0.000    0.000 {method 'append' of 'list' objects}
       79    0.000    0.000    0.000    0.000 {method 'items' of 'dict' objects}
        1    0.000    0.000    0.000    0.000 <string>:1(<module>)
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}

In [None]:

def sort_db_by_nns(db, nn_indexes):
    slugs = [v['fpath'] for v in db.values()]
    slug_nns = zip(slugs, nn_indexes)
    ranked_slugs = sorted(slug_nns, key=lambda o: o[1])
    return(ranked_slugs)
#     return(ranked_slugs)
        
#     ranked_db = sorted(db, key=lambda k,v: v.get('index') == i)

In [None]:
%prun sort_db_by_nns(db, indexes)

 

         86 function calls in 0.000 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    0.000    0.000 {built-in method builtins.exec}
        1    0.000    0.000    0.000    0.000 {built-in method builtins.sorted}
        1    0.000    0.000    0.000    0.000 <ipython-input-20-0bb824d6058d>:2(<listcomp>)
       79    0.000    0.000    0.000    0.000 <ipython-input-20-0bb824d6058d>:4(<lambda>)
        1    0.000    0.000    0.000    0.000 <ipython-input-20-0bb824d6058d>:1(sort_db_by_nns)
        1    0.000    0.000    0.000    0.000 <string>:1(<module>)
        1    0.000    0.000    0.000    0.000 {method 'values' of 'dict' objects}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}

In [None]:
assert nns_to_files(db, indexes) == sort_db_by_nns(db, indexes)

AssertionError: 

In [None]:
sort_db_by_nns(db, indexes)

[('images/mexican-food-concept-EXFWKZG.jpg', 0),
 ('images/Wholesome-Meme-84.jpg', 1),
 ('images/Wholesome-Meme-29.jpg', 2),
 ('images/Wholesome-Meme-64.jpg', 3),
 ('images/Wholesome-Meme-61.png', 4),
 ('images/Wholesome-Meme-72.jpg', 5),
 ('images/Wholesome-Meme-67.png', 6),
 ('images/Wholesome-Meme-40.png', 7),
 ('images/happy-young-couple-eat-breakfast-in-bed-in-morning-RH4KQ72.jpg', 8),
 ('images/Wholesome-Meme-69.jpg', 9),
 ('images/Wholesome-Meme-15.jpg', 10),
 ('images/Wholesome-Meme-97.jpg', 11),
 ('images/Wholesome-Meme-82.jpg', 12),
 ('images/Wholesome-Meme-98.jpg', 13),
 ('images/Wholesome-Meme-39.jpg', 14),
 ('images/Wholesome-Meme-44.png', 15),
 ('images/Wholesome-Meme-6.jpg', 16),
 ('images/Wholesome-Meme-89.jpg', 17),
 ('images/Wholesome-Meme-36.jpg', 18),
 ('images/Wholesome-Meme-86.jpg', 19),
 ('images/Wholesome-Meme-7.jpg', 20),
 ('images/Wholesome-Meme-33.jpg', 21),
 ('images/Wholesome-Meme-16.jpg', 22),
 ('images/Wholesome-Meme-60.jpg', 23),
 ('images/Wholesome-Meme

In [None]:
nns_to_files(db, indexes)

['images/Wholesome-Meme-8.jpg',
 'images/Wholesome-Meme-5.jpg',
 'images/Wholesome-Meme-35.jpg',
 'images/Wholesome-Meme-67.png',
 'images/embarassed-dog-on-bed-SA2BDZW.jpg',
 'images/Wholesome-Meme-72.jpg',
 'images/braydon-anderson-wOHH-NUTvVc-unsplash-min.jpg',
 'images/cute-dog-with-cupcake-P9E2YL5-min.jpg',
 'images/Wholesome-Meme-3.jpg',
 'images/Wholesome-Meme-18.jpg',
 'images/Wholesome-Meme-29.jpg',
 'images/Wholesome-Meme-13.jpg',
 'images/Wholesome-Meme-14.jpg',
 'images/Wholesome-Meme-71.jpg',
 'images/Wholesome-Meme-44.png',
 'images/Wholesome-Meme-63.jpg',
 'images/Wholesome-Meme-15.jpg',
 'images/Wholesome-Meme-17.jpg',
 'images/Wholesome-Meme-98.jpg',
 'images/stonks-meme.jpg',
 'images/Wholesome-Meme-39.jpg',
 'images/Wholesome-Meme-68.jpg',
 'images/Wholesome-Meme-45.jpg',
 'images/Wholesome-Meme-36.jpg',
 'images/Wholesome-Meme-22.jpg',
 'images/Wholesome-Meme-77.jpg',
 'images/Wholesome-Meme-61.png',
 'images/Wholesome-Meme-25.jpg',
 'images/Wholesome-Meme-84.jpg',


In [None]:
    slugs = [(v['fpath'], v['index'] )for v in db.values()]

In [None]:
slugs

[('images/Wholesome-Meme-3.jpg', 0),
 ('images/Wholesome-Meme-44.png', 1),
 ('images/Wholesome-Meme-69.jpg', 2),
 ('images/Wholesome-Meme-59.jpg', 3),
 ('images/Wholesome-Meme-68.jpg', 4),
 ('images/Wholesome-Meme-72.jpg', 5),
 ('images/Wholesome-Meme-57.jpg', 6),
 ('images/Wholesome-Meme-74.jpg', 7),
 ('images/mexican-food-concept-EXFWKZG.jpg', 8),
 ('images/Wholesome-Meme-35.jpg', 9),
 ('images/Envato-Elements.png', 10),
 ('images/Wholesome-Meme-89.jpg', 11),
 ('images/Wholesome-Meme-80.jpg', 12),
 ('images/Wholesome-Meme-45.jpg', 13),
 ('images/Wholesome-Meme-84.jpg', 14),
 ('images/Wholesome-Meme-5.jpg', 15),
 ('images/Wholesome-Meme-1.jpg', 16),
 ('images/Wholesome-Meme-13.jpg', 17),
 ('images/Wholesome-Meme-16.jpg', 18),
 ('images/Wholesome-Meme-23.jpg', 19),
 ('images/Wholesome-Meme-77.jpg', 20),
 ('images/Wholesome-Meme-61.png', 21),
 ('images/Wholesome-Meme-98.jpg', 22),
 ('images/Wholesome-Meme-73.png', 23),
 ('images/Wholesome-Meme-9.jpg', 24),
 ('images/Wholesome-Meme-7.jpg

In [None]:
[[slug for slug, index in slugs if index == i][0] for i in indexes]

['images/Wholesome-Meme-8.jpg',
 'images/Wholesome-Meme-5.jpg',
 'images/Wholesome-Meme-35.jpg',
 'images/Wholesome-Meme-67.png',
 'images/embarassed-dog-on-bed-SA2BDZW.jpg',
 'images/Wholesome-Meme-72.jpg',
 'images/braydon-anderson-wOHH-NUTvVc-unsplash-min.jpg',
 'images/cute-dog-with-cupcake-P9E2YL5-min.jpg',
 'images/Wholesome-Meme-3.jpg',
 'images/Wholesome-Meme-18.jpg',
 'images/Wholesome-Meme-29.jpg',
 'images/Wholesome-Meme-13.jpg',
 'images/Wholesome-Meme-14.jpg',
 'images/Wholesome-Meme-71.jpg',
 'images/Wholesome-Meme-44.png',
 'images/Wholesome-Meme-63.jpg',
 'images/Wholesome-Meme-15.jpg',
 'images/Wholesome-Meme-17.jpg',
 'images/Wholesome-Meme-98.jpg',
 'images/stonks-meme.jpg',
 'images/Wholesome-Meme-39.jpg',
 'images/Wholesome-Meme-68.jpg',
 'images/Wholesome-Meme-45.jpg',
 'images/Wholesome-Meme-36.jpg',
 'images/Wholesome-Meme-22.jpg',
 'images/Wholesome-Meme-77.jpg',
 'images/Wholesome-Meme-61.png',
 'images/Wholesome-Meme-25.jpg',
 'images/Wholesome-Meme-84.jpg',


In [None]:
for i in indexes:
    print(i)
    for slug, index in slugs:
        if index == i:
            print(slug, index)

48
images/Wholesome-Meme-8.jpg 48
15
images/Wholesome-Meme-5.jpg 15
9
images/Wholesome-Meme-35.jpg 9
35
images/Wholesome-Meme-67.png 35
71
images/embarassed-dog-on-bed-SA2BDZW.jpg 71
5
images/Wholesome-Meme-72.jpg 5
27
images/braydon-anderson-wOHH-NUTvVc-unsplash-min.jpg 27
60
images/cute-dog-with-cupcake-P9E2YL5-min.jpg 60
0
images/Wholesome-Meme-3.jpg 0
72
images/Wholesome-Meme-18.jpg 72
43
images/Wholesome-Meme-29.jpg 43
17
images/Wholesome-Meme-13.jpg 17
41
images/Wholesome-Meme-14.jpg 41
70
images/Wholesome-Meme-71.jpg 70
1
images/Wholesome-Meme-44.png 1
76
images/Wholesome-Meme-63.jpg 76
77
images/Wholesome-Meme-15.jpg 77
53
images/Wholesome-Meme-17.jpg 53
22
images/Wholesome-Meme-98.jpg 22
47
images/stonks-meme.jpg 47
28
images/Wholesome-Meme-39.jpg 28
4
images/Wholesome-Meme-68.jpg 4
13
images/Wholesome-Meme-45.jpg 13
54
images/Wholesome-Meme-36.jpg 54
61
images/Wholesome-Meme-22.jpg 61
20
images/Wholesome-Meme-77.jpg 20
21
images/Wholesome-Meme-61.png 21
58
images/Wholesome-Me