In [None]:
from fastai.vision.all import *
from enroll import NaiveImageSearch, FaissImageSearch
from utils import label_func, get_embedding

In [None]:
learn = load_learner("dogs/train/exported_resnext50_32x4d.pickle")

In [None]:
# searcher = NaiveImageSearch(learn)
searcher = FaissImageSearch(learn)

In [None]:
enroll_paths = get_image_files("dogs/recognition/enroll/")

In [None]:
enroll_class_names = [label_func(p) for p in enroll_paths]

In [None]:
enroll_imgs = [PILImage.create(p) for p in enroll_paths]

In [None]:
searcher.enroll_many(enroll_imgs[:100], enroll_class_names[:100])

In [None]:
searcher.enroll_many(enroll_imgs[100:], enroll_class_names[100:])

In [None]:
test_paths = get_image_files("dogs/recognition/test/")

In [None]:
test_class_names = [label_func(p) for p in test_paths]

In [None]:
test_imgs = [PILImage.create(p) for p in test_paths]

In [None]:
def search_faiss(searcher, path, k=3):
    img = PILImage.create(path)
    ds, ixs, names = searcher.search(img, k=k)
    return ds, ixs, names

In [None]:
search_faiss(searcher, test_paths[200])

In [None]:
def search(searcher, path, k=3):
    img = PILImage.create(path)
    ds, ixs, names = searcher.search(img, k=k)
    print(list(zip(ds, ixs, names)))
    print(f"query: {label_func(path)}")
    show_image(PILImage.create(path))
    for i in range(k):
        print(f"result: {names[i]}")
        show_image(searcher.imgs[ixs[i]])

In [None]:
# measure accuracy on test set for k=1 nearest neighbor
def search_accuracy(imgs, embeddings, class_names):
    distances_correct = []
    distances_incorrect = []
    distances_all = []
    correct = 0
    for i in range(len(imgs)):
        if i % 100 == 0:
            print(i)
        ds, ixs, names = searcher.search_from_vector(np.expand_dims(embeddings[i], axis=0), k=1)
        if names[0] == class_names[i]:
            correct += 1
            distances_correct.append(ds[0])
        else:
            distances_incorrect.append(ds[0])
        distances_all.append(ds[0])
    acc = float(correct) / len(imgs)
    print(f"acc: {acc}")
    distances_all = np.array(distances_all)
    distances_correct = np.array(distances_correct)
    distances_incorrect = np.array(distances_incorrect)
    return acc, distances_all, distances_correct, distances_incorrect

In [None]:
# test_embeddings = get_embedding(searcher.learn, searcher.embedder, test_imgs)
# pickle.dump(test_embeddings, open("test_embeddings.pickle", "wb"))
test_embeddings = pickle.load(open("test_embeddings.pickle", "rb"))

In [None]:
acc, distances_all, distances_correct, distances_incorrect = search_accuracy(test_imgs, test_embeddings, test_class_names)

## Plot distances

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

In [None]:
def plot_results(imgs, embeddings, class_names):
    acc, distances_all, distances_correct, distances_incorrect = search_accuracy(imgs, embeddings, class_names)
    plt.plot(distances_all)
    plt.axis([0, len(distances_all), 0, 1.0])
    plt.show()

In [None]:
plot_results(test_imgs, test_embeddings, test_class_names)

## Search for unknown breed

In [None]:
unknown_paths = get_image_files("dogs/recognition/unknown/test")
unknown_class_names = [label_func(p) for p in unknown_paths]
unknown_imgs = [PILImage.create(p) for p in unknown_paths]
unknown_embeddings = get_embedding(searcher.learn, searcher.embedder, unknown_imgs)
pickle.dump(unknown_embeddings, open("unknown_embeddings.pickle", "wb"))
# unknown_embeddings = pickle.load(open("unknown_embeddings.pickle", "rb"))

In [None]:
plot_results(unknown_imgs, unknown_embeddings, unknown_class_names)

In [None]:
search_faiss(searcher, unknown_paths[0])