In [None]:
#default_exp indexer

In [None]:
#hide
from nbdev.showdoc import *

# Indexer

Given a dataset of tensors, returns a dictionary archive and a treemap structure (and saves them to disk)

## Joiner

This executor `needs` both Encoder and Loader to send it the new and old vectors, respectively. So it needs to be preceded by the **join_all** component to make sure we're not missing new data before handing it over to the indexer -- or indexing old data that no longer exists!

In [None]:
#export
def join_all(db, new_files, new_embeddings):
    start = len(db)
    for i, file in enumerate(new_files):
        path, slug = file
        index = i + start
        db[index] = {
            'slug': slug,
            'fpath': path,
            'embed': new_embeddings[i],
        }
    return(db)

In [None]:
import torch
from pathlib import Path
from memery.loader import get_image_files, db_loader, archive_loader
from memery.crafter import crafter
from memery.encoder import image_encoder

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
root = Path('images/')

In [None]:
filepaths = get_image_files(root)
archive_db = {}


archive_db, new_files = archive_loader(filepaths, root, device)
print(f"Loaded {len(archive_db)} encodings")
print(f"Encoding {len(new_files)} new images")

crafted_files = crafter(new_files, device)
new_embeddings = image_encoder(crafted_files, device)

db = join_all(archive_db, new_files, new_embeddings)

  0%|          | 0/1 [00:00<?, ?it/s]

Loaded 0 encodings
Encoding 80 new images


100%|██████████| 1/1 [00:02<00:00,  2.75s/it]


In [None]:
# db = db_loader(root/'memery.pt',device)

In [None]:
[o[0] for o in db.items()][:5]

[0, 1, 2, 3, 4]

In [None]:
len(db)

80

Building treemap takes a long time. I don't think `annoy` uses the GPU at all?

In [None]:
#export
from annoy import AnnoyIndex

In [None]:
#export
def build_treemap(db):
    treemap = AnnoyIndex(512, 'angular')
    for k, v in db.items():
        treemap.add_item(k, v['embed'])

    # Build the treemap, with 5 trees rn
    treemap.build(5)

    return(treemap)
    

In [None]:
t = build_treemap(db)

In [None]:
t.get_n_items(), t.get_n_trees()

(80, 5)

In [None]:
#export
import torch

In [None]:
#export
def save_archives(root, treemap, db):
    dbpath = root/'memery.pt'
    if dbpath.exists():
#         dbpath.rename(root/'memery-bak.pt')
        dbpath.unlink()
    torch.save(db, dbpath)
    
    treepath = root/'memery.ann'
    if treepath.exists():
#         treepath.rename(root/'memery-bak.ann')
        treepath.unlink()
    treemap.save(str(treepath))
    
    return(str(dbpath), str(treepath))

In [None]:
save_archives(root, t, db)

('images/memery.pt', 'images/memery.ann')