In [1]:
import json
import os
from tqdm import tqdm
import numpy as np
import torch
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer

device = torch.device('cuda:7' if torch.cuda.is_available() else 'cpu')

In [2]:
# 데이터 로드
data = 'data/aihub_rules_prev.json'
with open(data, 'r', encoding="UTF-8") as j:
    aihub_rule = json.load(j)

normal_rule = aihub_rule['normal']
abnormal_rule = aihub_rule['abnormal']
combined_rules = normal_rule + abnormal_rule

In [3]:
root_path = '/data1/sliver/jwsuh/construction_dataset/aihub/llava/llava_image_result_with_obj'
files = sorted([i for i in os.listdir(root_path) if i.endswith('.json')])

files_captions = []
for idx, file in enumerate(tqdm(files)):
    with open(os.path.join(root_path, file), 'r', encoding="UTF-8") as j:
        caption = json.load(j)['outputs']
    files_captions.append(caption)
    

100%|██████████| 943/943 [00:00<00:00, 23996.39it/s]


In [4]:
context_embeddings_dpr = torch.load('context_embeddings_dpr.pt')
query_embeddings_dpr = torch.load('query_embeddings_dpr.pt')

In [5]:
from tqdm import tqdm
from sentence_transformers.cross_encoder import CrossEncoder

def dpr_initial_search_and_rerank(query, query_idx, top_k=50):
    query_embedding = query_embeddings_dpr[query_idx]
    similarities = np.dot(query_embedding, context_embeddings_dpr.T).flatten()
    initial_top_k_indices = np.argsort(similarities)[::-1][:top_k]
    initial_top_k_docs = [combined_rules[i] for i in initial_top_k_indices]
    
    sentence_combinations = [[query, doc] for doc in initial_top_k_docs]
    reranked_scores = model.predict(sentence_combinations)
    reranked_indices = np.argsort(reranked_scores)[::-1]
    final_top_k_indices = [initial_top_k_indices[i] for i in reranked_indices]
    
    return final_top_k_indices

In [6]:
def find_rank_of_answer_in_results(final_indices, answer_index):
    try:
        return final_indices.index(answer_index) + 1
    except ValueError:
        return -1

In [7]:
model_names = ["cross-encoder/qnli-electra-base", "cross-encoder/quora-roberta-large"]

In [8]:
for model_name in model_names:
    print(model_name)
    model = CrossEncoder(model_name, device=device, max_length=512)

    save_right_index = {}
    save_wrong_index = {}
    bm25_and_cross_encoder_final_top_k_indices = []

    for final_top_k in [50]:
        correct = 0
        save_right_index[final_top_k] = {}
        save_wrong_index[final_top_k] = {}
        
        for idx, file in enumerate(tqdm(files)):
            with open(os.path.join(root_path, file), 'r', encoding="UTF-8") as j:
                caption = json.load(j)
            
            final_top_k_indices = dpr_initial_search_and_rerank(caption['outputs'], idx, top_k=50)
            bm25_and_cross_encoder_final_top_k_indices.append(final_top_k_indices)
            
            answer = file.split('_')[2]
            if answer[0] == 'Y':
                answer_index = int(answer[2:]) - 1
            elif answer[0] == 'N':
                answer_index = int(answer[2:]) + 49
                
            rank = find_rank_of_answer_in_results(final_top_k_indices, answer_index)

            if rank <= final_top_k and rank != -1:
                correct += 1
                save_right_index[final_top_k][idx] = final_top_k_indices
            else:
                save_wrong_index[final_top_k][idx] = final_top_k_indices

        print(f"Top-{final_top_k} accuracy:", correct / len(files))
        
    for final_top_k in [1, 5, 10, 15, 20, 25, 30]:
        correct = 0
        
        for idx, file in enumerate(tqdm(files)):
            with open(os.path.join(root_path, file), 'r', encoding="UTF-8") as j:
                caption = json.load(j)
            
            final_top_k_indices = bm25_and_cross_encoder_final_top_k_indices[idx]

            answer = file.split('_')[2]
            if answer[0] == 'Y':
                answer_index = int(answer[2:]) - 1
            elif answer[0] == 'N':
                answer_index = int(answer[2:]) + 49
                
            rank = find_rank_of_answer_in_results(final_top_k_indices, answer_index)

            if rank <= final_top_k and rank != -1:
                correct += 1

        print(f"Top-{final_top_k} accuracy:", correct / len(files))

cross-encoder/qnli-electra-base


100%|██████████| 943/943 [04:53<00:00,  3.21it/s]


Top-50 accuracy: 0.8197242841993637


100%|██████████| 943/943 [00:00<00:00, 18591.14it/s]


Top-1 accuracy: 0.009544008483563097


100%|██████████| 943/943 [00:00<00:00, 19312.64it/s]


Top-5 accuracy: 0.04029692470837752


100%|██████████| 943/943 [00:00<00:00, 19183.28it/s]


Top-10 accuracy: 0.05726405090137858


100%|██████████| 943/943 [00:00<00:00, 19703.14it/s]


Top-15 accuracy: 0.08589607635206786


100%|██████████| 943/943 [00:00<00:00, 19744.26it/s]


Top-20 accuracy: 0.12937433722163308


100%|██████████| 943/943 [00:00<00:00, 19497.72it/s]


Top-25 accuracy: 0.1728525980911983


100%|██████████| 943/943 [00:00<00:00, 16811.02it/s]


Top-30 accuracy: 0.21845174973488865
cross-encoder/quora-roberta-large


100%|██████████| 943/943 [12:37<00:00,  1.24it/s]


Top-50 accuracy: 0.8197242841993637


100%|██████████| 943/943 [00:00<00:00, 19857.56it/s]


Top-1 accuracy: 0.17179215270413573


100%|██████████| 943/943 [00:00<00:00, 21113.13it/s]


Top-5 accuracy: 0.383881230116649


100%|██████████| 943/943 [00:00<00:00, 20303.32it/s]


Top-10 accuracy: 0.48568398727465534


100%|██████████| 943/943 [00:00<00:00, 19517.34it/s]


Top-15 accuracy: 0.5673382820784729


100%|██████████| 943/943 [00:00<00:00, 21288.48it/s]


Top-20 accuracy: 0.6277836691410392


100%|██████████| 943/943 [00:00<00:00, 21553.32it/s]


Top-25 accuracy: 0.6818663838812301


100%|██████████| 943/943 [00:00<00:00, 21718.18it/s]


Top-30 accuracy: 0.7242841993637328


In [9]:
from rank_bm25 import BM25Okapi
from nltk.tokenize import word_tokenize
import nltk

def tokenize(text):
    return word_tokenize(text.lower())

tokenized_rules = [tokenize(rule) for rule in combined_rules]

# BM25 모델 생성
bm25 = BM25Okapi(tokenized_rules)

In [10]:
def bm25_initial_search_and_rerank(query, top_k=50):
    tokenized_query = tokenize(query)
    initial_top_k_scores = bm25.get_scores(tokenized_query)
    initial_top_k_indices =  np.argsort(initial_top_k_scores)[::-1][:top_k]
    initial_top_k_docs = [combined_rules[i] for i in initial_top_k_indices]
    
    sentence_combinations = [[query, doc] for doc in initial_top_k_docs]
    reranked_scores = model.predict(sentence_combinations)
    reranked_indices = np.argsort(reranked_scores)[::-1]
    final_top_k_indices = [initial_top_k_indices[i] for i in reranked_indices]
    
    return final_top_k_indices

In [11]:
for model_name in model_names:
    print(model_name)
    model = CrossEncoder(model_name, device=device, max_length=512)

    save_right_index = {}
    save_wrong_index = {}
    bm25_and_cross_encoder_final_top_k_indices = []

    for final_top_k in [50]:
        correct = 0
        save_right_index[final_top_k] = {}
        save_wrong_index[final_top_k] = {}
        
        for idx, file in enumerate(tqdm(files)):
            with open(os.path.join(root_path, file), 'r', encoding="UTF-8") as j:
                caption = json.load(j)
            
            final_top_k_indices = bm25_initial_search_and_rerank(caption['outputs'], top_k=50)
            bm25_and_cross_encoder_final_top_k_indices.append(final_top_k_indices)
            
            answer = file.split('_')[2]
            if answer[0] == 'Y':
                answer_index = int(answer[2:]) - 1
            elif answer[0] == 'N':
                answer_index = int(answer[2:]) + 49
                
            rank = find_rank_of_answer_in_results(final_top_k_indices, answer_index)

            if rank <= final_top_k and rank != -1:
                correct += 1
                save_right_index[final_top_k][idx] = final_top_k_indices
            else:
                save_wrong_index[final_top_k][idx] = final_top_k_indices

        print(f"Top-{final_top_k} accuracy:", correct / len(files))
        
    for final_top_k in [1, 5, 10, 15, 20, 25, 30]:
        correct = 0
        
        for idx, file in enumerate(tqdm(files)):
            with open(os.path.join(root_path, file), 'r', encoding="UTF-8") as j:
                caption = json.load(j)
            
            final_top_k_indices = bm25_and_cross_encoder_final_top_k_indices[idx]

            answer = file.split('_')[2]
            if answer[0] == 'Y':
                answer_index = int(answer[2:]) - 1
            elif answer[0] == 'N':
                answer_index = int(answer[2:]) + 49
                
            rank = find_rank_of_answer_in_results(final_top_k_indices, answer_index)

            if rank <= final_top_k and rank != -1:
                correct += 1

        print(f"Top-{final_top_k} accuracy:", correct / len(files))

cross-encoder/qnli-electra-base


100%|██████████| 943/943 [04:32<00:00,  3.46it/s]


Top-50 accuracy: 0.9406150583244963


100%|██████████| 943/943 [00:00<00:00, 18180.95it/s]


Top-1 accuracy: 0.010604453870625663


100%|██████████| 943/943 [00:00<00:00, 19012.32it/s]


Top-5 accuracy: 0.043478260869565216


100%|██████████| 943/943 [00:00<00:00, 18735.30it/s]


Top-10 accuracy: 0.07211028632025451


100%|██████████| 943/943 [00:00<00:00, 18758.23it/s]


Top-15 accuracy: 0.10604453870625663


100%|██████████| 943/943 [00:00<00:00, 19344.28it/s]


Top-20 accuracy: 0.1569459172852598


100%|██████████| 943/943 [00:00<00:00, 19446.72it/s]


Top-25 accuracy: 0.19936373276776245


100%|██████████| 943/943 [00:00<00:00, 18896.24it/s]


Top-30 accuracy: 0.2545068928950159
cross-encoder/quora-roberta-large


100%|██████████| 943/943 [12:15<00:00,  1.28it/s]


Top-50 accuracy: 0.9406150583244963


100%|██████████| 943/943 [00:00<00:00, 22766.13it/s]


Top-1 accuracy: 0.16967126193001061


100%|██████████| 943/943 [00:00<00:00, 22177.03it/s]


Top-5 accuracy: 0.45705196182396607


100%|██████████| 943/943 [00:00<00:00, 22384.37it/s]


Top-10 accuracy: 0.6002120890774125


100%|██████████| 943/943 [00:00<00:00, 22161.12it/s]


Top-15 accuracy: 0.6903499469777307


100%|██████████| 943/943 [00:00<00:00, 21610.56it/s]


Top-20 accuracy: 0.7348886532343585


100%|██████████| 943/943 [00:00<00:00, 22329.02it/s]


Top-25 accuracy: 0.7857900318133616


100%|██████████| 943/943 [00:00<00:00, 20898.50it/s]

Top-30 accuracy: 0.8144220572640509



