In [1]:
import json
import os
from tqdm import tqdm
import numpy as np
import torch
from sentence_transformers import SentenceTransformer, util
from sentence_transformers.cross_encoder import CrossEncoder

# GPU 설정
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [2]:
# Rule load = document = corpus
data = 'data/aihub_rules_prev.json'
with open(data, 'r', encoding="UTF-8") as j:
    aihub_rule = json.load(j)

normal_rule = aihub_rule['normal']
abnormal_rule = aihub_rule['abnormal']
combined_rules = normal_rule + abnormal_rule

In [3]:
# Caption load = discreption = query
root_path = '/data1/sliver/jwsuh/construction_dataset/aihub/llava/llava_image_result_with_obj'
files = sorted([i for i in os.listdir(root_path) if i.endswith('.json')])

files_captions = []
for idx, file in enumerate(tqdm(files)):
    with open(os.path.join(root_path, file), 'r', encoding="UTF-8") as j:
        caption = json.load(j)['outputs']
    files_captions.append(caption)
    

100%|██████████| 943/943 [00:00<00:00, 24657.92it/s]


In [4]:
def find_rank_of_answer_in_results(final_indices, answer_index):
    try:
        return final_indices.index(answer_index) + 1
    except ValueError:
        return -1

In [None]:
embedder_names = ["all-MiniLM-L6-v2", "all-distilroberta-v1", "multi-qa-mpnet-base-dot-v1", "msmarco-bert-base-dot-v5"]
model_names = ["cross-encoder/ms-marco-MiniLM-L-12-v2", "cross-encoder/stsb-roberta-large", "cross-encoder/qnli-electra-base", "cross-encoder/quora-roberta-large"]

for embedder_name in embedder_names:
    for model_name in model_names:
        print("embedder_names:", embedder_name)
        print("model_names:", model_name)
        embedder = SentenceTransformer(embedder_name)  # 우선 임시로 사용할 모델

        corpus_embeddings = embedder.encode(combined_rules, convert_to_tensor=True)
        query_embedding = embedder.encode(files_captions, convert_to_tensor=True)

        model = CrossEncoder(model_name, device=device, max_length=512)

        save_right_index = {}
        save_wrong_index = {}
        biencoder_and_cross_encoder_final_top_k_indices = []

        for final_top_k in [50]:
            correct = 0
            save_right_index[final_top_k] = {}
            save_wrong_index[final_top_k] = {}
            
            for idx, file in enumerate(tqdm(files)):
                with open(os.path.join(root_path, file), 'r', encoding="UTF-8") as j:
                    caption = json.load(j)
                
                tokenized_query = query_embedding[idx]
                initial_top_k_scores = util.dot_score(tokenized_query, corpus_embeddings)[0]
                top_results =  torch.topk(initial_top_k_scores, k=50)
                initial_top_k_docs = [combined_rules[i] for i in top_results[1]]
                
                sentence_combinations = [[caption['outputs'], doc] for doc in initial_top_k_docs]
                reranked_scores = model.predict(sentence_combinations)
                reranked_indices = np.argsort(reranked_scores)[::-1]
                final_top_k_indices = [top_results[1][i] for i in reranked_indices]
                
                biencoder_and_cross_encoder_final_top_k_indices.append(final_top_k_indices)
                
                answer = file.split('_')[2]
                if answer[0] == 'Y':
                    answer_index = int(answer[2:]) - 1
                elif answer[0] == 'N':
                    answer_index = int(answer[2:]) + 49
                    
                rank = find_rank_of_answer_in_results(final_top_k_indices, answer_index)

                if rank <= final_top_k and rank != -1:
                    correct += 1
                    save_right_index[final_top_k][idx] = final_top_k_indices
                else:
                    save_wrong_index[final_top_k][idx] = final_top_k_indices

            print(f"Top-{final_top_k} accuracy:", correct / len(files))
            
        for final_top_k in [1, 5, 10, 15, 20, 25, 30]:
            correct = 0
            
            for idx, file in enumerate(tqdm(files)):
                with open(os.path.join(root_path, file), 'r', encoding="UTF-8") as j:
                    caption = json.load(j)
                
                final_top_k_indices = biencoder_and_cross_encoder_final_top_k_indices[idx]

                answer = file.split('_')[2]
                if answer[0] == 'Y':
                    answer_index = int(answer[2:]) - 1
                elif answer[0] == 'N':
                    answer_index = int(answer[2:]) + 49
                    
                rank = find_rank_of_answer_in_results(final_top_k_indices, answer_index)

                if rank <= final_top_k and rank != -1:
                    correct += 1

            print(f"Top-{final_top_k} accuracy:", correct / len(files))

In [5]:
embedder_names = ["msmarco-distilbert-base-v4"]
model_names = ["cross-encoder/ms-marco-MiniLM-L-12-v2", "cross-encoder/stsb-roberta-large", "cross-encoder/qnli-electra-base", "cross-encoder/quora-roberta-large"]

for embedder_name in embedder_names:
    for model_name in model_names:
        print("embedder_names:", embedder_name)
        print("model_names:", model_name)
        embedder = SentenceTransformer(embedder_name)  # 우선 임시로 사용할 모델

        corpus_embeddings = embedder.encode(combined_rules, convert_to_tensor=True)
        query_embedding = embedder.encode(files_captions, convert_to_tensor=True)

        model = CrossEncoder(model_name, device=device, max_length=512)

        save_right_index = {}
        save_wrong_index = {}
        biencoder_and_cross_encoder_final_top_k_indices = []

        for final_top_k in [50]:
            correct = 0
            save_right_index[final_top_k] = {}
            save_wrong_index[final_top_k] = {}
            
            for idx, file in enumerate(tqdm(files)):
                with open(os.path.join(root_path, file), 'r', encoding="UTF-8") as j:
                    caption = json.load(j)
                
                tokenized_query = query_embedding[idx]
                initial_top_k_scores = util.cos_sim(tokenized_query, corpus_embeddings)[0]
                top_results =  torch.topk(initial_top_k_scores, k=50)
                initial_top_k_docs = [combined_rules[i] for i in top_results[1]]
                
                sentence_combinations = [[caption['outputs'], doc] for doc in initial_top_k_docs]
                reranked_scores = model.predict(sentence_combinations)
                reranked_indices = np.argsort(reranked_scores)[::-1]
                final_top_k_indices = [top_results[1][i] for i in reranked_indices]
                
                biencoder_and_cross_encoder_final_top_k_indices.append(final_top_k_indices)
                
                answer = file.split('_')[2]
                if answer[0] == 'Y':
                    answer_index = int(answer[2:]) - 1
                elif answer[0] == 'N':
                    answer_index = int(answer[2:]) + 49
                    
                rank = find_rank_of_answer_in_results(final_top_k_indices, answer_index)

                if rank <= final_top_k and rank != -1:
                    correct += 1
                    save_right_index[final_top_k][idx] = final_top_k_indices
                else:
                    save_wrong_index[final_top_k][idx] = final_top_k_indices

            print(f"Top-{final_top_k} accuracy:", correct / len(files))
            
        for final_top_k in [1, 5, 10, 15, 20, 25, 30]:
            correct = 0
            
            for idx, file in enumerate(tqdm(files)):
                with open(os.path.join(root_path, file), 'r', encoding="UTF-8") as j:
                    caption = json.load(j)
                
                final_top_k_indices = biencoder_and_cross_encoder_final_top_k_indices[idx]

                answer = file.split('_')[2]
                if answer[0] == 'Y':
                    answer_index = int(answer[2:]) - 1
                elif answer[0] == 'N':
                    answer_index = int(answer[2:]) + 49
                    
                rank = find_rank_of_answer_in_results(final_top_k_indices, answer_index)

                if rank <= final_top_k and rank != -1:
                    correct += 1

            print(f"Top-{final_top_k} accuracy:", correct / len(files))

embedder_names: msmarco-distilbert-base-v4
model_names: cross-encoder/ms-marco-MiniLM-L-12-v2


100%|██████████| 943/943 [02:15<00:00,  6.96it/s]


Top-50 accuracy: 0.9469777306468717


100%|██████████| 943/943 [00:00<00:00, 1146.63it/s]


Top-1 accuracy: 0.0911983032873807


100%|██████████| 943/943 [00:00<00:00, 1153.58it/s]


Top-5 accuracy: 0.27465535524920465


100%|██████████| 943/943 [00:00<00:00, 1191.49it/s]


Top-10 accuracy: 0.37857900318133614


100%|██████████| 943/943 [00:00<00:00, 1210.86it/s]


Top-15 accuracy: 0.43902439024390244


100%|██████████| 943/943 [00:00<00:00, 1239.79it/s]


Top-20 accuracy: 0.49204665959703076


100%|██████████| 943/943 [00:00<00:00, 1221.61it/s]


Top-25 accuracy: 0.5503711558854719


100%|██████████| 943/943 [00:00<00:00, 1211.47it/s]


Top-30 accuracy: 0.6076352067868505
embedder_names: msmarco-distilbert-base-v4
model_names: cross-encoder/stsb-roberta-large


100%|██████████| 943/943 [11:55<00:00,  1.32it/s]


Top-50 accuracy: 0.9469777306468717


100%|██████████| 943/943 [00:00<00:00, 1982.34it/s]


Top-1 accuracy: 0.20784729586426298


100%|██████████| 943/943 [00:00<00:00, 2103.31it/s]


Top-5 accuracy: 0.4740190880169671


100%|██████████| 943/943 [00:00<00:00, 1972.04it/s]


Top-10 accuracy: 0.591728525980912


100%|██████████| 943/943 [00:00<00:00, 1951.23it/s]


Top-15 accuracy: 0.6765641569459173


100%|██████████| 943/943 [00:00<00:00, 1993.65it/s]


Top-20 accuracy: 0.7423117709437964


100%|██████████| 943/943 [00:00<00:00, 1988.35it/s]


Top-25 accuracy: 0.8038176033934252


100%|██████████| 943/943 [00:00<00:00, 1983.11it/s]


Top-30 accuracy: 0.8568398727465536
embedder_names: msmarco-distilbert-base-v4
model_names: cross-encoder/qnli-electra-base


100%|██████████| 943/943 [04:15<00:00,  3.69it/s]


Top-50 accuracy: 0.9469777306468717


100%|██████████| 943/943 [00:01<00:00, 756.50it/s]


Top-1 accuracy: 0.010604453870625663


100%|██████████| 943/943 [00:01<00:00, 780.85it/s]


Top-5 accuracy: 0.05090137857900318


100%|██████████| 943/943 [00:01<00:00, 776.22it/s]


Top-10 accuracy: 0.07741251325556733


100%|██████████| 943/943 [00:01<00:00, 798.96it/s]


Top-15 accuracy: 0.11983032873807


100%|██████████| 943/943 [00:01<00:00, 785.01it/s]


Top-20 accuracy: 0.17709437963944857


100%|██████████| 943/943 [00:01<00:00, 762.38it/s]


Top-25 accuracy: 0.22799575821845175


100%|██████████| 943/943 [00:01<00:00, 762.38it/s]


Top-30 accuracy: 0.29374337221633084
embedder_names: msmarco-distilbert-base-v4
model_names: cross-encoder/quora-roberta-large


100%|██████████| 943/943 [11:54<00:00,  1.32it/s]


Top-50 accuracy: 0.9469777306468717


100%|██████████| 943/943 [00:00<00:00, 1791.80it/s]


Top-1 accuracy: 0.1633085896076352


100%|██████████| 943/943 [00:00<00:00, 1813.75it/s]


Top-5 accuracy: 0.42948038176033937


100%|██████████| 943/943 [00:00<00:00, 1740.36it/s]


Top-10 accuracy: 0.545068928950159


100%|██████████| 943/943 [00:00<00:00, 1804.45it/s]


Top-15 accuracy: 0.6352067868504772


100%|██████████| 943/943 [00:00<00:00, 1856.22it/s]


Top-20 accuracy: 0.704135737009544


100%|██████████| 943/943 [00:00<00:00, 1819.67it/s]


Top-25 accuracy: 0.7709437963944857


100%|██████████| 943/943 [00:00<00:00, 1839.70it/s]

Top-30 accuracy: 0.8207847295864263





In [6]:
embedder_names = ["msmarco-distilbert-base-tas-b"]
model_names = ["cross-encoder/ms-marco-MiniLM-L-12-v2", "cross-encoder/stsb-roberta-large", "cross-encoder/qnli-electra-base", "cross-encoder/quora-roberta-large"]

for embedder_name in embedder_names:
    for model_name in model_names:
        print("embedder_names:", embedder_name)
        print("model_names:", model_name)
        embedder = SentenceTransformer(embedder_name)  # 우선 임시로 사용할 모델

        corpus_embeddings = embedder.encode(combined_rules, convert_to_tensor=True)
        query_embedding = embedder.encode(files_captions, convert_to_tensor=True)

        model = CrossEncoder(model_name, device=device, max_length=512)

        save_right_index = {}
        save_wrong_index = {}
        biencoder_and_cross_encoder_final_top_k_indices = []

        for final_top_k in [50]:
            correct = 0
            save_right_index[final_top_k] = {}
            save_wrong_index[final_top_k] = {}
            
            for idx, file in enumerate(tqdm(files)):
                with open(os.path.join(root_path, file), 'r', encoding="UTF-8") as j:
                    caption = json.load(j)
                
                tokenized_query = query_embedding[idx]
                initial_top_k_scores = util.dot_score(tokenized_query, corpus_embeddings)[0]
                top_results =  torch.topk(initial_top_k_scores, k=50)
                initial_top_k_docs = [combined_rules[i] for i in top_results[1]]
                
                sentence_combinations = [[caption['outputs'], doc] for doc in initial_top_k_docs]
                reranked_scores = model.predict(sentence_combinations)
                reranked_indices = np.argsort(reranked_scores)[::-1]
                final_top_k_indices = [top_results[1][i] for i in reranked_indices]
                
                biencoder_and_cross_encoder_final_top_k_indices.append(final_top_k_indices)
                
                answer = file.split('_')[2]
                if answer[0] == 'Y':
                    answer_index = int(answer[2:]) - 1
                elif answer[0] == 'N':
                    answer_index = int(answer[2:]) + 49
                    
                rank = find_rank_of_answer_in_results(final_top_k_indices, answer_index)

                if rank <= final_top_k and rank != -1:
                    correct += 1
                    save_right_index[final_top_k][idx] = final_top_k_indices
                else:
                    save_wrong_index[final_top_k][idx] = final_top_k_indices

            print(f"Top-{final_top_k} accuracy:", correct / len(files))
            
        for final_top_k in [1, 5, 10, 15, 20, 25, 30]:
            correct = 0
            
            for idx, file in enumerate(tqdm(files)):
                with open(os.path.join(root_path, file), 'r', encoding="UTF-8") as j:
                    caption = json.load(j)
                
                final_top_k_indices = biencoder_and_cross_encoder_final_top_k_indices[idx]

                answer = file.split('_')[2]
                if answer[0] == 'Y':
                    answer_index = int(answer[2:]) - 1
                elif answer[0] == 'N':
                    answer_index = int(answer[2:]) + 49
                    
                rank = find_rank_of_answer_in_results(final_top_k_indices, answer_index)

                if rank <= final_top_k and rank != -1:
                    correct += 1

            print(f"Top-{final_top_k} accuracy:", correct / len(files))

embedder_names: msmarco-distilbert-base-tas-b
model_names: cross-encoder/ms-marco-MiniLM-L-12-v2


100%|██████████| 943/943 [02:14<00:00,  6.99it/s]


Top-50 accuracy: 0.9639448568398727


100%|██████████| 943/943 [00:00<00:00, 1091.39it/s]


Top-1 accuracy: 0.08907741251325557


100%|██████████| 943/943 [00:00<00:00, 1199.67it/s]


Top-5 accuracy: 0.2799575821845175


100%|██████████| 943/943 [00:00<00:00, 1196.33it/s]


Top-10 accuracy: 0.3711558854718982


100%|██████████| 943/943 [00:00<00:00, 1198.55it/s]


Top-15 accuracy: 0.44750795334040294


100%|██████████| 943/943 [00:00<00:00, 1214.35it/s]


Top-20 accuracy: 0.49946977730646874


100%|██████████| 943/943 [00:00<00:00, 1190.03it/s]


Top-25 accuracy: 0.5556733828207847


100%|██████████| 943/943 [00:00<00:00, 1265.07it/s]


Top-30 accuracy: 0.616118769883351
embedder_names: msmarco-distilbert-base-tas-b
model_names: cross-encoder/stsb-roberta-large


100%|██████████| 943/943 [11:55<00:00,  1.32it/s]


Top-50 accuracy: 0.9639448568398727


100%|██████████| 943/943 [00:00<00:00, 1986.10it/s]


Top-1 accuracy: 0.2067868504772004


100%|██████████| 943/943 [00:00<00:00, 2091.15it/s]


Top-5 accuracy: 0.4814422057264051


100%|██████████| 943/943 [00:00<00:00, 1974.94it/s]


Top-10 accuracy: 0.591728525980912


100%|██████████| 943/943 [00:00<00:00, 2051.11it/s]


Top-15 accuracy: 0.6808059384941676


100%|██████████| 943/943 [00:00<00:00, 2011.37it/s]


Top-20 accuracy: 0.7444326617179216


100%|██████████| 943/943 [00:00<00:00, 1819.81it/s]


Top-25 accuracy: 0.8048780487804879


100%|██████████| 943/943 [00:00<00:00, 2021.16it/s]


Top-30 accuracy: 0.8716861081654295
embedder_names: msmarco-distilbert-base-tas-b
model_names: cross-encoder/qnli-electra-base


100%|██████████| 943/943 [04:16<00:00,  3.68it/s]


Top-50 accuracy: 0.9639448568398727


100%|██████████| 943/943 [00:01<00:00, 811.60it/s]


Top-1 accuracy: 0.013785790031813362


100%|██████████| 943/943 [00:01<00:00, 805.68it/s]


Top-5 accuracy: 0.04665959703075292


100%|██████████| 943/943 [00:01<00:00, 777.83it/s]


Top-10 accuracy: 0.0784729586426299


100%|██████████| 943/943 [00:01<00:00, 787.20it/s]


Top-15 accuracy: 0.11876988335100742


100%|██████████| 943/943 [00:01<00:00, 769.91it/s]


Top-20 accuracy: 0.17815482502651114


100%|██████████| 943/943 [00:01<00:00, 791.72it/s]


Top-25 accuracy: 0.22799575821845175


100%|██████████| 943/943 [00:01<00:00, 793.33it/s]


Top-30 accuracy: 0.2948038176033934
embedder_names: msmarco-distilbert-base-tas-b
model_names: cross-encoder/quora-roberta-large


100%|██████████| 943/943 [11:55<00:00,  1.32it/s]


Top-50 accuracy: 0.9639448568398727


100%|██████████| 943/943 [00:00<00:00, 1764.86it/s]


Top-1 accuracy: 0.16436903499469777


100%|██████████| 943/943 [00:00<00:00, 1816.64it/s]


Top-5 accuracy: 0.42735949098621423


100%|██████████| 943/943 [00:00<00:00, 1856.43it/s]


Top-10 accuracy: 0.5365853658536586


100%|██████████| 943/943 [00:00<00:00, 1806.85it/s]


Top-15 accuracy: 0.6341463414634146


100%|██████████| 943/943 [00:00<00:00, 1860.19it/s]


Top-20 accuracy: 0.704135737009544


100%|██████████| 943/943 [00:00<00:00, 1800.93it/s]


Top-25 accuracy: 0.7762460233297985


100%|██████████| 943/943 [00:00<00:00, 1836.17it/s]

Top-30 accuracy: 0.8313891834570519



