# Testing the embedding processsing methods
find_similar_qdrant, find_representative, leverage_OOD, find_mismatches

In [None]:
import cv2
import matplotlib.pyplot as plt
import numpy as np

from luxonis_ml.embeddings.methods.duplicate import find_similar_qdrant
from luxonis_ml.embeddings.methods.mistakes import find_mismatches_centroids
from luxonis_ml.embeddings.methods.OOD import leverage_OOD
from luxonis_ml.embeddings.methods.representative import (
    calculate_similarity_matrix,
    find_representative_kmedoids,
)
from luxonis_ml.embeddings.utils.qdrant import Distance, QdrantAPI, QdrantManager

In [None]:
# Start Qdrant docker container
QdrantManager("qdrant/qdrant", "qdrant_container2").start_docker_qdrant()

# Connect to Qdrant
qdrant_api = QdrantAPI("localhost", 6333, "mnist3")

# Create a collection
qdrant_api.create_collection(vector_size=2048, distance=Distance.COSINE)

### Find representative images

In [None]:
ids = qdrant_api.get_all_ids()

In [None]:
# ids, embs, res = qdrant_api.get_full_similarity_matrix()

In [None]:
ids, embeddings = qdrant_api.get_all_embeddings()

In [None]:
similarity_matrix = calculate_similarity_matrix(embeddings)

In [None]:
desired_size = int(len(embeddings) * 0.05)
# desired_size = 10
selected_image_indices = find_representative_kmedoids(similarity_matrix, desired_size)
# selected_image_indices = find_representative_greedy_qdrant(qdrant_client, desired_size, 0, "mnist3")

In [None]:
ids_sel = np.array(ids)[selected_image_indices].tolist()
payloads = qdrant_api.get_payloads_from_ids(ids_sel)

In [None]:
represent_imgs = [p["image_path"] for p in payloads]
len(represent_imgs)

In [None]:
# set plt size
plt.rcParams["figure.figsize"] = [30, 10]

for j in range(min(10, len(represent_imgs))):
    plt.subplot(1, 10, j + 1)
    img = cv2.imread(represent_imgs[j])
    plt.imshow(img)

plt.show()

### Out-of-distribution detection

In [None]:
idx = leverage_OOD(np.array(embeddings))

In [None]:
ids_sel = np.array(ids)[idx].tolist()
payloads = qdrant_api.get_payloads_from_ids(ids_sel)

In [None]:
outlier_imgs = [p["image_path"] for p in payloads]

In [None]:
plt.rcParams["figure.figsize"] = [30, 10]

for j in range(min(10, len(outlier_imgs))):
    plt.subplot(1, 10, j + 1)
    img = cv2.imread(outlier_imgs[j])
    plt.imshow(img)

plt.show()

### Find similar images

In [None]:
i_sim, path_sim = find_similar_qdrant(
    ids[4],
    qdrant_api,
    dataset="",
    k=100,
    n=100,
    method="first",
    k_method="kde_peaks",
    kde_bw="scott",
    plot=True,
)

In [None]:
plt.rcParams["figure.figsize"] = [30, 10]

for j in range(min(10, len(path_sim))):
    plt.subplot(1, 10, j + 1)
    img = cv2.imread(path_sim[j])
    plt.imshow(img)

plt.show()

### Find mismatches

In [None]:
ids, embeddings = qdrant_api.get_all_embeddings()
payloads = qdrant_api.get_payloads_from_ids(ids)

In [None]:
X = np.array(embeddings)
y = np.array([p["class"] for p in payloads])

In [None]:
mis_ix, new_y = find_mismatches_centroids(X, y)

In [None]:
# find img paths for misclassified images
mis_img_paths = [payloads[i]["image_path"] for i in mis_ix]

In [None]:
# plot
plt.rcParams["figure.figsize"] = [30, 10]

for j in range(min(5, len(mis_img_paths))):
    plt.subplot(1, 10, j + 1)
    img = cv2.imread(mis_img_paths[j])
    plt.title(f"True: {y[mis_ix[j]]}, Pred: {new_y[j]}")
    plt.imshow(img)

plt.show()