In [1]:
import numpy as np
from numpy.linalg import norm
from sklearn.neighbors import NearestNeighbors

import os

import tensorflow as tf
from tensorflow.keras.preprocessing import image
# from tensorflow.keras.applications.resnet50 import ResNet50
import tensorflow.keras.applications.resnet50 as resnet50

In [2]:
resnet50_model = resnet50.ResNet50(weights='imagenet',
                include_top=True,
                input_shape=(224, 224, 3))

In [3]:
extensions = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG']

def get_file_list(root_dir):
    file_list = []
    for root, directories, filenames in os.walk(root_dir):
        for filename in filenames:
            if any(ext in filename for ext in extensions):
                file_list.append(os.path.join(root, filename))
    return file_list

In [4]:
def extract_features(img_path, model):
    input_shape = (224, 224, 3)
    img = image.load_img(img_path,
                        target_size=(input_shape[0], input_shape[1]))
    img_array = image.img_to_array(img)
    expanded_img_array = np.expand_dims(img_array, axis=0)
    preprocessed_img = resnet50.preprocess_input(expanded_img_array)
    features = model.predict(preprocessed_img)
    flattened_features = features.flatten()
    normalized_features = flattened_features / norm(flattened_features)
    return normalized_features

In [5]:
test_filenames = sorted(get_file_list('./cat_test_images'))
test_filenames

['./cat_test_images\\jjokgo1.jpg',
 './cat_test_images\\jjokgo2.jpg',
 './cat_test_images\\jjokgo3.jpg',
 './cat_test_images\\jjokgo4.jpg',
 './cat_test_images\\jjokgo5.jpg',
 './cat_test_images\\mango1.jpg',
 './cat_test_images\\mango2.jpg',
 './cat_test_images\\mango3.jpg',
 './cat_test_images\\mango4.jpg',
 './cat_test_images\\mango5.jpg',
 './cat_test_images\\sabum1.jpg',
 './cat_test_images\\sabum2.jpg',
 './cat_test_images\\sabum3.jpg',
 './cat_test_images\\sabum4.jpg',
 './cat_test_images\\sabum5.jpg',
 './cat_test_images\\samsak1.jpg',
 './cat_test_images\\samsak2.jpg',
 './cat_test_images\\samsak3.jpg',
 './cat_test_images\\samsak4.jpg',
 './cat_test_images\\samsak5.jpg',
 './cat_test_images\\sango1.jpg',
 './cat_test_images\\sango2.jpg',
 './cat_test_images\\sango3.jpg',
 './cat_test_images\\sango4.jpg',
 './cat_test_images\\sango5.jpg',
 './cat_test_images\\yuksi1.jpg',
 './cat_test_images\\yuksi2.jpg',
 './cat_test_images\\yuksi3.jpg',
 './cat_test_images\\yuksi4.jpg',
 './

In [46]:
import re

def catname(idx):
    return re.sub(r"[0-9]", '', test_filenames[idx].split("\\")[1].split(".")[0])

In [51]:
def scoring(neighbors, distances, n):
    result_score = 0
    
    for test_num in range(len(test_filenames)):
        cur_score = n
        cur_cat = catname(test_num)
        prediction = dict()
        print(cur_cat)
        
        for item in neighbors.kneighbors([distances[test_num]], return_distance=False)[0]:
            if cur_score == n:
                cur_score -= 1
                continue
            
            if catname(item) in prediction:
                prediction[catname(item)] += cur_score
                print(catname(item) + " " + str(cur_score))
            else:
                prediction[catname(item)] = cur_score
                print(catname(item) + " " + str(cur_score))
                
            cur_score -= 1
        
#         print(prediction)
        rank = sorted(prediction.items(), key=(lambda x: x[1]), reverse=True)
        print(rank)

In [52]:
result =  []
n_neighbors = 6

for filename in test_filenames:
    result.append(extract_features(filename, resnet50_model))

neighbors = NearestNeighbors(n_neighbors=n_neighbors,
                            algorithm='brute',
                            metric='euclidean').fit(result)

# minkowski, euclidean, mahalanobis

scoring(neighbors, result, n_neighbors)

jjokgo
0
31
23
32
34
20
[('zado', 10), ('sango', 5)]
jjokgo
1
0
20
31
32
23
[('jjokgo', 5), ('sango', 5), ('zado', 5)]
jjokgo
2
24
4
3
25
30
[('jjokgo', 7), ('sango', 5), ('yuksi', 2), ('zado', 1)]
jjokgo
3
24
2
4
25
30
[('jjokgo', 7), ('sango', 5), ('yuksi', 2), ('zado', 1)]
jjokgo
4
24
25
30
2
20
[('sango', 6), ('yuksi', 4), ('zado', 3), ('jjokgo', 2)]
mango
5
8
14
18
16
7
[('mango', 6), ('samsak', 5), ('sabum', 4)]
mango
6
17
13
22
15
0
[('samsak', 7), ('sabum', 4), ('sango', 3), ('jjokgo', 1)]
mango
7
18
11
10
12
13
[('sabum', 10), ('samsak', 5)]
mango
8
19
33
9
5
29
[('samsak', 5), ('mango', 5), ('zado', 4), ('yuksi', 1)]
mango
9
19
33
8
29
15
[('samsak', 6), ('zado', 4), ('mango', 3), ('yuksi', 2)]
sabum
10
11
12
0
31
32
[('sabum', 9), ('jjokgo', 3), ('zado', 3)]
sabum
11
10
18
12
15
0
[('sabum', 8), ('samsak', 6), ('jjokgo', 1)]
sabum
12
10
11
2
34
32
[('sabum', 9), ('jjokgo', 3), ('zado', 3)]
sabum
13
17
22
15
6
0
[('samsak', 8), ('sango', 4), ('mango', 2), ('jjokgo', 1)]
sabum