# 準備

In [None]:
# 必要なライブラリのインストール
! pip install pixivpy --upgrade
! pip install scann

In [None]:
# データベース・索引のダウンロード
! wget https://github.com/kosuke1701/illust-search/releases/download/0.0/vectors.sql
! wget https://github.com/kosuke1701/illust-search/releases/download/0.0/scann_save_dir.zip
! unzip scann_save_dir.zip

In [None]:
from getpass import getpass
import io
import sqlite3
import time

from IPython.display import display
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from pixivpy3 import PixivAPI
import scann
import scipy as sp

In [None]:
# Following codes which define `array` type in sqlite3 is copied from the following Stack Overflow:
# https://stackoverflow.com/questions/18621513
# question by:
# Joe Flip (https://stackoverflow.com/users/1715453/joe-flip)
# answered by:
# unutbu (https://stackoverflow.com/users/190597/unutbu)
def adapt_array(arr):
    """
    http://stackoverflow.com/a/31312102/190597 (SoulNibbler)
    """
    out = io.BytesIO()
    np.save(out, arr)
    out.seek(0)
    return sqlite3.Binary(out.read())
def convert_array(text):
    out = io.BytesIO(text)
    out.seek(0)
    return np.load(out)
# Converts np.array to TEXT when inserting
sqlite3.register_adapter(np.ndarray, adapt_array)
# Converts TEXT to np.array when selecting
sqlite3.register_converter("array", convert_array)

In [None]:
# Load database
dim = 500
data_dtype = [("id", int), ("face", int), ("xmin", int), ("xmax", int), \
              ("ymin", int), ("ymax", int), ("vector", np.float32, dim)]

with sqlite3.connect("vectors.sql", detect_types=sqlite3.PARSE_DECLTYPES) as conn:
    c = conn.cursor()

    c.execute("SELECT * FROM face ORDER BY id")
    _data = c.fetchall()
    data = np.array(_data, dtype=data_dtype)


In [None]:
# Load ScaNN Index
searcher = scann.scann_ops_pybind.load_searcher("scann_save_dir")

In [None]:
# イラストが既に削除されている場合はAssertionErrorが返される
def download(id, username, password):
    aapi = PixivAPI()
    aapi.login(username, password)
    illust = aapi.works(id)
    assert illust["status"] == "success", "AppPixivAPI.works({}) returns failure status.".format(id)
    illust = illust.response[0]
    #print(illust)
    url = illust["image_urls"]["large"]
    file_type = url.split(".")[-1]

    # Download illust.
    fn = f"{id}.{file_type}"
    aapi.download(url, path=".", name=fn)

    return fn

In [None]:
# Test
username = getpass("Username:")
password = getpass("Password:")


target_id = 400
i_illust = data["id"][target_id]
xmin, xmax,ymin, ymax = data["xmin"][target_id], data["xmax"][target_id], \
    data["ymin"][target_id], data["ymax"][target_id]
fn = download(i_illust, username, password)
im = Image.open(fn)
im = im.crop((xmin, ymin, xmax, ymax))
im = im.resize((128,128))

display(im)
#print(data["vector"][target_id])

print("Display similar 5 faces.")
neighbor_ids, _ = searcher.search(data["vector"][target_id], final_num_neighbors=100)
counter = 0
for nei_id in neighbor_ids:
    i_illust = data["id"][nei_id]
    xmin, xmax,ymin, ymax = data["xmin"][nei_id], data["xmax"][nei_id], \
        data["ymin"][nei_id], data["ymax"][nei_id]
    try:
        fn = download(i_illust, username, password)
    except AssertionError:
        print(f"No image found: {nei_id}\t{i_illust}")
        continue
    im = Image.open(fn)
    im = im.crop((xmin, ymin, xmax, ymax))
    im = im.resize((128,128))
    display(im)
    counter += 1
    if counter == 5:
        break

In [None]:
# うまく動かなかった
class CharacterRetriever(object):
    def __init__(self, data, searcher, dim=500):
        self.dim = 500

        self.data = data
        self.searcher = searcher

        self.username = getpass("Username:")
        self.password = getpass("Password:")

        self.feedback_vectors = [] # (positive - negative)のベクトルのリスト
    
    def display(self, i_data):
        i_illust = self.data["id"][i_data]
        xmin, xmax,ymin, ymax = self.data["xmin"][i_data], self.data["xmax"][i_data], \
            self.data["ymin"][i_data], self.data["ymax"][i_data]
        fn = download(i_illust, self.username, self.password)
        im = Image.open(fn)
        im = im.crop((xmin, ymin, xmax, ymax))
        im = im.resize((128,128))

        display(im)

        return self.data["vector"][i_data]
    
    def query(self, q_vector, n_neighbor=5):
        neighbor_ids, _ = self.searcher.search(q_vector, final_num_neighbors=100)
        return neighbor_ids
    
    def generate_random_query(self):
        """
        x*(pos_i-neg_i) > 0の部分空間からサンプリングする
        厳密にやるのは面倒なので、w_i*(pos_i-neg_i) + kernel (w_i>0, kernelはpos_i-neg_iの行列のカーネル)の和としてみる
        """
        if len(self.feedback_vectors) == 0:
            q_vec = np.random.uniform(size=self.dim)
        else:
            if len(self.feedback_vectors) == 1:
                f_mat = self.feedback_vectors[0][np.newaxis,:]
            else:
                f_mat = np.concatenate([mat[np.newaxis,:] for mat in self.feedback_vectors], axis=0)
            kernel = sp.linalg.null_space(f_mat) # (dim x null_dim)
            if kernel.shape[1] == 0:
                kernel_vec = np.zeros(size=self.dim)
            else:
                kernel_vec = np.sum(np.random.uniform(size=(1,kernel.shape[1])) * kernel, axis=1)

            f_vec = np.sum(np.random.uniform(size=(f_mat.shape[0], 1)) * f_mat, axis=0)

            q_vec = kernel_vec + f_vec
        
        return q_vec


In [None]:
retriever = CharacterRetriever(data, searcher)

while True:
    vectors = []
    for i in range(5):
        random_q = retriever.generate_random_query()
        ids = retriever.query(random_q)[1:]
        for id in ids:
            try:
                vec = retriever.display(id)
            except AssertionError:
                time.sleep(1)
                continue
            vectors.append(vec)
            break
    
    while True:
        try:
            i_best = int(input("Please select best image: "))
        except Exception:
            continue
        else:
            break
    
    for i_vec in range(len(vectors)):
        if i_vec == i_best:
            continue
        retriever.feedback_vectors.append(vectors[i_best] - vectors[i_vec])
        