# 测试clip向量化+faiss向量搜索在5037类中的分类能力


In [None]:
import random
import time

import torch
from transformers import CLIPProcessor, CLIPModel, CLIPFeatureExtractor

import faiss
from PIL import Image
import os
import json
from collections import Counter
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
with open('./output_all/id2filename.json', 'r') as json_file:
    id2filename = json.load(json_file)
index = faiss.read_index("./output_all/image.faiss")

In [None]:

def most_frequent_element(arr):
    if not arr:
        return None  # 如果数组为空，则返回None
    element_count = Counter(arr)
    # 获取出现次数最多的元素和它的次数
    most_common_element, frequency = element_count.most_common(1)[0]
    return most_common_element, frequency

def get_random_image_path(directory):
    # 存储找到的图片路径
    image_paths = []

    # os.walk遍历目录和子目录
    for root, dirs, files in os.walk(directory):
        for file in files:
            # 检查文件是否是图片，这里以几种常见图片格式为例
            if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
                image_paths.append(os.path.join(root, file))

    # 从列表中随机选择一个图片路径
    if image_paths:
        return random.choice(image_paths)
    else:
        return None

def get_image_path(directory):
    # 存储找到的图片路径
    image_paths = []

    # os.walk遍历目录和子目录
    for root, dirs, files in os.walk(directory):
        for file in files:
            # 检查文件是否是图片，这里以几种常见图片格式为例
            if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
                image_paths.append(os.path.join(root, file))

    # 从列表中随机选择一个图片路径
    return image_paths

def get_image_feature(filename: str):
    image = Image.open(filename).convert("RGB")
    processed = processor(images=image, return_tensors="pt", padding=True, truncation=True)
    with torch.no_grad():
        image_features = model.get_image_features(pixel_values=processed["pixel_values"])
    return image_features


def image_search(image, k=1):

    inputs = processor(images=image, return_tensors="pt")
    image_features = model.get_image_features(**inputs)
    image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True)  # normalize

    image_features = image_features.detach().numpy()
    D, I = index.search(image_features, k)  # 实际的查询

    filenames = [[id2filename[str(j)] for j in i] for i in I]

    return img_path, D, filenames


# 统计正确分类总个数，图片总个数，正确率，按照查询类别最优指标来计算
correct_count: 33608 total_count: 43598 accuracy: 0.7708610486719575

In [None]:
base_images=os.getcwd()
img_path = os.path.join(base_images,"data","Trax_bbox出来的小图含label_20230207")
img_paths=get_image_path(img_path)

# 统计正确分类总个数，图片总个数，正确率，按照查询类别最优指标来计算
correct_count = 0
total_count = 0
for img_path in img_paths:
    image = Image.open(img_path)
    start_time = time.time()
    img_path, D, filenames = image_search(image, k=4)
    total_count=total_count+1
    for i in range(len(filenames)):
        if filenames[0][0] == filenames[0][1]:
            correct_count = correct_count + 1
        else:
            print("img_path:", img_path, "true_name:", filenames[0][0], "pred_name:", filenames[0][1])

print("correct_count:",correct_count,"total_count:",total_count,"accuracy:",correct_count/total_count)


# 统计正确分类总个数，图片总个数，正确率，按照查询类别最优的三个结果中出现次数最多来计算
correct_count: 32877 total_count: 43598 accuracy: 0.7540942245057113

In [None]:
base_images=os.getcwd()
img_path = os.path.join(base_images,"data","Trax_bbox出来的小图含label_20230207")
img_paths=get_image_path(img_path)
correct_count = 0
total_count = 0
for img_path in img_paths:
    image = Image.open(img_path)
    start_time = time.time()
    img_path, D, filenames = image_search(image, k=4)
    temp_filenames=filenames[0]
    true_name=filenames[0][0]
    temp_filenames.pop(0)
    total_count=total_count+1
    pred_name, freq = most_frequent_element(temp_filenames)
    #print("true_name:",true_name,"pred_name:",pred_name,"freq:",freq)
    if true_name == pred_name:
        correct_count = correct_count + 1
    else:
        print("img_path:", img_path, "true_name:", true_name, "pred_name:", pred_name, "freq:", freq)

print("correct_count:",correct_count,"total_count:",total_count,"accuracy:",correct_count/total_count)

