# 版本1
- 仅支持n_piece为int，不能为列表，array等，默认值为5
- K也仅持支持int，默认值是100
- query数量仅支持int，默认100

In [25]:
import h5py
import numpy as np
import faiss
import os
import time
from config import *
import math

## 基础函数导入

In [3]:
# 测试数据集获取
def get_test_data(data_choice:str='glove',dim:int=25,number:int=100):#这里定义一个number是为了选取测试集数量，默认是100个
    data_name= data_info[data_choice][dim] # 这里是为了方便创建文件夹和数据库文件名，方便识别
    glove_file_path = f"data/{data_name}.hdf5"
    glove_hdf = h5py.File(glove_file_path, "r")# 读取数据
    glove_test = glove_hdf['test'][:number]# 获取测试查询数据
    if data_choice=='glove':
        faiss.normalize_L2(glove_test)# glove数据采用的angular距离，首先需要进行归一化,然后再进行faiss.METRIC_L2
        # faiss.normalize_L2直接作用再原数组上，不需要返回值
    glove_neighbors = glove_hdf['neighbors'][:number]# 获取测
    glove_distances = glove_hdf['distances'][:number]# 获取测试数据集
    return glove_test,glove_neighbors,glove_distances

In [4]:
# 我们对训练数据进行分片，分成五组即可
# 到时设计的数据切片函数 参数：data_name:str,n_piece:int=5,返回一个字典和data_name对应的数据
# 对于data_name 除了指定具体的名字外如glove-25-angular，还可以指定glove和数据维度25，到时我们用一个字典进行匹配就可以了
def data_piece(data_choice:str='glove',n_piece:int=5,dim:int=25)->dict:# 如果data_choice是glove则dim必须是25,50,100,200，如果是sift则dim必须是128，否则会报错
    n_piece= n_piece#这里是用来定义切片的数量
    data_name= data_info[data_choice][dim] # 这里是为了方便创建文件夹和数据库文件名，方便识别
    glove_file_path = f"data/{data_name}.hdf5"# 数据地址，hdf5格式
    glove_hdf = h5py.File(glove_file_path, "r")# 读取数据
    length_all = len(glove_hdf['train'])# 获取数据的总长度  
    cut_point=int(length_all/n_piece)# 获取切分点,int是向下取整
    data_dict={}# 这里用字典来存取每一段切分点
    id_dict={}# 这里用字典来存取每一段对应的索引
    for i in range(n_piece):
        data_key=data_name+'_n'+str(n_piece)+'_'+str(i+1)
        if i+1 == n_piece:
            data_dict[data_key]=glove_hdf['train'][i*cut_point:]# 把剩下的余数全部纳入不在单独考虑了（如果能整除则恰好每一块长度相同，如果不能整除则最后一段的长度=cut_point+余数），这里是最后一段，
            id=list(range(i*cut_point,length_all))
            id_dict[data_key]=id
        else:
            data_dict[data_key]=glove_hdf['train'][i*cut_point:(i+1)*cut_point]
            id=list(range(i*cut_point,(i+1)*cut_point))
            id_dict[data_key]=id
    return data_dict,id_dict,data_name
# 用法示例
data_dict,id_dict,data_name=data_piece('glove',5)
data_dict.keys(),len(data_dict['glove-25-angular_n5_1']),len(data_dict['glove-25-angular_n5_4']),type(data_dict['glove-25-angular_n5_4'])

In [5]:
# 创建向量数据库所需要的文件夹
def create_folder_if_not_exists(data_name:str,n_piece:int=5)->str:#这里的n_piece需要与data_piece函数中的n_piece一致
    folder_path = f"index/{data_name}_n{n_piece}"
    # 检查文件夹是否存在
    if not os.path.exists(folder_path):
        # 不存在时创建文件夹
        os.makedirs(folder_path)
        print(f"文件夹 '{folder_path}' 不存在，已创建。")
    else:
        print(f"文件夹 '{folder_path}' 已经存在。")
    return folder_path

In [24]:
# 创建向量数据库
# 关于向量数据库所需要的一些个参数也可以设为该函数需要传的参数如measure, param
# 这里的关于分片后的数据的id还得自定义一下
def create_index(data_choice:str='glove',n_piece:int=5,dim:int=25,param:str='HNSW64'):# 如果data_choice是glove则dim必须是25,50,100,200，如果是sift则dim必须是128，否则会报错
    data_dict,id_dict,data_name=data_piece(data_choice,n_piece,dim)
    folder_path = create_folder_if_not_exists(data_name,n_piece)
    for i in range(n_piece):
        data_key=data_name+'_n'+str(n_piece)+'_'+str(i+1)
        file_path = f"{folder_path}/{data_key}.index"
        dim = dim
        param = param
        measure=faiss.METRIC_L2# 这里固定因为glove和sift都用的是L2距离
        index = faiss.index_factory(dim, param, measure)  
        # 导入数据
        df=data_dict[data_key]
        if data_choice=='glove':
            print(f'现在处理的glove数据集{data_key},需要进行归一化处理')
            faiss.normalize_L2(df)# glove数据采用的angular距离，首先需要进行归一化,然后再进行faiss.METRIC_L2
        # index.add(df)
        # faiss.write_index(index, file_path)
        # 做id的映射和自定义
        xids = np.array(id_dict[data_key])
        IDMap_index = faiss.IndexIDMap(index)
        IDMap_index.add_with_ids(df, xids)
        faiss.write_index(IDMap_index, file_path)
        print(f"{data_key}.index 已创建。")

In [7]:
# 测试数据集的拼接
def combine_list(data: list, ids: list, k: int = 100):
    # 将每个元素和对应ID放在一个元组中
    combined_tuples = [(element, id_val) for element, id_val in zip(data, ids)]
    # 对拼接后的元组列表按照元素排序，sorted是从小到大排序的，正好取前k个即可
    sorted_combined_tuples = sorted(combined_tuples, key=lambda x: x[0])
    # 输出排序后的元素和对应的ID列表
    sorted_elements = [element[0] for element in sorted_combined_tuples]
    sorted_ids = [element[1] for element in sorted_combined_tuples]
    return sorted_elements[:k], sorted_ids[:k]


In [8]:
# 获取测试结果
# 这里不需要传参measure，因为两个数据集采用的measure=faiss.METRIC_L2是一样的，除了glove需要加一个归一化
def get_search_result(data_choice:str='glove',n_piece:int=5,dim:int=25,k:int=100,number:int=100):# number是指定测试集的数量,要和get_test_data函数中的number一致
    # data_dict,data_name=data_piece(data_choice,n_piece,dim)
    data_name= data_info[data_choice][dim] 
    folder_path = create_folder_if_not_exists(data_name,n_piece)
    golve_test,_,_=get_test_data(data_choice,dim,number)# 这是一个双层数组，因为有这么多测试数据集
    # 测试数据提取
    search_id=[]
    search_distance=[]
    for i in range(n_piece):
        data_key=data_name+'_n'+str(n_piece)+'_'+str(i+1)
        print(f'正在处理数据集{data_key}')
        file_path = f"{folder_path}/{data_key}.index"
        index = faiss.read_index(file_path)
        sd,sid=index.search(golve_test, k)
        search_id.append(sid)
        search_distance.append(sd)
    search_id=np.array(search_id)
    search_distance=np.array(search_distance)
    # 现在的数据维度是(n_piece,number,k),我们需要转换成(number,k*n_piece),然后实现对后面两个维度的拼接
    # np.concatenate(arr_3d, axis=1) 会沿着第二个维度（axis=1）将每个子数组拼接在一起，得到你想要的结果。
    # 这里不能用reshape
    # 用这种方法也可以np.transpose(arr_3d, (1, 0, 2)).reshape(arr_3d.shape[1],-1)，
        # 这样的话可以for循环遍历number，然后剩下的维度可以交给二维combine_list函数处理
    search_id=np.concatenate(search_id, axis=1)
    search_distance=np.concatenate(search_distance, axis=1)
    # return search_id,search_distance
    # 然后进行排序操作
    test_id=[]
    test_distance=[]
    for i in range(number):
        # print(f'正在处理第{i}个测试数据')
        distance_list,id_list = combine_list(search_distance[i], search_id[i],k)
        test_id.append(id_list) 
        test_distance.append(distance_list)
    test_id=np.array(test_id)
    test_distance=np.array(test_distance)
    return test_id,test_distance

In [17]:
# 计算召回率
def calculate_recall_np(test_id, ground_truth_id):
    # 计算 True Positives，计算交集
    tp = len(np.intersect1d(test_id, ground_truth_id))
    # 这里不用计算False Positives，只需要求出TP后，除以ground_truth_id的长度即可
    # 计算召回率
    recall = tp / len(ground_truth_id) if len(ground_truth_id) != 0 else 0.0
    return recall

def get_recall(data_choice:str='glove',n_piece:int=5,dim:int=25,k:int=100,number:int=100):
    # glove_test,distance_true和neighbors_true暂时没用到
    _,neighbors_true,_=get_test_data(data_choice,dim,number)# 获取真实id
    start_time = time.time()
    test_id,_=get_search_result(data_choice,n_piece,dim,k,number)
    end_time = time.time()
    elapsed_time = end_time - start_time  # 计算查询所需的时间
    # print(test_id.shape,neighbors_true.shape)
    recall_list=[]
    for i in range(number):
        recall=calculate_recall_np(test_id[i],neighbors_true[i])
        recall_list.append(recall)
    return np.mean(recall_list),elapsed_time

In [19]:
get_recall('glove',5,25,100,100)

文件夹 'index/glove-25-angular_n5' 已经存在。
正在处理数据集glove-25-angular_n5_1
正在处理数据集glove-25-angular_n5_2
正在处理数据集glove-25-angular_n5_3
正在处理数据集glove-25-angular_n5_4
正在处理数据集glove-25-angular_n5_5


(0.9214000000000001, 1.8633124828338623)

In [25]:
create_index('glove',1,25)
print('数据集以创建完')
get_recall('glove',1,25,100,100)# 返回一个召回率，一个查询时间

文件夹 'index/glove-25-angular_n1' 已经存在。
现在处理的glove数据集glove-25-angular_n1_1,需要进行归一化处理
glove-25-angular_n1_1.index 已创建。
数据集以创建完
文件夹 'index/glove-25-angular_n1' 已经存在。
正在处理数据集glove-25-angular_n1_1


(0.7817999999999999, 0.4783446788787842)

## 多组实验报告生成

In [6]:
# # 创建不同分片大小的向量数据库所需要的时间对比
# def get_report_create_index_time(data_choice:str='glove',dim:int=25,n_piece:list=[5]):
#     # 关于创建向量数据库的实验
#     create_index_time={}# 创建不同切片向量数据库所需要的时间
#     data_name= data_info[data_choice][dim]
#     for n in n_piece:
#         start_time = time.time()
#         create_index(data_choice,n,dim)
#         end_time = time.time()
#         elapsed_time = end_time - start_time  # 计算创建n_piece向量数据库所需的时间
#         database_key=data_name+'_n'+str(n)
#         create_index_time[database_key]=elapsed_time
#     return create_index_time

In [7]:
# 创建不同分片大小的向量数据库所需要的时间对比
def get_report_create_index_time(data_choice:str='glove',dim:int=25,n_piece:list=[1,5,10,15,20,25]):
    # 关于创建向量数据库的实验
    create_index_time={}# 创建不同切片向量数据库所需要的时间
    # data_name= data_info[data_choice][dim]
    for n in n_piece:
        start_time = time.time()
        create_index(data_choice,n,dim)
        end_time = time.time()
        elapsed_time = end_time - start_time  # 计算创建n_piece向量数据库所需的时间
        # database_key=data_name+'_n'+str(n)
        # create_index_time[database_key]=elapsed_time
        # 不用database_key了，直接用n，这里是为了方便画图，下面的几组实验也是一样的
        create_index_time[n]=elapsed_time
    return create_index_time

In [8]:
# # 不同切片数据库在不同k值下的查询时间和召回率对比,query=100(也可以自己定义)
# def get_report_n_piece_k(data_choice:str='glove',dim:int=25,n_piece:list=[1,5,10,15,20,25],
#                k:list=[10,100,1000,5000,10000,50000,100000],number:int=100):
#     # 定义一个两层循环，外层循环是n_piece，内层循环是k
#     # 首先是创建数据库，这里记录创建不同分片数据库的时间
#     create_index_time=get_report_create_index_time(data_choice,dim,n_piece)
#     print('数据库创建完毕')
#     # 然后是查询数据库，这里记录不同分片数据库在不同k值下的查询时间和召回率对比
#     search_time={}# 查询时间
#     seach_recall={}# 查询召回率
#     data_name= data_info[data_choice][dim]
#     for n in n_piece:
#         name1=data_name+'_n'+str(n)+'_q'+str(number)
#         search_time[name1]={}
#         seach_recall[name1]={}
#         for k_ in k:
#             print(f'现在处理的是{n}分片数据库，k值是{k_}')
#             name2=name1+'_k'+str(k_)
#             mean_recall,elapsed_time=get_recall(data_choice,n,dim,k_,number)#get_recall返回的是平均召回率和查询时间(np.mean(recall_list),elapsed_time)
#             search_time[name1][name2]=elapsed_time
#             seach_recall[name1][name2]=mean_recall
#     return create_index_time,search_time,seach_recall

In [10]:
# 不同切片数据库在不同k值下的查询时间和召回率对比,query=100(也可以自己定义)
def get_report_n_piece_k(data_choice:str='glove',dim:int=25,n_piece:list=[1,5,10,15,20,25],
               k:list=[10,100,1000,5000,10000,50000,100000],number:int=100):
    # 定义一个两层循环，外层循环是n_piece，内层循环是k
    # 首先是创建数据库，这里记录创建不同分片数据库的时间
    # 注意：这里有一个两层字典，字典的第一层是n_piece，第二层是k
    create_index_time=get_report_create_index_time(data_choice,dim,n_piece)
    print('数据库创建完毕')
    # 然后是查询数据库，这里记录不同分片数据库在不同k值下的查询时间和召回率对比
    search_time={}# 查询时间
    seach_recall={}# 查询召回率
    for n in n_piece:
        search_time[n]={}
        seach_recall[n]={}
        for k_ in k:
            print(f'现在处理的是{n}分片数据库，k值是{k_}')
            mean_recall,elapsed_time=get_recall(data_choice,n,dim,k_,number)#get_recall返回的是平均召回率和查询时间(np.mean(recall_list),elapsed_time)
            search_time[n][k_]=elapsed_time
            seach_recall[n][k_]=mean_recall
    return create_index_time,search_time,seach_recall

In [3]:
# # 此时n是固定的为5，对比不同query数量下的查询时间和召回率对比
# def get_report_number_k(data_choice:str='glove',dim:int=25,n_piece:int=5,
#                k:list=[10,100,1000,5000,10000,50000,100000],number:list=[10,500,1000,5000,10000]):
#     # 首先是创建相关向量数据库
#     create_index(data_choice=data_choice,n_piece=n_piece,dim=dim)
#     print('数据库创建完毕')
#     # 接下来是查询数据库
#     search_time={}# 查询时间
#     seach_recall={}# 查询召回率
#     data_name= data_info[data_choice][dim]
#     for n in number:
#         name1=data_name+'_n'+str(n_piece)+'_q'+str(n)
#         search_time[name1]={}
#         seach_recall[name1]={}
#         for k_ in k:
#             print(f'现在处理的是{n}个查询数据，k值是{k_}')
#             name2=name1+'_k'+str(k_)
#             mean_recall,elapsed_time=get_recall(data_choice,n_piece,dim,k_,n)
#             search_time[name1][name2]=elapsed_time
#             seach_recall[name1][name2]=mean_recall
#     return search_time,seach_recall

In [None]:
# 此时n是固定的为5，对比不同query数量下的查询时间和召回率对比
def get_report_number_k(data_choice:str='glove',dim:int=25,n_piece:int=5,
               k:list=[10,100,1000,5000,10000,50000,100000],number:list=[10,500,1000,5000,10000]):
    # 首先是创建相关向量数据库
    create_index(data_choice=data_choice,n_piece=n_piece,dim=dim)
    print('数据库创建完毕')
    # 接下来是查询数据库
    # 注意：这里有一个两层字典，字典的第一层是number，第二层是k
    search_time={}# 查询时间
    seach_recall={}# 查询召回率
    for n in number:
        search_time[n]={}
        seach_recall[n]={}
        for k_ in k:
            print(f'现在处理的是{n}个查询数据，k值是{k_}')
            mean_recall,elapsed_time=get_recall(data_choice,n_piece,dim,k_,n)
            search_time[n][k_]=elapsed_time
            seach_recall[n][k_]=mean_recall
    return search_time,seach_recall

In [2]:
# # 此时k固定为1000，对比不同分片数据库下，不同query数量下的查询时间和召回率对比
# def get_report_n_piece_number(data_choice:str='glove',dim:int=25,n_piece:list=[1,5,10,15,20,25],
#                k:int=1000,number:list=[10,500,1000,5000,10000]):
#     # 首先是创建相关向量数据库
#     create_index_time=get_report_create_index_time(data_choice,dim,n_piece)
#     print('数据库创建完毕')
#     # 接下来是查询数据库
#     search_time={}# 查询时间
#     seach_recall={}# 查询召回率
#     data_name= data_info[data_choice][dim]
#     for n in n_piece:
#         name1=data_name+'_n'+str(n)
#         search_time[name1]={}
#         seach_recall[name1]={}
#         for n_ in number:
#             print(f'现在处理的是{n}分片数据库，查询数据是{n_}')
#             name2=name1+'_q'+str(n_)+'_k'+str(k)
#             mean_recall,elapsed_time=get_recall(data_choice,n,dim,k,n_)
#             search_time[name1][name2]=elapsed_time
#             seach_recall[name1][name2]=mean_recall
#     return create_index_time,search_time,seach_recall

In [None]:
# 此时k固定为1000，对比不同分片数据库下，不同query数量下的查询时间和召回率对比
def get_report_n_piece_number(data_choice:str='glove',dim:int=25,n_piece:list=[1,5,10,15,20,25],
               k:int=1000,number:list=[10,500,1000,5000,10000]):
    # 首先是创建相关向量数据库
    create_index_time=get_report_create_index_time(data_choice,dim,n_piece)
    print('数据库创建完毕')
    # 接下来是查询数据库
    # 注意：这里有一个两层字典，字典的第一层是n_piece，第二层是number
    search_time={}# 查询时间
    seach_recall={}# 查询召回率
    for n in n_piece:
        search_time[n]={}
        seach_recall[n]={}
        for n_ in number:
            print(f'现在处理的是{n}分片数据库，查询数据是{n_}')
            mean_recall,elapsed_time=get_recall(data_choice,n,dim,k,n_)
            search_time[n][n_]=elapsed_time
            seach_recall[n][n_]=mean_recall
    return create_index_time,search_time,seach_recall

## 可视化

In [3]:
# 这里用的库是matplotlib
import matplotlib.pyplot as plt

In [4]:
# 创建存储图片所需要的文件夹
def create_photo_store(data_choice:str='glove',dim:int=25):
    data_name= data_info[data_choice][dim]
    folder_path = f"figure/{data_name}"
    # 检查文件夹是否存在
    if not os.path.exists(folder_path):
        # 不存在时创建文件夹
        os.makedirs(folder_path)
        print(f"文件夹 '{folder_path}' 不存在，已创建。")
    else:
        print(f"文件夹 '{folder_path}' 已经存在。")
    return folder_path

### create_index_time可视化

In [5]:
# create_index_time结果可视化
def plot_create_index_time(create_index_time:dict,data_choice:str='glove',dim:int=25):
    # 首先时创建存储图片的文件夹
    folder_path=create_photo_store(data_choice,dim)
    file_path = f"{folder_path}/create_index_time.png"
    # plt.figure(figsize=(10, 6))
    x_values=list(create_index_time.keys())
    y_values=list(create_index_time.values())
    plt.plot(x_values, y_values,label='create_index_time',marker='o')
    # 在每个点上标注具体数值
    # 在每个点上标注y值
    for x, y in zip(x_values, y_values):
        plt.text(x, y, f'{y:.3f}', ha='left', va='bottom')

    # # 添加每个点到x轴的虚线连接
    # for x, y in zip(x_values, y_values):
    #     plt.plot([x, x], [0, y], 'k--', lw=1)
    plt.xlabel("Number of pieces")
    plt.ylabel("Time (s)")
    plt.title("Time to create index for different number of pieces")
    # 添加图列
    plt.legend()
    # 保存图片
    plt.savefig(file_path)
    # # 展示图片
    # plt.show()

### 接下来是对report_n_piece_k进行可视化

In [11]:
def plot_search_time_npiece_k_bar3d(search_time_npiece_k:dict,data_choice:str='glove',dim:int=25):
    # 首先时创建存储图片的文件夹
    folder_path=create_photo_store(data_choice,dim)
    file_path = f"{folder_path}/search_time_npiece_k_bar3d.png"
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    x_values=[key1 for key1, inner_dict in search_time_npiece_k.items() for key2, value in inner_dict.items()]
    y_values=[key2 for key1, inner_dict in search_time_npiece_k.items() for key2, value in inner_dict.items()]
    z_values=[value for key1, inner_dict in search_time_npiece_k.items() for key2, value in inner_dict.items()]
    print(x_values,y_values,z_values)
    # Normalize the data for color mapping
    norm = plt.Normalize(min(z_values), max(z_values))
    colors = plt.cm.coolwarm(norm(z_values))
    # Create 3D bar chart with transparency
    ax.bar3d(x_values, y_values, np.zeros_like(z_values), 1.5, 1.5, z_values, shade=False, color=colors, alpha=0.5)
    ax.set_xlabel('Number of pieces')
    ax.set_ylabel('k')
    ax.set_zlabel('Time (s)')
    font_size = 12
    plt.title("Time to search for different number of pieces and k,query = 100", fontsize=font_size)
    # 设置视角（仰角，方位角）
    ax.view_init(elev=15, azim=250)
    # 保存图片
    plt.savefig(file_path)
    # # 展示图片
    # plt.show()

In [12]:
def plot_search_recall_npiece_k_bar3d(search_recall_npiece_k:dict,data_choice:str='glove',dim:int=25):
    # 首先时创建存储图片的文件夹
    folder_path=create_photo_store(data_choice,dim)
    file_path = f"{folder_path}/search_recall_npiece_k_bar3d.png"
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    x_values=[key1 for key1, inner_dict in search_recall_npiece_k.items() for key2, value in inner_dict.items()]
    y_values=[key2 for key1, inner_dict in search_recall_npiece_k.items() for key2, value in inner_dict.items()]
    z_values=[value for key1, inner_dict in search_recall_npiece_k.items() for key2, value in inner_dict.items()]
    print(x_values,y_values,z_values)
    # Normalize the data for color mapping
    norm = plt.Normalize(min(z_values), max(z_values))
    colors = plt.cm.coolwarm(norm(z_values))
    # Create 3D bar chart with transparency
    ax.bar3d(x_values, y_values, np.zeros_like(z_values), 1.5, 1.5, z_values, shade=False, color=colors, alpha=0.5)
    # 将 x 轴设置为对数刻度
    # plt.xscale('log')
    ax.set_xlabel('Number of pieces')
    ax.set_ylabel('k')
    ax.set_zlabel('Recall')
    font_size = 12
    plt.title("Recall to search for different number of pieces and k,query = 100", fontsize=font_size)
    # 设置视角（仰角，方位角）
    ax.view_init(elev=15, azim=250)
    # 保存图片
    plt.savefig(file_path)
    # # 展示图片
    # plt.show()

In [13]:
def plot_search_time_npiece_k_subplot(search_time_npiece_k:dict,data_choice:str='glove',dim:int=25):
    folder_path=create_photo_store(data_choice,dim)
    file_path = f"{folder_path}/search_time_npiece_k_subplot.png"
    key1= list(search_time_npiece_k.keys())# 这里的key1是n_piece
    num_subplots = len(search_time_npiece_k)
    # 每行显示的子图数量
    subplots_per_row = 3
    # 计算总共需要多少行
    num_rows = math.ceil(num_subplots / subplots_per_row)
    # 生成对应数量的子图
    fig, axs = plt.subplots(num_rows, subplots_per_row, figsize=(15, 4 * num_rows), sharex=True)
    # 遍历第一层 key，每个 key 对应一个子图
    for i, (key, inner_dict) in enumerate(search_time_npiece_k.items()):
        # 计算当前子图在网格中的位置
        row_idx = i // subplots_per_row
        col_idx = i % subplots_per_row
        # 获取子图对象
        ax = axs[row_idx, col_idx] if num_rows > 1 else axs[col_idx]
        # 提取数据
        x_values = list(inner_dict.keys())
        y_values = list(inner_dict.values())
        # 在子图中绘制折线图
        ax.plot(x_values, y_values, marker='o')
        # 添加标签和标题
        ax.set_xlabel(f'k ')
        ax.set_ylabel('Time (s)')
        ax.set_title(f'n_piece = {key1[i]}')
        # # 设置y轴的刻度值
        # ax.set_yticks(y_values)
        # ax.set_yticklabels([f'{val:.2f}' for val in y_values])
        # 添加图例
        ax.legend()
    # 隐藏未使用的子图
    for i in range(num_subplots, num_rows * subplots_per_row):
        row_idx = i // subplots_per_row
        col_idx = i % subplots_per_row
        axs[row_idx, col_idx].axis('off')
    # 调整布局
    plt.tight_layout()
    # 增加总的标题和子图标题之间的间隔
    plt.subplots_adjust(top=0.85)
    # 添加总的标题
    plt.suptitle('Time to search for different number of pieces and k,query = 100')
    # 保存图片
    plt.savefig(file_path)
    plt.show()

In [14]:
def plot_search_recall_npiece_k_subplot(seach_recall_n_piece_k:dict,data_choice:str='glove',dim:int=25):
    folder_path=create_photo_store(data_choice,dim)
    file_path = f"{folder_path}/search_recall_npiece_k_subplot.png"
    key1= list(seach_recall_n_piece_k.keys())# 这里的key1是n_piece
    num_subplots = len(seach_recall_n_piece_k)
    # 每行显示的子图数量
    subplots_per_row = 3
    # 计算总共需要多少行
    num_rows = math.ceil(num_subplots / subplots_per_row)
    # 生成对应数量的子图
    fig, axs = plt.subplots(num_rows, subplots_per_row, figsize=(15, 4 * num_rows), sharex=True)
    # 遍历第一层 key，每个 key 对应一个子图
    for i, (key, inner_dict) in enumerate(seach_recall_n_piece_k.items()):
        # 计算当前子图在网格中的位置
        row_idx = i // subplots_per_row
        col_idx = i % subplots_per_row
        # 获取子图对象
        ax = axs[row_idx, col_idx] if num_rows > 1 else axs[col_idx]
        # 提取数据
        x_values = list(inner_dict.keys())
        y_values = list(inner_dict.values())
        # 在子图中绘制折线图
        ax.plot(x_values, y_values, marker='o')
        # 添加标签和标题
        ax.set_xlabel(f'k ')
        ax.set_ylabel('Recall')
        ax.set_title(f'n_piece = {key1[i]}')
        # 添加图例
        ax.legend()
    # 隐藏未使用的子图
    for i in range(num_subplots, num_rows * subplots_per_row):
        row_idx = i // subplots_per_row
        col_idx = i % subplots_per_row
        axs[row_idx, col_idx].axis('off')
    # 调整布局
    plt.tight_layout()
    # 增加总的标题和子图标题之间的间隔
    plt.subplots_adjust(top=0.85)
    # 添加总的标题
    plt.suptitle('Recall to search for different number of pieces and k,query = 100')
    # 保存图片
    plt.savefig(file_path)
    plt.show()


### 接下来是对result_number_k进行可视化

In [15]:
def plot_search_time_number_k_bar3d(search_time_number_k:dict,data_choice:str='glove',dim:int=25):
    # 首先时创建存储图片的文件夹
    folder_path=create_photo_store(data_choice,dim)
    file_path = f"{folder_path}/search_time_number_k_bar3d.png"
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    x_values=[key1 for key1, inner_dict in search_time_number_k.items() for key2, value in inner_dict.items()]
    y_values=[key2 for key1, inner_dict in search_time_number_k.items() for key2, value in inner_dict.items()]
    z_values=[value for key1, inner_dict in search_time_number_k.items() for key2, value in inner_dict.items()]
    print(x_values,y_values,z_values)
    # Normalize the data for color mapping
    norm = plt.Normalize(min(z_values), max(z_values))
    colors = plt.cm.coolwarm(norm(z_values))
    # Create 3D bar chart with transparency
    ax.bar3d(x_values, y_values, np.zeros_like(z_values), 300, 50, z_values, shade=False, color=colors, alpha=0.5)
    ax.set_xlabel('Number of querys')
    ax.set_ylabel('k')
    ax.set_zlabel('Time (s)')
    font_size = 12
    plt.title("Time to search for different number of query and k,npiece = 5",fontsize=font_size)
    # 设置视角（仰角，方位角）
    ax.view_init(elev=15, azim=250)
    # 保存图片
    plt.savefig(file_path)
    # 展示图片
    plt.show()

In [16]:
def plot_search_recall_number_k_bar3d(search_recall_number_k:dict,data_choice:str='glove',dim:int=25):
    # 首先时创建存储图片的文件夹
    folder_path=create_photo_store(data_choice,dim)
    file_path = f"{folder_path}/search_recall_number_k_bar3d.png"
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    x_values=[key1 for key1, inner_dict in search_recall_number_k.items() for key2, value in inner_dict.items()]
    y_values=[key2 for key1, inner_dict in search_recall_number_k.items() for key2, value in inner_dict.items()]
    z_values=[value for key1, inner_dict in search_recall_number_k.items() for key2, value in inner_dict.items()]
    print(x_values,y_values,z_values)
    # Normalize the data for color mapping
    norm = plt.Normalize(min(z_values), max(z_values))
    colors = plt.cm.coolwarm(norm(z_values))
    # Create 3D bar chart with transparency
    ax.bar3d(x_values, y_values, np.zeros_like(z_values), 300, 50, z_values, shade=False, color=colors, alpha=0.5)
    ax.set_xlabel('Number of querys')
    ax.set_ylabel('k')
    ax.set_zlabel('Recall')
    font_size = 12
    plt.title("Recall to search for different number of query and k,npiece = 5",fontsize=font_size)
    # 设置视角（仰角，方位角）
    ax.view_init(elev=15, azim=250)
    # 保存图片
    plt.savefig(file_path)
    # 展示图片
    plt.show()

In [17]:
def plot_search_time_number_k_subplot(search_time_number_k:dict,data_choice:str='glove',dim:int=25):
    folder_path=create_photo_store(data_choice,dim)
    file_path = f"{folder_path}/search_time_number_k_subplot.png"
    key1= list(search_time_number_k.keys())# 这里的key1是number of querys
    num_subplots = len(search_time_number_k)
    # 每行显示的子图数量
    subplots_per_row = 3
    # 计算总共需要多少行
    num_rows = math.ceil(num_subplots / subplots_per_row)
    # 生成对应数量的子图
    fig, axs = plt.subplots(num_rows, subplots_per_row, figsize=(15, 4 * num_rows), sharex=True)
    # 遍历第一层 key，每个 key 对应一个子图
    for i, (key, inner_dict) in enumerate(search_time_number_k.items()):
        # 计算当前子图在网格中的位置
        row_idx = i // subplots_per_row
        col_idx = i % subplots_per_row
        # 获取子图对象
        ax = axs[row_idx, col_idx] if num_rows > 1 else axs[col_idx]
        # 提取数据
        x_values = list(inner_dict.keys())
        y_values = list(inner_dict.values())
        # 在子图中绘制折线图
        ax.plot(x_values, y_values, marker='o')
        # 添加标签和标题
        ax.set_xlabel(f'k ')
        ax.set_ylabel('Time (s)')
        ax.set_title(f'number of querys = {key1[i]}')
        # # 设置y轴的刻度值
        # ax.set_yticks(y_values)
        # ax.set_yticklabels([f'{val:.2f}' for val in y_values])
        # 添加图例
        ax.legend()
    # 隐藏未使用的子图
    for i in range(num_subplots, num_rows * subplots_per_row):
        row_idx = i // subplots_per_row
        col_idx = i % subplots_per_row
        axs[row_idx, col_idx].axis('off')
    # 调整布局
    plt.tight_layout()
    # 增加总的标题和子图标题之间的间隔
    plt.subplots_adjust(top=0.85)
    # 添加总的标题
    plt.suptitle('Time to search for different number of query and k,npiece = 5')
    # 保存图片
    plt.savefig(file_path)
    plt.show()

In [18]:
def plot_search_recall_number_k_subplot(search_recall_number_k:dict,data_choice:str='glove',dim:int=25):
    folder_path=create_photo_store(data_choice,dim)
    file_path = f"{folder_path}/search_recall_number_k_subplot.png"
    key1= list(search_recall_number_k.keys())# 这里的key1是number of querys
    num_subplots = len(search_recall_number_k)
    # 每行显示的子图数量
    subplots_per_row = 3
    # 计算总共需要多少行
    num_rows = math.ceil(num_subplots / subplots_per_row)
    # 生成对应数量的子图
    fig, axs = plt.subplots(num_rows, subplots_per_row, figsize=(15, 4 * num_rows), sharex=True)
    # 遍历第一层 key，每个 key 对应一个子图
    for i, (key, inner_dict) in enumerate(search_recall_number_k.items()):
        # 计算当前子图在网格中的位置
        row_idx = i // subplots_per_row
        col_idx = i % subplots_per_row
        # 获取子图对象
        ax = axs[row_idx, col_idx] if num_rows > 1 else axs[col_idx]
        # 提取数据
        x_values = list(inner_dict.keys())
        y_values = list(inner_dict.values())
        # 在子图中绘制折线图
        ax.plot(x_values, y_values, marker='o')
        # 添加标签和标题
        ax.set_xlabel(f' k ')
        ax.set_ylabel('Recall')
        ax.set_title(f'number of querys = {key1[i]}')
        # # 设置y轴的刻度值
        # ax.set_yticks(y_values)
        # ax.set_yticklabels([f'{val:.2f}' for val in y_values])
        # 添加图例
        ax.legend()
    # 隐藏未使用的子图
    for i in range(num_subplots, num_rows * subplots_per_row):
        row_idx = i // subplots_per_row
        col_idx = i % subplots_per_row
        axs[row_idx, col_idx].axis('off')
    # 调整布局
    plt.tight_layout()
    # 增加总的标题和子图标题之间的间隔
    plt.subplots_adjust(top=0.85)
    # 添加总的标题
    plt.suptitle('Recall to search for different number of query and k,npiece = 5')
    # 保存图片
    plt.savefig(file_path)
    plt.show()

### 接下来是对result_n_piece_number进行可视化

In [19]:
def plot_search_time_npiece_number_bar3d(search_time_npiece_number:dict,data_choice:str='glove',dim:int=25):
    # 首先时创建存储图片的文件夹
    folder_path=create_photo_store(data_choice,dim)
    file_path = f"{folder_path}/search_time_npiece_number_bar3d.png"
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    x_values=[key1 for key1, inner_dict in search_time_npiece_number.items() for key2, value in inner_dict.items()]
    y_values=[key2 for key1, inner_dict in search_time_npiece_number.items() for key2, value in inner_dict.items()]
    z_values=[value for key1, inner_dict in search_time_npiece_number.items() for key2, value in inner_dict.items()]
    print(x_values,y_values,z_values)
    # Normalize the data for color mapping
    norm = plt.Normalize(min(z_values), max(z_values))
    colors = plt.cm.coolwarm(norm(z_values))
    # Create 3D bar chart with transparency
    ax.bar3d(x_values, y_values, np.zeros_like(z_values), 1.5, 300, z_values, shade=False, color=colors, alpha=0.5)
    ax.set_xlabel('Number of pieces')
    ax.set_ylabel('Number of querys')
    ax.set_zlabel('Time (s)')
    font_size = 12
    plt.title("Time to search for different number of pieces and number,k = 1000", fontsize=font_size)
    # 设置视角（仰角，方位角）
    ax.view_init(elev=15, azim=250)
    # 保存图片
    plt.savefig(file_path)
    # # 展示图片
    # plt.show()

In [20]:
def plot_search_recall_npiece_number_bar3d(search_recall_npiece_number:dict,data_choice:str='glove',dim:int=25):
    # 首先时创建存储图片的文件夹
    folder_path=create_photo_store(data_choice,dim)
    file_path = f"{folder_path}/search_recall_npiece_number_bar3d.png"
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    x_values=[key1 for key1, inner_dict in search_recall_npiece_number.items() for key2, value in inner_dict.items()]
    y_values=[key2 for key1, inner_dict in search_recall_npiece_number.items() for key2, value in inner_dict.items()]
    z_values=[value for key1, inner_dict in search_recall_npiece_number.items() for key2, value in inner_dict.items()]
    print(x_values,y_values,z_values)
    # Normalize the data for color mapping
    norm = plt.Normalize(min(z_values), max(z_values))
    colors = plt.cm.coolwarm(norm(z_values))
    # Create 3D bar chart with transparency
    ax.bar3d(x_values, y_values, np.zeros_like(z_values), 1.5, 300, z_values, shade=False, color=colors, alpha=0.5)
    ax.set_xlabel('Number of pieces')
    ax.set_ylabel('Number of querys')
    ax.set_zlabel('Recall')
    font_size = 12
    plt.title("Recall to search for different number of pieces and number,k = 1000", fontsize=font_size)
    # 设置视角（仰角，方位角）
    ax.view_init(elev=15, azim=250)
    # 保存图片
    plt.savefig(file_path)
    # # 展示图片
    # plt.show()

In [21]:
def plot_search_time_npiece_number_subplot(search_time_npiece_number:dict,data_choice:str='glove',dim:int=25):
    folder_path=create_photo_store(data_choice,dim)
    file_path = f"{folder_path}/search_time_npiece_number_subplot.png"
    key1= list(search_time_npiece_number.keys())# 这里的key1是number of npiece
    num_subplots = len(search_time_npiece_number)
    # 每行显示的子图数量
    subplots_per_row = 3
    # 计算总共需要多少行
    num_rows = math.ceil(num_subplots / subplots_per_row)
    # 生成对应数量的子图
    fig, axs = plt.subplots(num_rows, subplots_per_row, figsize=(15, 4 * num_rows), sharex=True)
    # 遍历第一层 key，每个 key 对应一个子图
    for i, (key, inner_dict) in enumerate(search_time_npiece_number.items()):
        # 计算当前子图在网格中的位置
        row_idx = i // subplots_per_row
        col_idx = i % subplots_per_row
        # 获取子图对象
        ax = axs[row_idx, col_idx] if num_rows > 1 else axs[col_idx]
        # 提取数据
        x_values = list(inner_dict.keys())
        y_values = list(inner_dict.values())
        # 在子图中绘制折线图
        ax.plot(x_values, y_values, marker='o')
        # 添加标签和标题
        ax.set_xlabel(f'number of querys ')
        ax.set_ylabel('Time (s)')
        ax.set_title(f'number of npiece = {key1[i]}')
        # # 设置y轴的刻度值
        # ax.set_yticks(y_values)
        # ax.set_yticklabels([f'{val:.2f}' for val in y_values])
        # 添加图例
        ax.legend()
    # 隐藏未使用的子图
    for i in range(num_subplots, num_rows * subplots_per_row):
        row_idx = i // subplots_per_row
        col_idx = i % subplots_per_row
        axs[row_idx, col_idx].axis('off')
    # 调整布局
    plt.tight_layout()
    # 增加总的标题和子图标题之间的间隔
    plt.subplots_adjust(top=0.85)
    # 添加总的标题
    plt.suptitle('Time to search for different number of query and npiece,k=1000')
    # 保存图片
    plt.savefig(file_path)
    plt.show()

In [22]:
def plot_search_recall_npiece_number_subplot(search_recall_npiece_number:dict,data_choice:str='glove',dim:int=25):
    folder_path=create_photo_store(data_choice,dim)
    file_path = f"{folder_path}/search_recall_npiece_number_subplot.png"
    key1= list(search_recall_npiece_number.keys())# 这里的key1是number of npiece
    num_subplots = len(search_recall_npiece_number)
    # 每行显示的子图数量
    subplots_per_row = 3
    # 计算总共需要多少行
    num_rows = math.ceil(num_subplots / subplots_per_row)
    # 生成对应数量的子图
    fig, axs = plt.subplots(num_rows, subplots_per_row, figsize=(15, 4 * num_rows), sharex=True)
    # 遍历第一层 key，每个 key 对应一个子图
    for i, (key, inner_dict) in enumerate(search_recall_npiece_number.items()):
        # 计算当前子图在网格中的位置
        row_idx = i // subplots_per_row
        col_idx = i % subplots_per_row
        # 获取子图对象
        ax = axs[row_idx, col_idx] if num_rows > 1 else axs[col_idx]
        # 提取数据
        x_values = list(inner_dict.keys())
        y_values = list(inner_dict.values())
        # 在子图中绘制折线图
        ax.plot(x_values, y_values, marker='o')
        # 添加标签和标题
        ax.set_xlabel(f'number of querys ')
        ax.set_ylabel('Recall')
        ax.set_title(f'number of npiece = {key1[i]}')
        # # 设置y轴的刻度值
        # ax.set_yticks(y_values)
        # ax.set_yticklabels([f'{val:.2f}' for val in y_values])
        # 添加图例
        ax.legend()
    # 隐藏未使用的子图
    for i in range(num_subplots, num_rows * subplots_per_row):
        row_idx = i // subplots_per_row
        col_idx = i % subplots_per_row
        axs[row_idx, col_idx].axis('off')
    # 调整布局
    plt.tight_layout()
    # 增加总的标题和子图标题之间的间隔
    plt.subplots_adjust(top=0.85)
    # 添加总的标题
    plt.suptitle('Recall to search for different number of query and npiece,k=1000')
    # 保存图片
    plt.savefig(file_path)
    plt.show()

### 可视化函数封装

In [24]:
# 创建存储图片所需要的文件夹
def create_photo_store(data_choice:str='glove',dim:int=25):
    data_name= data_info[data_choice][dim]
    folder_path = f"photo/{data_name}"
    # 检查文件夹是否存在
    if not os.path.exists(folder_path):
        # 不存在时创建文件夹
        os.makedirs(folder_path)
        print(f"文件夹 '{folder_path}' 不存在，已创建。")
    else:
        print(f"文件夹 '{folder_path}' 已经存在。")
    return folder_path

In [26]:
from config import *

In [27]:
# 这里的task_choice是用来选择我们要执行的任务类型，这个在config.py中有定义。task_data_plot这个字典中
# 画柱状图
# data_dict是该任务对应的输出的数据
def plot_bar3d(task_choice:str,data_dict:dict, data_choice:str='glove', dim:int=25):
    # 首先时创建存储图片的文件夹
    folder_path=create_photo_store(data_choice,dim)
    file_path = f"{folder_path}/{task_choice}_bar3d.png"
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    x_values=[key1 for key1, inner_dict in data_dict.items() for key2, value in inner_dict.items()]
    y_values=[key2 for key1, inner_dict in data_dict.items() for key2, value in inner_dict.items()]
    z_values=[value for key1, inner_dict in data_dict.items() for key2, value in inner_dict.items()]
    print(x_values,y_values,z_values)
    # Normalize the data for color mapping
    norm = plt.Normalize(min(z_values), max(z_values))
    colors = plt.cm.coolwarm(norm(z_values))
    # Create 3D bar chart with transparency
    dx = task_data_plot[task_choice]['dx_bar3d']
    dy = task_data_plot[task_choice]['dy_bar3d']
    ax.bar3d(x_values, y_values, np.zeros_like(z_values), dx, dy, z_values, shade=False, color=colors, alpha=0.5)
    xlabel,ylabel,zlabel=task_data_plot[task_choice]['x_label_bar3d'],task_data_plot[task_choice]['y_label_bar3d'],task_data_plot[task_choice]['z_label_bar3d']
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_zlabel(zlabel)
    font_size = 12
    plt_title=task_data_plot[task_choice]['title']
    plt.title(plt_title, fontsize=font_size)
    # 设置视角（仰角，方位角）
    ax.view_init(elev=15, azim=250)
    # 保存图片
    plt.savefig(file_path)
    # # 展示图片
    # plt.show()


In [28]:
# 用于画子图，这的操作和plot_bar3d类似，只是这里是画子图
def plot_subplot(task_choice:str,data_dict:dict, data_choice:str='glove', dim:int=25):
    folder_path=create_photo_store(data_choice,dim)
    file_path = f"{folder_path}/{task_choice}_subplot.png"
    key1= list(data_dict.keys())# 这里的key1是number of npiece
    num_subplots = len(data_dict)
    # 每行显示的子图数量
    subplots_per_row = 3
    # 计算总共需要多少行
    num_rows = math.ceil(num_subplots / subplots_per_row)
    # 生成对应数量的子图
    fig, axs = plt.subplots(num_rows, subplots_per_row, figsize=(15, 4 * num_rows), sharex=True)
    # 遍历第一层 key，每个 key 对应一个子图
    for i, (key, inner_dict) in enumerate(data_dict.items()):
        # 计算当前子图在网格中的位置
        row_idx = i // subplots_per_row
        col_idx = i % subplots_per_row
        # 获取子图对象
        ax = axs[row_idx, col_idx] if num_rows > 1 else axs[col_idx]
        # 提取数据
        x_values = list(inner_dict.keys())
        y_values = list(inner_dict.values())
        # 在子图中绘制折线图
        ax.plot(x_values, y_values, marker='o')
        # 添加标签和标题
        set_xlabel, set_ylabel, set_title=task_data_plot[task_choice]['set_xlabel'],task_data_plot[task_choice]['set_ylabel'],task_data_plot[task_choice]['set_title']
        ax.set_xlabel(set_xlabel)
        ax.set_ylabel(set_ylabel)
        ax.set_title(f'{set_title} = {key1[i]}')
        # # 设置y轴的刻度值
        # ax.set_yticks(y_values)
        # ax.set_yticklabels([f'{val:.2f}' for val in y_values])
        # 添加图例
        ax.legend()
    # 隐藏未使用的子图
    for i in range(num_subplots, num_rows * subplots_per_row):
        row_idx = i // subplots_per_row
        col_idx = i % subplots_per_row
        axs[row_idx, col_idx].axis('off')
    # 调整布局
    plt.tight_layout()
    # 增加总的标题和子图标题之间的间隔
    plt.subplots_adjust(top=0.85)
    # 添加总的标题
    plt_title=task_data_plot[task_choice]['title']
    plt.suptitle(plt_title)
    # 保存图片
    plt.savefig(file_path)
    plt.show()

In [1]:
# 画折线图
def plot_multiple_lines(task_choice: str, data_dict: dict, data_choice: str = 'glove', dim: int = 25):
    # 创建文件夹并确定文件路径
    folder_path = create_photo_store(data_choice, dim)
    file_path = f"{folder_path}/{task_choice}_multiple_lines.png"
    key1 = list(data_dict.keys())  # 这里的key1是number of npiece
    num_lines = len(data_dict)
    # 创建图表对象
    plt.figure(figsize=(8, 5))
    # 遍历每条折线图
    for i, (key, inner_dict) in enumerate(data_dict.items()):
        # 提取数据
        x_values = list(inner_dict.keys())
        y_values = list(inner_dict.values())
        # 选择不同颜色，可以根据需要修改颜色
        line_color = plt.cm.viridis(i / num_lines)
        # 在同一个图中绘制多条折线图，并设置标签位置在图外
        sub_label=task_data_plot[task_choice]['set_title']
        plt.plot(x_values, y_values, marker='o', label=f'{sub_label} {key1[i]}', color=line_color)
    # 添加标签和标题
    set_xlabel, set_ylabel = task_data_plot[task_choice]['set_xlabel'], task_data_plot[task_choice][
        'set_ylabel']
    plt.xlabel(set_xlabel)
    plt.ylabel(set_ylabel)
    plt_title=task_data_plot[task_choice]['title']
    plt.title(plt_title)

    # 添加图例，并设置位置在图外
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    # 保存图片
    plt.savefig(file_path)
    plt.show()


