In [None]:
from fastai.vision.all import *

from utils import get_embedding, label_func

from importlib import reload
import search
reload(search)
from search import FaissImageSearch, search_from_path, search_accuracy, plot_results, calculate_rejection_accuracy

In [None]:
learn = load_learner("models/exported_resnext50_32x4d.pickle")

In [None]:
searcher = FaissImageSearch(learn)

In [None]:
enroll_paths = get_image_files("dogs/recognition/enroll/")

In [None]:
enroll_class_names = [label_func(p) for p in enroll_paths]

In [None]:
enroll_imgs = [PILImage.create(p) for p in enroll_paths]

In [None]:
searcher.enroll_many(enroll_imgs, enroll_class_names)

In [None]:
searcher.dump("models")

In [None]:
test_paths = get_image_files("dogs/recognition/test/")

In [None]:
test_class_names = [label_func(p) for p in test_paths]

In [None]:
test_imgs = [PILImage.create(p) for p in test_paths]

In [None]:
# test_embeddings = get_embedding(searcher.learn, searcher.embedder, test_imgs)
# Path("cache").mkdir(exist_ok=True)
# pickle.dump(test_embeddings, open("cache/test_embeddings.pickle", "wb"))
test_embeddings = pickle.load(open("cache/test_embeddings.pickle", "rb"))

In [None]:
search_from_path(searcher, test_paths[100])

In [None]:
# calculate test accuracy for some values of hyperparameter k
# k is the number of nearest neighbors used when searching
for k in [1, 3, 5, 7, 9]:
    acc, dist_all, dist_correct, dist_incorrect, dist_empty = search_accuracy(searcher, test_imgs, test_embeddings, test_class_names, k=k)
    print(f"k: {k}  acc: {acc:.04f}")

## Plot distances

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

In [None]:
plot_results(searcher, test_imgs, test_embeddings, test_class_names)

## Find suitable distance threshold

In [None]:
k = 5
accs = []
threshs = np.linspace(0.7, 0.8, num=50, endpoint=False)
for thresh in threshs:
    acc, dist_all, dist_correct, dist_incorrect, dist_empty = search_accuracy(
        searcher,
        test_imgs,
        test_embeddings,
        test_class_names,
        k=k,
        threshold=thresh
    )
    accs.append(acc)
    print(f"k: {k}  thresh: {thresh:.04f}  acc: {acc:.04f}")
plt.plot(threshs, accs)
plt.show()

## Search for unknown breed

In [None]:
unknown_breeds = [
    "n02111500-Great_Pyrenees",
    "n02111889-Samoyed",
    "n02113186-Cardigan",
    "n02113978-Mexican_hairless",
    "n02116738-African_hunting_dog",
]

In [None]:
enroll_paths = get_image_files("dogs/recognition/enroll/")
enroll_paths = [p for p in enroll_paths if p.parent.name not in unknown_breeds]

In [None]:
unknown_paths = get_image_files("dogs/recognition/unknown/test")
unknown_class_names = [label_func(p) for p in unknown_paths]
unknown_imgs = [PILImage.create(p) for p in unknown_paths]
unknown_embeddings = get_embedding(searcher.learn, searcher.embedder, unknown_imgs)
pickle.dump(unknown_embeddings, open("cache/unknown_embeddings.pickle", "wb"))
# unknown_embeddings = pickle.load(open("cache/unknown_embeddings.pickle", "rb"))

In [None]:
plot_results(unknown_imgs, unknown_embeddings, unknown_class_names)

In [None]:
search_from_path(searcher, unknown_paths[0])

In [None]:
calculate_rejection_accuracy(unknown_imgs, unknown_embeddings, unknown_class_names)