In [None]:
#default_exp loader

In [None]:
#hide
from nbdev.showdoc import *

# Loader
> Functions for finding and loading image files and saved embeddings


## File manipulation

In [None]:
#export
from pathlib import Path
from PIL import Image
from tqdm import tqdm

**NB: A lot of this implementation is too specific, especially the slugified filenames being used for dictionary IDs. Should be replaced with a better database implementation.**

In [None]:
#export
def slugify(filepath):
    return f'{filepath.stem}_{str(filepath.stat().st_mtime).split(".")[0]}'

def get_image_files(path):
    img_extensions = {'.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp'}
    return [(f, slugify(f)) for f in tqdm(path.rglob('*')) if f.suffix in img_extensions]

def get_valid_images(path):
    filepaths = get_image_files(path)
    return [f for f in filepaths if verify_image(f[0])]

# This returns boolean and should be called is_valid_image or something like that
def verify_image(f):
    try:
        img = Image.open(f)
        img.verify() 
        return(True)
    except Exception as e:
        print(f'Skipping bad file: {f}\ndue to {type(e)}')
        pass
    

Demonstrating the usage here, not a great test though:

In [None]:
root = Path('./images')


filepaths = get_image_files(root)

len(filepaths)

85it [00:00, 59419.31it/s]


82

In [None]:
filepaths[:3]

[(PosixPath('images/Wholesome-Meme-3.jpg'), 'Wholesome-Meme-3_1621298477'),
 (PosixPath('images/Wholesome-Meme-44.png'), 'Wholesome-Meme-44_1621298480'),
 (PosixPath('images/Wholesome-Meme-69.jpg'), 'Wholesome-Meme-69_1621298481')]

## Loaders

So we have a list of paths and slugified filenames from the folder. We want to see if there's an archive, so that we don't have to recalculate tensors for images we've seen before. Then we want to pass that directly to the indexer, but send the new images through the crafter and encoder first.



In [None]:
#export
import torch
import torchvision

We want to use the GPU, if possible, for all the pyTorch functions. But if we can't get access to it we need to fallback to CPU. Either way we call it `device` and pass it to each function in the executors that use torch.

In [None]:
#export
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

The `archive_loader` is only called in `indexFlow`. It takes the list of image files and the folder they're in (and the torch device), opens an archive if there is one

In [None]:
#export
def archive_loader(filepaths, root, device):
    dbpath = root/'memery.pt'
#     dbpath_backup = root/'memery.pt'
    db = db_loader(dbpath, device)
    
    current_slugs = [slug for path, slug in filepaths]    
    archive_db = {i:db[item[0]] for i, item in enumerate(db.items()) if item[1]['slug'] in current_slugs}      
    archive_slugs = [v['slug'] for v in archive_db.values()]
    new_files = [(str(path), slug) for path, slug in filepaths if slug not in archive_slugs and verify_image(path)]
    
    return(archive_db, new_files)

The `db_loader` takes a location and returns either the archive dictionary or an empty dictionary. Decomposed to its own function so it can be called separately from `archive_loader` or `queryFlow`. 

In [None]:
#export
def db_loader(dbpath, device):

    # check for savefile or backup and extract
    if Path(dbpath).exists():
        db = torch.load(dbpath, device)
#     elif dbpath_backup.exists():
#         db = torch.load(dbpath_backup)
    else:
        db = {}
    return(db)

The library `annoy`, [Approximate Nearest Neighbors Oh Yeah!](https://github.com/spotify/annoy) allows us to search through vector space for approximate matches instead of exact best-similarity matches. We sacrifice accuracy for speed, so we can search through tens of thousands of images in less than a thousand times the time it would take to search through tens of images. There's got to be a better way to put that.

In [None]:
#export
from annoy import AnnoyIndex

In [None]:
#export
def treemap_loader(treepath):
    treemap = AnnoyIndex(512, 'angular')

    if treepath.exists():
        treemap.load(str(treepath))
    else:
        treemap = None
    return(treemap)

In [None]:
treepath = Path('images/memery.ann')

In [None]:
treemap = AnnoyIndex(512, 'angular')

In [None]:
if treepath.exists():
    treemap.load(str(treepath))
else:
    treemap = None

Here we just test on the local image folder

In [None]:
archive_db, new_files = archive_loader(get_image_files(root), root, device)

85it [00:00, 53092.46it/s]


Skipping bad file: images/corrupted-file.jpeg
due to <class 'PIL.UnidentifiedImageError'>
Skipping bad file: images/.ipynb_checkpoints/corrupted-file-checkpoint.jpeg
due to <class 'PIL.UnidentifiedImageError'>


In [None]:
len(archive_db), len(new_files), treemap.get_n_items()

(80, 0, 80)

In [None]:
dbpath = root/'memery.pt'
#     dbpath_backup = root/'memery.pt'
db = db_loader(dbpath, device)

current_slugs = [slug for path, slug in filepaths]    

In [None]:
archive_db = {i:db[item[0]] for i, item in enumerate(db.items()) if item[1]['slug'] in current_slugs}  

In [None]:
len(archive_db)

80