In [3]:
import os
import shutil
from cv2 import cv2
import h5py
import numpy as np
import seaborn as sns
from imutils import paths
sns.set()

In [4]:
# import the necessary keras/sklearn packages
import sklearn
from sklearn.neighbors import NearestNeighbors
from tensorflow.keras.models import Model, load_model

In [5]:
# load the hdf5 dataset
data_dir = r"C:\Users\mhasa\GDrive\mvcnn"
signature_db = h5py.File(name=f"{data_dir}//data_sig_col_roi_28px_255.hdf5",
                         mode="r")
extracted_features = signature_db['extracted_features'][:]
print(extracted_features.shape)

# load the model
loaded_model = load_model(f"{data_dir}//model_mvcnn_color_roi_10class_28px1px_255_minvgg.h5")
flatten_layer = loaded_model.layers[16]
retrieval_model = Model(inputs=loaded_model.input, outputs=flatten_layer.output)
retrieval_model.summary()

(114458, 3136)
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_input (InputLayer)    [(None, 28, 28, 1)]       0         
_________________________________________________________________
conv2d (Conv2D)              (None, 28, 28, 32)        320       
_________________________________________________________________
activation (Activation)      (None, 28, 28, 32)        0         
_________________________________________________________________
batch_normalization (BatchNo (None, 28, 28, 32)        128       
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 28, 28, 32)        9248      
_________________________________________________________________
activation_1 (Activation)    (None, 28, 28, 32)        0         
_________________________________________________________________
batch_normalization_1 (Batch (None, 28, 28, 32

In [6]:
sorted(sklearn.neighbors.VALID_METRICS['kd_tree'])

['chebyshev',
 'cityblock',
 'euclidean',
 'infinity',
 'l1',
 'l2',
 'manhattan',
 'minkowski',
 'p']

In [7]:
# do kdtree nn
nn = NearestNeighbors(n_neighbors=50,
                      algorithm='kd_tree',
                      metric='l2').fit(extracted_features)

In [9]:
# read the sample image and get its signature
img = cv2.imread(f"{data_dir}//headless_screws.png", cv2.IMREAD_GRAYSCALE)
img = img.astype('float32')
img = img / 255.0

# channel dim and batch dim since we doing feature extraction
img = np.expand_dims(img, axis=-1)
img = np.expand_dims(img, axis=0)

source_part_sig = retrieval_model.predict(img)

In [10]:
dist, index = nn.kneighbors(source_part_sig)

In [11]:
 # get he neighbor names
indices = np.sort(index)[0]
neighbor_names = signature_db['label_names'][indices]
neighbor_names = np.array([n.decode() for n in neighbor_names])

In [12]:
target_dir = r"C:\Users\mhasa\Desktop\neighbors"

pristine_retrieve_models_dir = r"C:\Users\mhasa\Desktop\retrieval_models"
image_paths = list(paths.list_images(pristine_retrieve_models_dir))

for path in image_paths:
    image_name = path.split(os.path.sep)[-1]

    if image_name in list(neighbor_names):
        shutil.copy2(src=path,
                     dst=f"{target_dir}//{image_name}")