In [None]:
from fastai.vision.all import *

from utils import get_embedding, label_func

from importlib import reload
import search
reload(search)
from search import FaissImageSearch, search_from_path, search_accuracy, plot_results, calculate_rejection_accuracy

import matplotlib.pyplot as plt
plt.ion()

In [None]:
learn = load_learner("models/exported_resnext50_32x4d.pickle")
searcher = FaissImageSearch(learn)

In [None]:
enroll_paths = get_image_files("dogs/recognition/enroll/")
enroll_class_names = [label_func(p) for p in enroll_paths]
enroll_imgs = [PILImage.create(p) for p in enroll_paths]

In [None]:
searcher.enroll_many(enroll_imgs, enroll_class_names)
searcher.dump("models")

In [None]:
test_paths = get_image_files("dogs/recognition/test/")
test_class_names = [label_func(p) for p in test_paths]
test_imgs = [PILImage.create(p) for p in test_paths]

In [None]:
# test_embeddings = get_embedding(searcher.learn, searcher.embedder, test_imgs)
# Path("cache").mkdir(exist_ok=True)
# pickle.dump(test_embeddings, open("cache/test_embeddings.pickle", "wb"))
test_embeddings = pickle.load(open("cache/test_embeddings.pickle", "rb"))

In [None]:
# search_from_path(searcher, test_paths[100])

In [None]:
# calculate test accuracy for some values of hyperparameter k
# k is the number of nearest neighbors used when searching
ks = [1,3,5,7,9]
test_accs = []
for k in ks:
    acc, dist_all, dist_correct, dist_incorrect, dist_empty = search_accuracy(
        searcher, test_embeddings, test_class_names, k=k, threshold=0.78)
    test_accs.append(acc)
    print(f"k: {k}  test_acc: {acc:.04f}")
plt.plot(ks, test_accs)
plt.xlabel("k")
plt.ylabel("accuracy")
plt.title("Test accuracy as function of k")
plt.show()

## Plot distances

In [None]:
plot_results(searcher, test_imgs, test_embeddings, test_class_names)

## Find suitable distance threshold

In [None]:
# k = 5
# accs = []
# # threshs = np.linspace(0.7, 0.8, num=50, endpoint=False)
# threshs = [0.85]
# for thresh in threshs:
#     acc, dist_all, dist_correct, dist_incorrect, dist_empty = search_accuracy(
#         searcher,
#         test_imgs,
#         test_embeddings,
#         test_class_names,
#         k=k,
#         threshold=thresh
#     )
#     accs.append(acc)
#     print(f"k: {k}  thresh: {thresh:.04f}  acc: {acc:.04f}")
# plt.plot(threshs, accs)
# plt.show()

## Search for unknown breed

In [None]:
print("cuda.is_available: {}".format(torch.cuda.is_available()))
set_seed(1234)

In [None]:
def load_data(path):
    # for each image file, parent folder's name will be used as label
    dls = ImageDataLoaders.from_folder(
        path,
        valid_pct=0.2,
        item_tfms=Resize(224),
        bs=64,
    )
    return dls
    
path = "dogs/train/"

dls = load_data(path)

In [None]:
unknown_paths = dls.valid.items
unknown_class_names = [label_func(p) for p in unknown_paths]
unknown_imgs = [PILImage.create(p) for p in unknown_paths]
# unknown_embeddings = get_embedding(searcher.learn, searcher.embedder, unknown_imgs)
# pickle.dump(unknown_embeddings, open("cache/unknown_embeddings.pickle", "wb"))
unknown_embeddings = pickle.load(open("cache/unknown_embeddings.pickle", "rb"))

In [None]:
plot_results(searcher, unknown_imgs, unknown_embeddings, unknown_class_names, k=19, threshold=0.0)

In [None]:
search_from_path(searcher, unknown_paths[0]), unknown_paths[0]

In [None]:
ks = [1,3,5,7,9]
rejection_accs = []
for k in ks:
    acc = calculate_rejection_accuracy(searcher, unknown_embeddings, k=k, threshold=0.78)
    rejection_accs.append(acc)
    print(f"k: {k}  rejection_acc: {acc:.04f}")
plt.plot(ks, rejection_accs)
plt.xlabel("k")
plt.ylabel("accuracy")
plt.title("Rejection accuracy as function of k")
plt.show()

## Plot test and rejection accuracy

In [None]:
k = 5
thresholds = np.linspace(0.75, 0.85, 11)
# rejections_accs = []
# test_accs = []
# for threshold in thresholds:
#     rejection_acc = calculate_rejection_accuracy(searcher, unknown_embeddings, k=k, threshold=threshold)
#     test_acc, _, _, _, _ = search_accuracy(
#         searcher, test_embeddings, test_class_names, k=k, threshold=threshold)
#     rejection_accs.append(rejection_acc)
#     test_accs.append(test_acc)
rejection_accs = [calculate_rejection_accuracy(searcher, unknown_embeddings, k=k, threshold=threshold) for threshold in thresholds]
test_accs = [search_accuracy(searcher, test_embeddings, test_class_names, k=k, threshold=threshold)[0] for threshold in thresholds]

In [None]:
plt.plot(thresholds, rejection_accs, 'r', label='rejection accuracy')
plt.plot(thresholds, test_accs, 'b', label='test accuracy')
plt.xlabel("distance threshold")
plt.ylabel("accuracy")
plt.legend()
plt.show()

In [None]:
plt.plot(thresholds, np.array(rejection_accs)*np.array(test_accs))
plt.show()