In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from plantclef.spark import get_spark

spark = get_spark()
display(spark)

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/02/18 16:51:32 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/02/18 16:51:33 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).
25/02/18 16:51:35 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


In [3]:
from pathlib import Path

home = Path.home()
subset_folder_name = "subset_top20_train_embeddings"
subset_train_path = home / f"shared/plantclef/data/embeddings/{subset_folder_name}/data"
df = spark.read.parquet(str(subset_train_path))

                                                                                

In [4]:
df.printSchema()
df.show(5)

root
 |-- image_name: string (nullable = true)
 |-- species_id: integer (nullable = true)
 |-- cls_embedding: array (nullable = true)
 |    |-- element: float (containsNull = true)
 |-- sample_id: integer (nullable = true)



                                                                                

+--------------------+----------+--------------------+---------+
|          image_name|species_id|       cls_embedding|sample_id|
+--------------------+----------+--------------------+---------+
|a7d82ce1a990d21f3...|   1743246|[0.10556483, -0.0...|        0|
|3fa3f94a36e33331b...|   1394624|[0.17451216, 0.40...|        0|
|f4dedf4fadf6c67d1...|   1394624|[0.61694634, 0.40...|        0|
|5d0a8ae90e19c3c1d...|   1359162|[0.2670945, 0.218...|        0|
|b52775178eefe5558...|   1359162|[0.98361766, 0.75...|        0|
+--------------------+----------+--------------------+---------+
only showing top 5 rows



In [4]:
pandas_df = df.toPandas()
pandas_df.head()

                                                                                

Unnamed: 0,image_name,species_id,cls_embedding,sample_id
0,a7d82ce1a990d21f3b77a6c1c650d42ffda3c40a.jpg,1743246,"[0.10556483, -0.016215105, -0.42823476, 0.6918...",0
1,3fa3f94a36e33331be9989f0a1d251f292153fec.jpg,1394624,"[0.17451216, 0.40098402, -0.5628805, 0.2425829...",0
2,f4dedf4fadf6c67d1df8e5bae2c6c2352909bcec.jpg,1394624,"[0.61694634, 0.40478063, 1.1805801, -0.2453794...",0
3,5d0a8ae90e19c3c1d94a96d67279700980506e5e.jpg,1359162,"[0.2670945, 0.21836717, -0.10214946, 0.8792165...",0
4,b52775178eefe55586eea365227c31e39c034f74.jpg,1359162,"[0.98361766, 0.758035, 0.2368648, 0.39937067, ...",0


In [5]:
import numpy as np
import faiss

In [15]:
dim = np.array(df.select("cls_embedding").first()[0]).shape[0]
dim

768

In [16]:
index = faiss.IndexFlatL2(dim)

In [7]:
# def add_to_index(embedding: np.ndarray, index: faiss.IndexFlatL2):
#     embedding = embedding.astype("float32")
#     faiss.normalize_L2(embedding)
#     index.add(embedding)

In [6]:
embeddings = np.stack(pandas_df["cls_embedding"].values).astype("float32")
embeddings.shape

(9124, 768)

In [17]:
import time

t0 = time.time()
faiss.normalize_L2(embeddings)
index.add(embeddings)
print(f"Indexed {embeddings.shape[0]} embeddings in {time.time() - t0:.2f} seconds")


Indexed 9124 embeddings in 0.14 seconds


In [18]:
index_path_root = home / "scratch/plantclef/data"
index_path_root.mkdir(parents=True, exist_ok=True)
index_path = index_path_root / f"{subset_folder_name}.index"

In [None]:
faiss.write_index(index, str(index_path))

In [24]:
index = faiss.read_index(str(index_path))

t0 = time.time()
k = 5
num_queries = 10
d, i = index.search(embeddings[:num_queries], k)
print(f"Found {k} nearest neighbors for {num_queries} queries in {time.time() - t0:.2f} seconds")

Found 5 nearest neighbors for 10 queries in 0.07 seconds


In [20]:
d

array([[0.        , 0.30579454, 0.3433559 , 0.45759228, 0.5765661 ],
       [0.        , 0.3261646 , 0.35298112, 0.36710143, 0.37554586],
       [0.        , 0.33431265, 0.34244055, 0.38167328, 0.3944954 ],
       [0.        , 0.51580256, 0.5725595 , 0.59718996, 0.6022958 ],
       [0.        , 0.31769183, 0.35711777, 0.35823455, 0.36029816],
       [0.        , 0.91882014, 0.927521  , 0.98016006, 1.0213945 ],
       [0.        , 0.49491674, 0.6042012 , 0.6203002 , 0.66143185],
       [0.        , 0.47169265, 0.5406811 , 0.5412945 , 0.5460907 ],
       [0.        , 0.2578891 , 0.27204877, 0.28260392, 0.28792253],
       [0.        , 0.40747592, 0.46009785, 0.48361915, 0.4896986 ]],
      dtype=float32)

In [19]:
i

array([[   0, 2751, 2591, 3135, 3429],
       [   1, 2820, 3020,  212, 5262],
       [   2,   86,   14,  265, 2864],
       [   3, 2978, 3145, 4207, 7203],
       [   4, 3129,  418, 2678,   28],
       [   5, 1250,  131,   89,  111],
       [   6,  162, 2939, 7210, 3341],
       [   7, 4214, 7198,  180, 2748],
       [   8,  442, 1245,   86,  522],
       [   9, 2649,  146, 2948,  602]])

In [34]:
species_ids_array = pandas_df["species_id"].values
pred_species = species_ids_array[i]
pred_species

array([[1743246, 1743246, 1743246, 1743246, 1743246],
       [1394624, 1394624, 1394624, 1394624, 1394624],
       [1394624, 1394624, 1394624, 1394624, 1394624],
       [1359162, 1359162, 1359162, 1359162, 1359162],
       [1359162, 1359162, 1359162, 1359162, 1359162],
       [1359162, 1359162, 1359162, 1359162, 1359162],
       [1743246, 1743246, 1743246, 1743246, 1743246],
       [1743246, 1743246, 1743246, 1743246, 1743246],
       [1394624, 1394624, 1394624, 1394624, 1394624],
       [1394624, 1394624, 1394624, 1394624, 1394624]], dtype=int32)