In [6]:
import json
import os
from tqdm import tqdm
import numpy as np
from rank_bm25 import BM25Okapi
from nltk.tokenize import word_tokenize
import nltk
nltk.download('punkt')

import torch

[nltk_data] Downloading package punkt to /home/sliver/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [7]:
# GPU 설정
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [8]:
# 데이터 로드
data = 'data/aihub_rules_prev.json'
with open(data, 'r', encoding="UTF-8") as j:
    aihub_rule = json.load(j)

normal_rule = aihub_rule['normal']
abnormal_rule = aihub_rule['abnormal']
combined_rules = normal_rule + abnormal_rule

In [9]:
# 데이터 로드 및 평가
root_path = '/data1/sliver/jwsuh/construction_dataset/aihub/llava/llava_image_result_with_obj'
files = sorted([i for i in os.listdir(root_path) if i.endswith('.json')])

# top-K 리스트
list_k = [1, 5, 10, 15, 20, 25, 30, 50]
save_right_index = {}
save_wrong_index = {}

def find_rank_of_answer_in_results(final_indices, answer_index):
    try:
        return final_indices.index(answer_index) + 1
    except ValueError:
        return -1

In [10]:
def tokenize(text):
    return word_tokenize(text.lower())

tokenized_rules = [tokenize(rule) for rule in combined_rules]

# BM25 모델 생성
bm25 = BM25Okapi(tokenized_rules)

In [12]:
from tqdm import tqdm
from sentence_transformers import SentenceTransformer, util

model = SentenceTransformer("all-MiniLM-L6-v2").to(device)

# 유사도 계산 함수 (Cross-encoder 사용)
def calculate_similarity(query, document):
    emb1 = model.encode(query)
    emb2 = model.encode(document)
    cos_sim = util.cos_sim(emb1, emb2)
    return cos_sim


def bm25_initial_search_and_rerank(query, top_k=50):
    tokenized_query = tokenize(query)
    initial_top_k_scores = bm25.get_scores(tokenized_query)
    initial_top_k_indices =  np.argsort(initial_top_k_scores)[::-1][:top_k]
    
    reranked_scores = [calculate_similarity(query, combined_rules[i]) for i in initial_top_k_indices]
    reranked_indices = np.argsort(reranked_scores)[::-1]
    final_top_k_indices = [initial_top_k_indices[i] for i in reranked_indices]
    
    return final_top_k_indices

In [13]:
for final_top_k in list_k:
    correct = 0
    save_right_index[final_top_k] = {}
    save_wrong_index[final_top_k] = {}
    
    for idx, file in enumerate(tqdm(files)):
        with open(os.path.join(root_path, file), 'r', encoding="UTF-8") as j:
            caption = json.load(j)
        
        final_top_k_indices = bm25_initial_search_and_rerank(caption['outputs'], top_k=50)[:final_top_k]
        
        answer = file.split('_')[2]
        if answer[0] == 'Y':
            answer_index = int(answer[2:]) - 1
        elif answer[0] == 'N':
            answer_index = int(answer[2:]) + 49
            
        rank = find_rank_of_answer_in_results(final_top_k_indices, answer_index)

        if rank <= final_top_k and rank != -1:
            correct += 1
            save_right_index[final_top_k][idx] = final_top_k_indices
        else:
            save_wrong_index[final_top_k][idx] = final_top_k_indices

    print(f"Top-{final_top_k} accuracy:", correct / len(files))

100%|██████████| 943/943 [31:13<00:00,  1.99s/it]


Top-1 accuracy: 0.03499469777306469


 88%|████████▊ | 828/943 [27:36<03:50,  2.00s/it]


KeyboardInterrupt: 

In [15]:
for key in save_wrong_index[10].keys():
    print("file name:", files[key])
    answer = files[key].split('_')[2]
    
    if answer[0] == 'Y':
        answer_index = int(answer[2:]) - 1
    elif answer[0] == 'N':
        answer_index = int(answer[2:]) + 49
    print(answer_index)
    for idx, value in enumerate(save_wrong_index[10][key]):
        print(f"rank-{idx+1:<2}", f"{combined_rules[value]:>4} {value} ")
    print("\n")

file name: H-220607_B16_Y-14_002_0029.json
13
rank-1  Working only on the top (or bottom) of the system scaffold 11 
rank-2  Simultaneous top and bottom work on system scaffold (two-person operation) 61 
rank-3  No fire extinguisher placement next to welding equipment 82 
rank-4  Material placement at the edge of a horse scaffold 92 
rank-5  Improper overloading of a Ladder Truck 66 
rank-6  Proper placement of materials and tools on a horse scaffold 42 
rank-7  Fire extinguisher placement next to welding equipment 32 
rank-8  rolling tower safety railing installation 6 
rank-9  No materials placed on the end of the formwork 14 
rank-10 Faulty concrete pump truck safety device installation 94 


file name: H-220609_A18_Y-04_001_0021.json
3
rank-1  Fire extinguisher placement next to welding equipment 32 
rank-2  No fire extinguisher placement next to welding equipment 82 
rank-3  Worker operating within the hazard radius of a dump truck 75 
rank-4  Improper overloading of a Ladder Truc