In [3]:
import numpy as np

def read_fvecs(filename):
    with open(filename, 'rb') as f:
        # 읽기
        data = np.fromfile(f, dtype='int32')
        # 첫 번째 값을 차원으로 사용
        dim = data[0]
        # 데이터 변환 및 리턴
        return data.reshape(-1, dim + 1)[:, 1:].astype('float32')

def read_ivecs(filename):
    with open(filename, 'rb') as f:
        # 읽기
        data = np.fromfile(f, dtype='int32')
        # 첫 번째 값을 차원으로 사용
        dim = data[0]
        # 데이터 변환 및 리턴
        return data.reshape(-1, dim + 1)[:, 1:]

# 파일 경로 설정
base_path = '/data/matmang/gist/gist_base.fvecs'
learn_path = '/data/matmang/gist/gist_learn.fvecs'
query_path = '/data/matmang/gist/gist_query.fvecs'
groundtruth_path = '/data/matmang/gist/gist_groundtruth.ivecs'

# 파일 읽기
gist_base = read_fvecs(base_path)
gist_learn = read_fvecs(learn_path)
gist_query = read_fvecs(query_path)
gist_groundtruth = read_ivecs(groundtruth_path)

print(f'Base shape: {gist_base.shape}')
print(f'Learn shape: {gist_learn.shape}')
print(f'Query shape: {gist_query.shape}')
print(f'Groundtruth shape: {gist_groundtruth.shape}')

Base shape: (1000000, 960)
Learn shape: (500000, 960)
Query shape: (1000, 960)
Groundtruth shape: (1000, 100)


In [4]:
import faiss

# 차원 수
dim = gist_base.shape[1]

# FlatL2 인덱스 생성 및 학습
index_flat = faiss.IndexFlatL2(dim)
index_flat.add(gist_base)

# IVF 인덱스 생성 및 학습
nprove = 100  # Number of clusters
quantizer = faiss.IndexFlatL2(dim)
index_ivf = faiss.IndexIVFFlat(quantizer, dim, nprove)
index_ivf.train(gist_learn)
index_ivf.add(gist_base)

# HNSW 인덱스 생성 및 학습
index_hnsw = faiss.IndexHNSWFlat(dim, 32)
index_hnsw.add(gist_base)

In [None]:
def search_and_measure(index, queries, groundtruth, k=5):
    distances, indices = index.search(queries, k)
    recall = (indices == groundtruth).sum() / float(len(queries))
    return distances, indices, recall

k = 5
flat_distances, flat_indices, flat_recall = search_and_measure(index_flat, gist_query, gist_groundtruth, k)
ivf_distances, ivf_indices, ivf_recall = search_and_measure(index_ivf, gist_query, gist_groundtruth, k)
hnsw_distances, hnsw_indices, hnsw_recall = search_and_measure(index_hnsw, gist_query, gist_groundtruth, k)

print(f'FlatL2 Recall@{k}: {flat_recall:.4f}')
print(f'IVF Recall@{k}: {ivf_recall:.4f}')
print(f'HNSW Recall@{k}: {hnsw_recall:.4f}')

In [None]:
import matplotlib.pyplot as plt

def plot_results(dataset_name, distances):
    plt.figure()
    plt.hist(distances.flatten(), bins=100, label=f'{dataset_name} distances')
    plt.xlabel('Distance')
    plt.ylabel('Frequency')
    plt.title(f'Distance Distribution for {dataset_name}')
    plt.legend()
    plt.show()

plot_results('FlatL2', flat_distances)
plot_results('IVF', ivf_distances)
plot_results('HNSW', hnsw_distances)