In [7]:
import json
import os
from tqdm import tqdm
import numpy as np
import torch
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer

In [8]:
# GPU 설정
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [9]:
# 데이터 로드
data = 'data/aihub_rules_prev.json'
with open(data, 'r', encoding="UTF-8") as j:
    aihub_rule = json.load(j)

normal_rule = aihub_rule['normal']
abnormal_rule = aihub_rule['abnormal']
combined_rules = normal_rule + abnormal_rule

In [10]:
# 데이터 로드 및 평가
root_path = '/data1/sliver/jwsuh/construction_dataset/aihub/llava/llava_image_result_with_obj'
files = sorted([i for i in os.listdir(root_path) if i.endswith('.json')])

# top-K 리스트
list_k = [1, 5, 10, 15, 20, 25, 30, 50]
save_right_index = {}
save_wrong_index = {}

def find_rank_of_answer_in_results(final_indices, answer_index):
    try:
        return final_indices.index(answer_index) + 1
    except ValueError:
        return -1

In [11]:
# DPR 모델과 토크나이저 로드
dpr_question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained('facebook/dpr-question_encoder-single-nq-base')
dpr_question_model = DPRQuestionEncoder.from_pretrained('facebook/dpr-question_encoder-single-nq-base')
dpr_context_tokenizer = DPRContextEncoderTokenizer.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base')
dpr_context_model = DPRContextEncoder.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base')

dpr_question_model.to(device)
dpr_context_model.to(device)

# DPR 문서 임베딩 계산
def encode_contexts_dpr(rules):
    inputs = dpr_context_tokenizer(rules, padding=True, truncation=True, max_length=512, return_tensors='pt')
    inputs = {key: value.to(device) for key, value in inputs.items()}
    with torch.no_grad():
        outputs = dpr_context_model(**inputs)
    return outputs.pooler_output.cpu().numpy()

# DPR 쿼리 임베딩 계산
def encode_query_dpr(query):
    inputs = dpr_question_tokenizer(query, return_tensors='pt', truncation=True, max_length=512)
    inputs = {key: value.to(device) for key, value in inputs.items()}
    with torch.no_grad():
        outputs = dpr_question_model(**inputs)
    return outputs.pooler_output.cpu().numpy()

# 문서 임베딩 계산 및 저장
context_embeddings_dpr = encode_contexts_dpr(combined_rules)
# np.save('context_embeddings_dpr.npy', context_embeddings_dpr)
# print("DPR 문서 임베딩 저장 완료")

Some weights of the model checkpoint at facebook/dpr-question_encoder-single-nq-base were not used when initializing DPRQuestionEncoder: ['question_encoder.bert_model.pooler.dense.bias', 'question_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRQuestionEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRQuestionEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DPRQuestionEncoderTokenizer'. 
The class this function is called from is 'DPRCon

In [12]:
from tqdm import tqdm
from sentence_transformers import SentenceTransformer, util

model = SentenceTransformer("all-MiniLM-L6-v2")

# 유사도 계산 함수 (Cross-encoder 사용)
def calculate_similarity(query, document):
    emb1 = model.encode(query)
    emb2 = model.encode(document)
    cos_sim = util.cos_sim(emb1, emb2)
    return cos_sim

# 초기 검색 결과 및 re-ranking 함수
def initial_search_and_rerank(query, top_k=50):
    query_embedding = encode_query_dpr(query)
    # context_embeddings_dpr = np.load('context_embeddings_dpr.npy')
    
    similarities = np.dot(query_embedding, context_embeddings_dpr.T).flatten()
    initial_top_k_indices = np.argsort(similarities)[::-1][:top_k]
    initial_top_k_docs = [combined_rules[i] for i in initial_top_k_indices]
    
    reranked_scores = [calculate_similarity(query, doc) for doc in initial_top_k_docs]
    reranked_indices = np.argsort(reranked_scores)[::-1]
    final_top_k_indices = [initial_top_k_indices[i] for i in reranked_indices]
    
    return final_top_k_indices

In [13]:
for final_top_k in list_k:
    correct = 0
    save_right_index[final_top_k] = {}
    save_wrong_index[final_top_k] = {}
    
    for idx, file in enumerate(tqdm(files)):
        with open(os.path.join(root_path, file), 'r', encoding="UTF-8") as j:
            caption = json.load(j)
        
        final_top_k_indices = initial_search_and_rerank(caption['outputs'], top_k=50)[:final_top_k]
        
        answer = file.split('_')[2]
        if answer[0] == 'Y':
            answer_index = int(answer[2:]) - 1
        elif answer[0] == 'N':
            answer_index = int(answer[2:]) + 49
            
        rank = find_rank_of_answer_in_results(final_top_k_indices, answer_index)

        if rank <= final_top_k and rank != -1:
            correct += 1
            save_right_index[final_top_k][idx] = final_top_k_indices
        else:
            save_wrong_index[final_top_k][idx] = final_top_k_indices
            

    print(f"Top-{final_top_k} accuracy:", correct / len(files))

100%|██████████| 943/943 [32:12<00:00,  2.05s/it]


Top-1 accuracy: 0.06574761399787911


100%|██████████| 943/943 [32:27<00:00,  2.07s/it]


Top-5 accuracy: 0.06574761399787911


  2%|▏         | 22/943 [00:46<32:36,  2.12s/it]


KeyboardInterrupt: 