In [1]:
# 测试脚本
# 1.support set每一类的图片输入特征提取网络获取embedding
# 2.计算这一类的平均值，再归一化得到u1,u2,u3
# 3.query输入预训练网络得到特征向量，归一化，然后与u1,u2,u3堆成矩阵M乘积得到余弦相似度

### 计算两个向量的余弦相似度
$$cos(\theta)=\frac{ a \cdot b}{||a||\times||b||} $$

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import os
import random
import glob

# 这里image需要重命名一下
from tensorflow.keras.preprocessing import image as keras_image
from tensorflow.keras.applications.resnet import preprocess_input

2022-08-25 23:15:46.234851: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0


In [3]:
embedding = tf.keras.models.load_model('./latest_new.h5')

2022-08-25 23:15:47.264143: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcuda.so.1
2022-08-25 23:15:47.305537: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1733] Found device 0 with properties: 
pciBusID: 0000:01:00.0 name: NVIDIA GeForce RTX 3090 computeCapability: 8.6
coreClock: 1.695GHz coreCount: 82 deviceMemorySize: 23.70GiB deviceMemoryBandwidth: 871.81GiB/s
2022-08-25 23:15:47.305568: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0
2022-08-25 23:15:47.309129: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcublas.so.11
2022-08-25 23:15:47.309158: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcublasLt.so.11
2022-08-25 23:15:47.310329: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcu



In [4]:
# embedding.summary()

In [5]:
def extractFeat(model,filename):
    # 和训练时一样的数据预处理
    img = keras_image.load_img(filename, target_size= (200, 200) )
    img = keras_image.img_to_array(img)
    img = preprocess_input(img)
    # 扩张维度
    image = np.expand_dims(img, axis=0)
    # 预测
    feat = model.predict(image)
    # L2归一化
    norm_feat = feat[0] / np.linalg.norm(feat[0])
    return norm_feat

In [6]:
# 获取测试的support set 和query

In [7]:
query_classes = glob.glob('dogImages/train/*')[-5:]

In [8]:
query_classes

['dogImages/train/129.Tibetan_mastiff',
 'dogImages/train/130.Welsh_springer_spaniel',
 'dogImages/train/131.Wirehaired_pointing_griffon',
 'dogImages/train/132.Xoloitzcuintli',
 'dogImages/train/133.Yorkshire_terrier']

In [9]:
query_files = []
# 遍历每一类
for class_item in query_classes:
    # 遍历下面所有文件
    query_file_list = glob.glob(class_item + '/*')
    query_files.append(query_file_list)


In [10]:
# query_files

In [66]:
# 遍历每一类，选择其中前两个作为support set计算平均值，第三个当成查询用query
class_u = []
for class_images in query_files:
    # 获取该类图片
    feat_A = extractFeat(embedding,class_images[0]) 
    feat_B = extractFeat(embedding,class_images[1])     
    # 平均值
    feat_avg = (feat_A + feat_B) / 2
    # 归一化
    feat_norm = feat_avg / np.linalg.norm(feat_avg)
    class_u.append(feat_norm)
# 转为numpy数组    
class_u = np.array(class_u)

In [67]:
class_u.shape

(5, 256)

In [68]:
class_u

array([[ 0.0778444 , -0.08980139,  0.01660734, ...,  0.07267957,
         0.1185353 , -0.00660693],
       [ 0.10270932, -0.06751491,  0.01926412, ..., -0.0031268 ,
         0.09171963,  0.05606749],
       [ 0.12412076, -0.12937087,  0.0158099 , ..., -0.01265056,
         0.071591  ,  0.08709649],
       [ 0.10841086, -0.12155013, -0.00542881, ..., -0.02353824,
        -0.00940758,  0.08753707],
       [ 0.10181352, -0.10390382,  0.0849184 , ..., -0.01457242,
         0.06649275,  0.02868202]], dtype=float32)

In [104]:
# 随机选一个query图片
query_img = query_files[4][5]

In [105]:
query_img

'dogImages/train/133.Yorkshire_terrier/Yorkshire_terrier_08319.jpg'

In [106]:
# 获取特征
query_feat = extractFeat(embedding,query_img)

In [107]:
# query_feat

In [108]:
# 计算余弦相似度
class_prob = np.dot(class_u,query_feat)

In [109]:
class_prob

array([0.3200401 , 0.53333116, 0.90293527, 0.7195747 , 0.9551848 ],
      dtype=float32)

In [110]:
# 查询对应类别
query_classes[np.argmax(class_prob)]

'dogImages/train/133.Yorkshire_terrier'

In [21]:
# 测试猫咪

In [22]:
labels = ['Abyssinian','Aegean','British_Shorthair','Donskoy','Persian']

In [123]:
# 遍历每一类
class_u = []
for label in labels:
    # 获取该类图片
    files = glob.glob('./cat/'+label+'/*')
    feats = [extractFeat(embedding,test_file) for test_file in files]
    # 平均值
    feat_avg = (feats[0] + feats[1]) /2 
    # 归一化
    feat_norm = feat_avg / np.linalg.norm(feat_avg)
    class_u.append(feat_norm)
    
class_u = np.array(class_u)

In [124]:
class_u.shape

(5, 256)

In [125]:
query_list = glob.glob('./cat/*_query*')

In [126]:
query_list

['./cat/Abyssinian_query.jpg',
 './cat/Aegean_query.jpg',
 './cat/British_Shorthair_query.jpeg',
 './cat/Donskoy_query.jpeg',
 './cat/Persian_query.jpeg']

In [139]:
query_feat = extractFeat(embedding,query_list[3])

In [140]:
class_prob = np.dot(class_u,query_feat)

In [141]:
class_prob

array([0.83904886, 0.843537  , 0.8892164 , 0.81702244, 0.68522656],
      dtype=float32)

In [142]:
labels[np.argmax(class_prob)]

'British_Shorthair'