In [None]:
import json
import os
from tqdm import tqdm
import numpy as np
import torch
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer

import torch

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
# 데이터 로드
data = 'data/aihub_rules_prev.json'
with open(data, 'r', encoding="UTF-8") as j:
    aihub_rule = json.load(j)

normal_rule = aihub_rule['normal']
abnormal_rule = aihub_rule['abnormal']
combined_rules = normal_rule + abnormal_rule

In [None]:
root_path = '/data1/sliver/jwsuh/construction_dataset/aihub/llava/llava_image_result_with_obj'
files = sorted([i for i in os.listdir(root_path) if i.endswith('.json')])

files_captions = []
for idx, file in enumerate(tqdm(files)):
    with open(os.path.join(root_path, file), 'r', encoding="UTF-8") as j:
        caption = json.load(j)['outputs']
    files_captions.append(caption)
    

In [None]:
# # DPR 모델과 토크나이저 로드
# dpr_question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained('facebook/dpr-question_encoder-single-nq-base')
# dpr_question_model = DPRQuestionEncoder.from_pretrained('facebook/dpr-question_encoder-single-nq-base')
# dpr_context_tokenizer = DPRContextEncoderTokenizer.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base')
# dpr_context_model = DPRContextEncoder.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base')

# dpr_question_model.to(device)
# dpr_context_model.to(device)

# # DPR 문서 임베딩 계산
# def encode_contexts_dpr(rules):
#     inputs = dpr_context_tokenizer(rules, padding=True, truncation=True, max_length=512, return_tensors='pt')
#     inputs = {key: value.to(device) for key, value in inputs.items()}
#     with torch.no_grad():
#         outputs = dpr_context_model(**inputs)
#     return outputs.pooler_output.cpu().numpy()

# # DPR 쿼리 임베딩 계산
# def encode_query_dpr(query):
#     inputs = dpr_question_tokenizer(query,  padding=True, return_tensors='pt', truncation=True, max_length=512)
#     inputs = {key: value.to(device) for key, value in inputs.items()}
#     with torch.no_grad():
#         outputs = dpr_question_model(**inputs)
#     return outputs.pooler_output.cpu().numpy()

# # 문서 임베딩 계산 및 저장
# context_embeddings_dpr = encode_contexts_dpr(combined_rules)
# torch.save(context_embeddings_dpr, 'context_embeddings_dpr.pt')
# query_embeddings_dpr = encode_query_dpr(files_captions)
# torch.save(query_embeddings_dpr, 'query_embeddings_dpr.pt')
# print("DPR 문서 및 쿼리 임베딩 저장 완료")

In [None]:
context_embeddings_dpr = torch.load('context_embeddings_dpr.pt')
query_embeddings_dpr = torch.load('query_embeddings_dpr.pt')

In [None]:
from tqdm import tqdm
from sentence_transformers.cross_encoder import CrossEncoder

model = CrossEncoder("cross-encoder/stsb-roberta-large", device=device, max_length=512)

def dpr_initial_search_and_rerank(query, query_idx, top_k=50):
    query_embedding = query_embeddings_dpr[query_idx]
    similarities = np.dot(query_embedding, context_embeddings_dpr.T).flatten()
    initial_top_k_indices = np.argsort(similarities)[::-1][:top_k]
    initial_top_k_docs = [combined_rules[i] for i in initial_top_k_indices]
    
    sentence_combinations = [[query, doc] for doc in initial_top_k_docs]
    reranked_scores = model.predict(sentence_combinations)
    reranked_indices = np.argsort(reranked_scores)[::-1]
    final_top_k_indices = [initial_top_k_indices[i] for i in reranked_indices]
    
    return final_top_k_indices

In [None]:
def find_rank_of_answer_in_results(final_indices, answer_index):
    try:
        return final_indices.index(answer_index) + 1
    except ValueError:
        return -1

In [None]:
save_right_index = {}
save_wrong_index = {}
bm25_and_cross_encoder_final_top_k_indices = []

for final_top_k in [50]:
    correct = 0
    save_right_index[final_top_k] = {}
    save_wrong_index[final_top_k] = {}
    
    for idx, file in enumerate(tqdm(files)):
        with open(os.path.join(root_path, file), 'r', encoding="UTF-8") as j:
            caption = json.load(j)
        
        final_top_k_indices = dpr_initial_search_and_rerank(caption['outputs'], idx, top_k=50)
        bm25_and_cross_encoder_final_top_k_indices.append(final_top_k_indices)
        
        answer = file.split('_')[2]
        if answer[0] == 'Y':
            answer_index = int(answer[2:]) - 1
        elif answer[0] == 'N':
            answer_index = int(answer[2:]) + 49
            
        rank = find_rank_of_answer_in_results(final_top_k_indices, answer_index)

        if rank <= final_top_k and rank != -1:
            correct += 1
            save_right_index[final_top_k][idx] = final_top_k_indices
        else:
            save_wrong_index[final_top_k][idx] = final_top_k_indices

    print(f"Top-{final_top_k} accuracy:", correct / len(files))

In [None]:
for final_top_k in [1, 5, 10, 15, 20, 25, 30]:
    correct = 0
    
    for idx, file in enumerate(tqdm(files)):
        with open(os.path.join(root_path, file), 'r', encoding="UTF-8") as j:
            caption = json.load(j)
        
        final_top_k_indices = bm25_and_cross_encoder_final_top_k_indices[idx]

        answer = file.split('_')[2]
        if answer[0] == 'Y':
            answer_index = int(answer[2:]) - 1
        elif answer[0] == 'N':
            answer_index = int(answer[2:]) + 49
            
        rank = find_rank_of_answer_in_results(final_top_k_indices, answer_index)

        if rank <= final_top_k and rank != -1:
            correct += 1

    print(f"Top-{final_top_k} accuracy:", correct / len(files))