In [1]:
from torchvision.models import vgg16

In [2]:
from PIL import Image
from torchvision.transforms import ToTensor

In [3]:
import torch
from torch import nn

In [4]:
from pathlib import Path

In [5]:
import os

In [6]:
import faiss

In [7]:
import numpy as np

In [8]:
import PIL

In [9]:
import sys

In [10]:
import time

In [11]:
from loguru import logger

In [12]:
import h5py

In [13]:
model =vgg16(pretrained=True)

In [14]:
feature_map = nn.Sequential(model.features, model.avgpool)

In [15]:
dirname = "/data/DIV2K/DIV2K_valid_HR/"

In [16]:
def resize_img(original_img: Image, input_size, fill_value=0) -> Image:
    """

    :param original_img:
    :param input_size:
    :param fill_value:
    :return:
    """
    img_w = original_img.width
    img_h = original_img.height
    w, h = input_size
    scale_rate = min(w / img_w, h / img_h)
    new_w = int(img_w * scale_rate)
    new_h = int(img_h * scale_rate)

    img = original_img.resize((new_w, new_h), resample=PIL.Image.BICUBIC)
    _img = Image.new("RGB", (w, h), color=fill_value)
    _img.paste(img, ((w - new_w) // 2, (h - new_h) // 2))
    return _img


In [17]:
def data_iter(dirname, feature_map):
    """
    """
    if not isinstance(dirname, Path):
        dirname = Path(dirname)
    for filename in dirname.iterdir():
        try:
            img = Image.open(filename)
        except Exception as e:
            logger.warning(e)
        else:
            img = resize_img(img, (224, 224))
            name = os.fspath(Path(filename))
            with torch.no_grad():
                data_in = ToTensor()(img).unsqueeze(0)
                data_out = feature_map(data_in).reshape(1, -1).numpy()
            yield name, data_out

In [75]:
fr = h5py.File("features.hdf5", "w")

In [76]:
fr.create_dataset("features", (100, d), dtype="f4")

<HDF5 dataset "features": shape (100, 25088), type "<f4">

In [77]:
dt = h5py.special_dtype(vlen=str) 

In [78]:
fr.create_dataset("filenames", (100, ), dtype=dt)

<HDF5 dataset "filenames": shape (100,), type "|O">

In [79]:
feature_dataset = fr['features']

In [80]:
name_list = fr['filenames']

In [81]:
for i, (name, data_out) in enumerate(data_iter(dirname, feature_map)):
    logger.info(f"add {i}th data: {name}")
    feature_dataset[i] = data_out
    name_list[i] = np.string_(name)

2019-07-19 18:06:35.295 | INFO     | __main__:<module>:2 - add 0th data: /data/DIV2K/DIV2K_valid_HR/0836.png
2019-07-19 18:06:35.465 | INFO     | __main__:<module>:2 - add 1th data: /data/DIV2K/DIV2K_valid_HR/0851.png
2019-07-19 18:06:35.642 | INFO     | __main__:<module>:2 - add 2th data: /data/DIV2K/DIV2K_valid_HR/0856.png
2019-07-19 18:06:35.842 | INFO     | __main__:<module>:2 - add 3th data: /data/DIV2K/DIV2K_valid_HR/0809.png
2019-07-19 18:06:36.026 | INFO     | __main__:<module>:2 - add 4th data: /data/DIV2K/DIV2K_valid_HR/0804.png
2019-07-19 18:06:36.209 | INFO     | __main__:<module>:2 - add 5th data: /data/DIV2K/DIV2K_valid_HR/0816.png
2019-07-19 18:06:36.384 | INFO     | __main__:<module>:2 - add 6th data: /data/DIV2K/DIV2K_valid_HR/0805.png
2019-07-19 18:06:36.557 | INFO     | __main__:<module>:2 - add 7th data: /data/DIV2K/DIV2K_valid_HR/0845.png
2019-07-19 18:06:36.736 | INFO     | __main__:<module>:2 - add 8th data: /data/DIV2K/DIV2K_valid_HR/0813.png
2019-07-19 18:06:36

In [82]:
fr.close()

In [83]:
fr = h5py.File("features.hdf5", "r")

In [84]:
fr['features'][:]

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)

In [85]:
fr['filenames'][0]

'/data/DIV2K/DIV2K_valid_HR/0836.png'

In [86]:
fr.close()

In [20]:
filename = "/data/DIV2K/DIV2K_valid_HR/0803.png"

In [21]:
img = Image.open(filename)
img = resize_img(img, (224, 224))

In [22]:
with torch.no_grad():
    data_in = ToTensor()(img).unsqueeze(0)
    out = feature_map(data_in).reshape(1, -1).numpy()

In [23]:
out

array([[0., 0., 0., ..., 0., 0., 0.]], dtype=float32)

In [24]:
d = out.shape[1]

In [25]:
d

25088

In [26]:
index = faiss.IndexFlatL2(d)

In [27]:
index

<faiss.swigfaiss.IndexFlatL2; proxy of <Swig Object of type 'faiss::IndexFlatL2 *' at 0x7fbf46cf4780> >

In [28]:
index.is_trained

True

In [29]:
name_list = []

In [None]:
def build_index(h5df_path):
    """
    """
    index = faiss.IndexFlatL2(d)
    with h5py.File("features.hdf5", "r") as fr:
        index.add(fr['features'][:])
    

In [30]:
start = time.time()
for i, (name, data_out) in enumerate(data_iter(dirname, feature_map)):
    logger.info(f"add {i}th data: {name}")
    name_list.append(name)
    index.add(data_out)
logger.info(f"duration: {time.time() - start}")

2019-07-19 17:45:59.785 | INFO     | __main__:<module>:3 - add 0th data: /data/DIV2K/DIV2K_valid_HR/0836.png
2019-07-19 17:45:59.956 | INFO     | __main__:<module>:3 - add 1th data: /data/DIV2K/DIV2K_valid_HR/0851.png
2019-07-19 17:46:00.132 | INFO     | __main__:<module>:3 - add 2th data: /data/DIV2K/DIV2K_valid_HR/0856.png
2019-07-19 17:46:00.331 | INFO     | __main__:<module>:3 - add 3th data: /data/DIV2K/DIV2K_valid_HR/0809.png
2019-07-19 17:46:00.514 | INFO     | __main__:<module>:3 - add 4th data: /data/DIV2K/DIV2K_valid_HR/0804.png
2019-07-19 17:46:00.697 | INFO     | __main__:<module>:3 - add 5th data: /data/DIV2K/DIV2K_valid_HR/0816.png
2019-07-19 17:46:00.869 | INFO     | __main__:<module>:3 - add 6th data: /data/DIV2K/DIV2K_valid_HR/0805.png
2019-07-19 17:46:01.039 | INFO     | __main__:<module>:3 - add 7th data: /data/DIV2K/DIV2K_valid_HR/0845.png
2019-07-19 17:46:01.218 | INFO     | __main__:<module>:3 - add 8th data: /data/DIV2K/DIV2K_valid_HR/0813.png
2019-07-19 17:46:01

In [31]:
index.search(out, 4)

(array([[   0.   , 9132.658, 9288.479, 9791.4  ]], dtype=float32),
 array([[13, 44, 99, 52]]))

In [32]:
name_list[0]

'/data/DIV2K/DIV2K_valid_HR/0836.png'