In [3]:
from rank_bm25 import BM25Okapi
import json
import os
import nltk
from nltk.tokenize import word_tokenize
nltk.download('punkt')

# 데이터 로드
data = 'data/aihub_rules_prev.json'
with open(data, 'r', encoding="UTF-8") as j:
    aihub_rule = json.load(j)
    
normal_rule = aihub_rule['normal']
abnormal_rule = aihub_rule['abnormal']
combined_rules = normal_rule + abnormal_rule

# BM25 초기 검색 설정
tokenized_rules = [word_tokenize(rule) for rule in combined_rules]
bm25 = BM25Okapi(tokenized_rules)

def bm25_initial_search(query, top_k=50):
    tokenized_query = word_tokenize(query)
    top_docs = bm25.get_top_n(tokenized_query, combined_rules, n=top_k)
    top_indices = [combined_rules.index(doc) for doc in top_docs]
    return top_docs, top_indices

[nltk_data] Downloading package punkt to /home/sliver/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [4]:
import torch
from transformers import BertTokenizer, BertForSequenceClassification
from tqdm import tqdm

# GPU 설정
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# BERT Cross-encoder 모델과 토크나이저 로드
cross_encoder_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
cross_encoder_model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=1)
cross_encoder_model.to(device)
cross_encoder_model.eval()

# re-ranking 함수 정의
def cross_encoder_rerank(query, documents):
    inputs = [cross_encoder_tokenizer(query, doc, return_tensors='pt', truncation=True, padding=True, max_length=512) for doc in documents]
    scores = []
    with torch.no_grad():
        for input in inputs:
            input = {key: value.to(device) for key, value in input.items()}
            outputs = cross_encoder_model(**input)
            scores.append(outputs.logits.item())
    ranked_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)
    ranked_docs = [documents[i] for i in ranked_indices]
    return ranked_docs, ranked_indices

# re-ranking 예제
query = "example query"
initial_docs, initial_indices = bm25_initial_search(query, top_k=50)
reranked_docs, reranked_indices = cross_encoder_rerank(query, initial_docs)


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [11]:
max(initial_indices)

99

In [12]:

# re-ranking 예제
query = "14.1.1 it is wrong"
initial_docs, initial_indices = bm25_initial_search(query, top_k=50)
reranked_docs, reranked_indices = cross_encoder_rerank(query, initial_docs)

In [16]:
len(reranked_indices)

50

In [6]:
# discreption 가져오기
root_path = '/data1/sliver/jwsuh/construction_dataset/aihub/llava/llava_image_result_with_obj'
files = sorted([i for i in os.listdir(root_path) if i.endswith('.json')])

In [7]:
# top-K 리스트
list_k = [1, 5, 10, 15, 20, 25, 30, 50]
save_right_index = {}
save_wrong_index = {}

In [17]:
def find_rank_of_answer_in_dpr_results(final_indices_mapped, answer_index):
    try:
        return final_indices_mapped.index(answer_index) + 1
    except ValueError:
        return -1

In [19]:
for final_top_k in list_k:
    correct = 0
    save_right_index[final_top_k] = {}
    save_wrong_index[final_top_k] = {}
    ind = 0
    
    for file in tqdm(files):
        with open(os.path.join(root_path, file), 'r', encoding="UTF-8") as j:
            caption = json.load(j)
        initial_docs, initial_indices = bm25_initial_search(caption['outputs'], top_k=50)
        reranked_docs, reranked_indices = cross_encoder_rerank(caption['outputs'], initial_docs)
        
        final_indices_mapped = [initial_indices[i] for i in reranked_indices]
        
        answer = file.split('_')[2]
        if answer[0] == 'Y':
            answer_index = int(answer[2:]) - 1
        elif answer[0] == 'N':
            answer_index = int(answer[2:]) + 49
            
        rank = find_rank_of_answer_in_dpr_results(final_indices_mapped, answer_index)

        if rank <= final_top_k and rank != -1:
            correct += 1
            save_right_index[final_top_k][ind] = reranked_docs
        else:
            save_wrong_index[final_top_k][ind] = reranked_docs
            
        ind += 1

    print(f"Top-{final_top_k} accuracy:", correct / len(files))

  2%|▏         | 23/943 [00:33<22:35,  1.47s/it]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned fo