In [2]:
import json
import os
from tqdm import tqdm
import numpy as np
import torch
# GPU 설정
device = torch.device('cuda:7' if torch.cuda.is_available() else 'cpu')

In [3]:
# Rule load = document = corpus
data = 'data/aihub_rules_prev.json'
with open(data, 'r', encoding="UTF-8") as j:
    aihub_rule = json.load(j)

normal_rule = aihub_rule['normal']
abnormal_rule = aihub_rule['abnormal']
combined_rules = normal_rule + abnormal_rule

In [4]:
# Caption load = discreption = query
root_path = '/data1/sliver/jwsuh/construction_dataset/aihub/llava/llava_image_result_with_obj'
files = sorted([i for i in os.listdir(root_path) if i.endswith('.json')])

files_captions = []
for idx, file in enumerate(tqdm(files)):
    with open(os.path.join(root_path, file), 'r', encoding="UTF-8") as j:
        caption = json.load(j)['outputs']
    files_captions.append(caption)
    

100%|██████████| 943/943 [00:00<00:00, 19475.06it/s]


In [5]:
def find_rank_of_answer_in_results(final_indices, answer_index):
    try:
        return final_indices.index(answer_index) + 1
    except ValueError:
        return -1

In [11]:
from sentence_transformers import SentenceTransformer, util

embedder = SentenceTransformer("msmarco-distilbert-base-tas-b", device=device)  

corpus_embeddings = embedder.encode(combined_rules, convert_to_tensor=True)
query_embedding = embedder.encode(files_captions, convert_to_tensor=True)

modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/122 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/4.02k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/548 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/265M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/547 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

In [14]:
save_right_index = {}
save_wrong_index = {}
biencoder_final_top_k_indices = []

for final_top_k in [50]:
    correct = 0
    save_right_index[final_top_k] = {}
    save_wrong_index[final_top_k] = {}
    
    for idx, file in enumerate(tqdm(files)):
        with open(os.path.join(root_path, file), 'r', encoding="UTF-8") as j:
            caption = json.load(j)
        # query_embedding = embedder.encode(caption['outputs'], convert_to_tensor=True)
        tokenized_query = query_embedding[idx]
        initial_top_k_scores = util.cos_sim(tokenized_query, corpus_embeddings)[0]
        top_results =  torch.topk(initial_top_k_scores, k=50)
        final_top_k_indices = top_results[1].tolist()
        # final_top_k_indices = biencoder_initial_search(caption['outputs'], idx, top_k=50)
        biencoder_final_top_k_indices.append(final_top_k_indices)
        
        answer = file.split('_')[2]
        if answer[0] == 'Y':
            answer_index = int(answer[2:]) - 1
        elif answer[0] == 'N':
            answer_index = int(answer[2:]) + 49
            
        rank = find_rank_of_answer_in_results(final_top_k_indices, answer_index)

        if rank <= final_top_k and rank != -1:
            correct += 1
            save_right_index[final_top_k][idx] = final_top_k_indices
        else:
            save_wrong_index[final_top_k][idx] = final_top_k_indices

    print(f"Top-{final_top_k} accuracy:", correct / len(files))

100%|██████████| 943/943 [00:00<00:00, 2567.93it/s]

Top-50 accuracy: 0.9522799575821845





In [15]:
for final_top_k in [1, 5, 10, 15, 20, 25, 30]:
    correct = 0
    
    for idx, file in enumerate(tqdm(files)):
        with open(os.path.join(root_path, file), 'r', encoding="UTF-8") as j:
            caption = json.load(j)
        
        final_top_k_indices = biencoder_final_top_k_indices[idx]

        answer = file.split('_')[2]
        if answer[0] == 'Y':
            answer_index = int(answer[2:]) - 1
        elif answer[0] == 'N':
            answer_index = int(answer[2:]) + 49
            
        rank = find_rank_of_answer_in_results(final_top_k_indices, answer_index)

        if rank <= final_top_k and rank != -1:
            correct += 1

    print(f"Top-{final_top_k} accuracy:", correct / len(files))

100%|██████████| 943/943 [00:00<00:00, 23370.39it/s]


Top-1 accuracy: 0.2163308589607635


100%|██████████| 943/943 [00:00<00:00, 22688.69it/s]


Top-5 accuracy: 0.5906680805938495


100%|██████████| 943/943 [00:00<00:00, 22746.23it/s]


Top-10 accuracy: 0.7656415694591728


100%|██████████| 943/943 [00:00<00:00, 23484.87it/s]


Top-15 accuracy: 0.8324496288441146


100%|██████████| 943/943 [00:00<00:00, 22279.21it/s]


Top-20 accuracy: 0.8674443266171792


100%|██████████| 943/943 [00:00<00:00, 22273.57it/s]


Top-25 accuracy: 0.8939554612937434


100%|██████████| 943/943 [00:00<00:00, 21403.56it/s]

Top-30 accuracy: 0.9056203605514316



