In [16]:
import torch
import faiss
import time
from sklearn.metrics import f1_score

In [17]:
d = torch.load("gat_3.pt")
infer_target_mask = d["infer_target_mask"]
train_feats = d['features'][torch.logical_not(infer_target_mask)]
train_labels = d['labels'][torch.logical_not(infer_target_mask)]
train_embs = d['embeddings'][torch.logical_not(infer_target_mask)]
query_feats = d['features'][infer_target_mask]
query_labels = d['labels'][infer_target_mask]
feat_dim = train_feats.shape[1]

In [21]:
def cal_labels(y_pred, multilabel):
    if multilabel:
        y_pred[y_pred > 0] = 1
        y_pred[y_pred <= 0] = 0
        return y_pred
    else:
        y_pred = np.argmax(y_pred, axis=1)
        return y_pred

def get_acc(y_true, y_pred, multilabel):
    y_pred = cal_labels(y_pred, multilabel)
    return f1_score(y_true, y_pred, average="micro")

def build_index(train_feats, nlist, nprobe):
    res = faiss.StandardGpuResources()
    quantizer = faiss.IndexFlatL2(train_feats.shape[1])  # the other index
    index = faiss.IndexIVFFlat(quantizer, feat_dim, nlist)
    gpu_index = faiss.index_cpu_to_gpu(res, 0, index)
    gpu_index.train(train_feats)
    gpu_index.add(train_feats)
    gpu_index.nprobe = nprobe
    return gpu_index

index = build_index(train_feats, 1024, 1)

true_labels_arr = []
approx_embs_arr = []
approx_labels_arr = []
exec_time = 0
query_idx = 1024
while query_idx < query_feats.shape[0]:
    f = query_feats[query_idx-1024:query_idx]
    true_labels = query_labels[query_idx-1024:query_idx]
    start_t = time.time()
    r = index.search(f, 1)
    exec_time += time.time() - start_t
    approx_labels = train_labels[r[1].reshape(-1)]
    approx_embs = train_embs[r[1].reshape(-1)]
    approx_embs_arr.append(approx_embs)
    true_labels_arr.append(true_labels)
    approx_labels_arr.append(approx_labels)
    query_idx += 1024

true_labels = torch.concat(true_labels_arr)
approx_labels = torch.concat(approx_labels_arr)
approx_embs = torch.concat(approx_embs_arr)
exec_time = exec_time / len(true_labels_arr)
acc1 = f1_score(true_labels, approx_labels, average='micro')
acc2 = get_acc(true_labels, approx_embs, True)
print(f"exec_time={exec_time}, acc1={acc1}, acc2={acc2}")

exec_time=0.0024348848006304573, acc1=0.3898802698174548, acc2=0.44045697542382306
