In [None]:
import os
import argparse
from ruamel.yaml import YAML
from pathlib import Path
import utils
import json
from vqaTools.vqa import VQA
import datetime
import csv
from dataset.utils import vqa_eval, save_result, load_json
import time

In [2]:
import torch
import torch.backends.cudnn as cudnn
from accelerate import init_empty_weights, dispatch_model, infer_auto_device_map, load_checkpoint_and_dispatch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM
from huggingface_hub import snapshot_download
import en_core_web_sm
nlp = en_core_web_sm.load()
os.environ["TOKENIZERS_PARALLELISM"] = "true"

In [3]:
def update(params, args):
    params['min_answer_length'] = args.min_answer_length
    params['max_answer_length'] = args.max_answer_length
    params['model_selection'] = args.model_selection
    params['dist_selection'] = args.dist_selection

    params['dataset'] = args.dataset
    params['split_seed'] = args.split_seed
    params['num_sample'] = args.num_sample
    params['output_dir'] = args.output_dir
    params['test_server'] = args.test_server

    params['num_caps_per_img'] = args.num_caps_per_img
    params['num_question_per_img'] = args.num_question_per_img
    params['caption_file'] = args.caption_file

    params['question_file'] = args.question_file
    params['question_ppl_file'] = args.question_ppl_file
    params['ans_dict_file'] = args.ans_dict_file

    params['question_type'] = args.question_type

    params['random_question'] = args.random_question
    params['result_tag'] = args.result_tag
    params['evaluate_direct'] = args.evaluate_direct
    params['resume'] = args.resume

    return params

In [4]:
def create_cap_dic(caption_data):
    cap = []
    que_id = []
    for i in caption_data:
        que_id.append(i['question_id'])
        if isinstance(i['caption'], list):
            total_caption_list = []
            for ctx_id, cap_ in enumerate(i['caption'][:100]):
                total_caption_list.append((cap_.capitalize().strip()).rstrip()+".")
            cap.append(total_caption_list)
        else:
            raise NotImplementedError()
    caption_dict = dict(zip(que_id, cap))
    return caption_dict

In [5]:
def create_ans_to_cap_dic(ans_to_cap_data):
    que_id = []
    ans_dicts = []

    for i in ans_to_cap_data:
        que_id.append(i['question_id'])
        if 'ans_to_cap_dict' not in i.keys():
            key = 'tag'
        else:
            key = 'ans_to_cap_dict'
        if isinstance(i[key], dict):
                ans_dicts.append(i[key])
        else:
            raise NotImplementedError()
    ans_to_cap_dicts = dict(zip(que_id, ans_dicts))
    return ans_to_cap_dicts

In [6]:
def create_generated_question_dic(question_data):
    que_id = []
    syn_question = []
    syn_answer = []
    que_id = []
    ans_dicts = []

    for i in question_data:
        que_id.append(i['question_id'])
        if isinstance(i['question'], list):
            total_syn_question_list = []
            for ctx_id, syn_question_ in enumerate(i['question']):
                total_syn_question_list.append(syn_question_.capitalize().strip().rstrip())
            syn_question.append(total_syn_question_list)
        else:
            raise NotImplementedError()
        if isinstance(i['answer'], list):
            total_syn_answer_list = []
            for ctx_id, syn_answer_ in enumerate(i['answer']):
                total_syn_answer_list.append(syn_answer_.capitalize().strip().rstrip())
            syn_answer.append(total_syn_answer_list)
        else:
            raise NotImplementedError()
    syn_question_dict = dict(zip(que_id, syn_question))
    syn_answer_dict = dict(zip(que_id, syn_answer))

    return syn_question_dict,syn_answer_dict

In [7]:
# os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1" #this work together with dist barrier timeout
parser = argparse.ArgumentParser()
parser.add_argument('--config', default='./configs/AOKVQA_caption.yaml')
parser.add_argument('--caption_file', default='/home/mingyu/projects/Img2Prompt/caption_question_files/aokvqa_val_caption.json')
parser.add_argument('--question_file', default='/home/mingyu/projects/Img2Prompt/caption_question_files/aokvqa_val_question.json')
parser.add_argument('--question_ppl_file', default=None)
parser.add_argument('--ans_dict_file', default='/home/mingyu/projects/Img2Prompt/caption_question_files/aokvqa_val_ans_to_cap_dict.json')
parser.add_argument('--question_type', default='g_q', type=str)

parser.add_argument('--output_dir', default='output/VQA_caption')
parser.add_argument('--resume', action='store_true')

parser.add_argument('--evaluate_direct', action='store_true')

parser.add_argument('--evaluate', action='store_true')
parser.add_argument('--vqa_eval', action='store_true')

parser.add_argument('--device', default='cuda')
parser.add_argument('--seed', default=42, type=int)
parser.add_argument('--split_seed', default=0, type=int)
parser.add_argument('--num_sample', default=16, type=int)
parser.add_argument('--ensemble', default=1, type=int)
parser.add_argument('--random_question', action='store_true')
parser.add_argument('--test_server', action='store_true')



parser.add_argument('--model_selection', default='opt-6.7b', type=str)
parser.add_argument('--dist_selection', default='hugging', type=str)
parser.add_argument('--select_cap', action='store_true')

parser.add_argument('--dataset', default='vqa_caption', type=str)
parser.add_argument('--result_tag', default='', type=str)

parser.add_argument('--batch_size_test', default=64, type=int)


parser.add_argument('--num_caps_per_img', default=30, type=int)
parser.add_argument('--num_question_per_img', default=30, type=int)

parser.add_argument('--min_answer_length', default=1, type=int,
                    help='min answer length during inference (generate); '
                         'None  == self.model.config.min_length (0 for t0)')
parser.add_argument('--max_answer_length', default=10, type=int,
                    help='max answer length during inference (generate); '
                         'None  == self.model.config.max_length (20 for t0)')
args = parser.parse_args(args=[])

In [8]:
yaml = YAML(typ='rt')

config = yaml.load(open(args.config, 'r'))
# config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
config = update(config, args)
args.result_dir = os.path.join(args.output_dir, 'result')

Path(args.output_dir).mkdir(parents=True, exist_ok=True)
Path(args.result_dir).mkdir(parents=True, exist_ok=True)

yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
logger, writer = utils.setup_default_logging(args)

In [None]:
#### Dataset ####
print("Creating vqa datasets")
test_data = []
for f in config['test_file']:
    test_data = json.load(open(f, 'r'))

caption_data = json.load(open(config['caption_file'], 'r'))
quesID_to_cap_dict = create_cap_dic(caption_data)

question_data = json.load(open(config['question_file'], 'r'))
quesID_to_ques_data,syn_answer_dict = create_generated_question_dic(question_data)  # synthetic question, synthetic answer

ans_dict_data = json.load(open(config['ans_dict_file'], 'r'))
ans_to_cap_dicts = create_ans_to_cap_dic(ans_dict_data)


In [None]:
if config['val_ann_path']:
    vqa = VQA(config['val_ann_path'], config['val_ques_path'])


result_filename = config['result_tag']+'_'+config['dataset']+'_'+config['model_selection']+'_'+config['dist_selection'] + 'caps'+str(config['num_caps_per_img']) +'_question'+ str(config['num_question_per_img'])+'_questiontype'+'_'+config['question_type']
start_time = time.time()

In [11]:
metric_logger = utils.MetricLogger(delimiter="  ", logger=logger)
header = 'Generate VQA test result:'
print_freq = 1000
result = []
tested_quesId_dict = {}
# print(result)
for tested_dict in result:
    # print(result)
    if tested_dict['answer'] is not None:
        tested_quesId_dict[tested_dict['question_id']] = 1
# print(test_data)

In [12]:
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM
from huggingface_hub import snapshot_download, notebook_login
import huggingface_hub
import os

import networkx as nx
import matplotlib.pyplot as plt
from sentence_transformers import SentenceTransformer, util

In [13]:
# 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
os.environ["CUDA_VISIBLE_DEVICES"]= "0"

# 모델 로드
model = SentenceTransformer('all-MiniLM-L6-v2')

In [14]:
from tqdm import tqdm

In [None]:
# for n, per_test_data in enumerate(metric_logger.log_every(test_data, print_freq, header)):
for n, per_test_data in enumerate(test_data):

    # if n <= 1340: continue
    # print(n)
    # break
    kb_score_dict = {}
    question = per_test_data['question'].lower().strip()
    question_id = per_test_data['question_id']
    # print(question_id, question)
    
    selected_kb_list = []
    
    file_path = f"./knowledge_files/aokvqa_pagerank_kb/{question_id}.json"
    if os.path.exists(file_path):
        print(f"File {file_path} exists.")
        continue
    else:
        # with open(f"../cluster_generated_kb/{question_id}.json", "r") as f:
        with open(f"/home/hsh/DKSVQA-main/knowledge_files/aokvqa_cluster_generated_kb/{question_id}.json", "r") as f:
            kb = json.load(f)

        kb = list(set(kb))        
        
        question_embedding = model.encode(question, convert_to_tensor=True)
        knowledge_embeddings = model.encode(kb, convert_to_tensor=True)
        
        cosine_scores = util.pytorch_cos_sim(question_embedding, knowledge_embeddings)
        
        G = nx.Graph()
        
        G.add_node("question", label=question)  # 성능 잘 나왔을 때 이부분 question으로 적음
        # Knowledge 문장 노드 및 엣지 추가 (유사도 임계값 설정)
        threshold = 0.5  # 유사도 임계값 (0.25 / 0.5 / 0.75)
        
        for idx, sentence in enumerate(kb):
            similarity = cosine_scores[0][idx].item()
            if similarity >= threshold:
                G.add_node(f"Knowledge_{idx}", label=sentence)
                G.add_edge("Question", f"Knowledge_{idx}", weight=similarity)
                
        for i in range(len(kb)):
            for j in range(i + 1, len(kb)):
                # 각 Knowledge 문장들 간의 유사도 계산
                similarity = util.pytorch_cos_sim(knowledge_embeddings[i], knowledge_embeddings[j]).item()
                if similarity >= threshold:  # 유사도 임계값 설정
                    G.add_edge(f"Knowledge_{i}", f"Knowledge_{j}", weight=similarity)
        
        # connected_nodes = []  # Question과 연결된 Knowledge 문장 인덱스 저장
        # for idx, sentence in enumerate(kb):
        #     similarity = cosine_scores[0][idx].item()
        #     if similarity >= threshold:
        #         node_name = f"Knowledge_{idx}"
        #         G.add_node(node_name, label=sentence)
        #         G.add_edge("Question", node_name, weight=similarity)
        #         connected_nodes.append(idx)  # 연결된 노드 인덱스 저장

        # # Step 4: Question과 연결된 Knowledge 문장들 간의 유사도 계산 및 엣지 추가
        # for i in range(len(connected_nodes)):
        #     for j in range(i + 1, len(connected_nodes)):
        #         idx_i = connected_nodes[i]
        #         idx_j = connected_nodes[j]
        
        #         # 연결된 Knowledge 문장들 간의 유사도 계산
        #         similarity = util.pytorch_cos_sim(knowledge_embeddings[idx_i], knowledge_embeddings[idx_j]).item()
        #         if similarity >= threshold:  # 임계값 이상인 경우에만 엣지 추가
        #             G.add_edge(f"Knowledge_{idx_i}", f"Knowledge_{idx_j}", weight=similarity)
        
        # Step 5: Personalized PageRank 계산
        # personalization = {node: 0 for node in G.nodes()}
        # personalization["Question"] = 1.0  # Question 노드에 높은 가중치 부여
                
        # pagerank_scores = nx.pagerank(
        #     G, 
        #     alpha=0.85,  # damping factor
        #     personalization=personalization,  # 개인화 가중치
        #     weight='weight'  # 엣지의 가중치를 고려
        # )
        
        # pagerank 및 중심성 계산
        degree_centrality = nx.degree_centrality(G)
        betweenness_centrality = nx.betweenness_centrality(G)
        pagerank = nx.pagerank(G, alpha=0.85)
        
        # Step 5: 가장 중요한 Knowledge 문장 선택 (PageRank 기준)
        sorted_pagerank = sorted(pagerank.items(), key=lambda x: x[1], reverse=True)
        # sorted_pagerank = sorted(pagerank_scores.items(), key=lambda x: x[1], reverse=True)
        
        
        # 시각화 코드 (나중에 필요시 참고!) ============================================================================================
        # print("=== 중요도 순으로 정렬된 문장들 (PageRank 기준) ===")
        # print(f'question: {question}')
        # for node, score in sorted_pagerank:
        #     if node.startswith("Knowledge"):
        #         idx = int(node.split("_")[1])
        #         print(f"문장: {kb[idx]} | 점수: {score:.4f}")
        
    
        # # Step 6: 그래프 시각화
        # plt.figure(figsize=(12, 8))
        # pos = nx.spring_layout(G, seed=42)  # 그래프 레이아웃

        # # 노드 라벨 추출
        # labels = nx.get_node_attributes(G, 'label')

        # # 노드와 엣지 시각화
        # nx.draw(G, pos, with_labels=True, labels=labels, node_size=3000, node_color="skyblue", font_size=10, font_color="black")
        # edge_labels = nx.get_edge_attributes(G, 'weight')
        # nx.draw_networkx_edge_labels(G, pos, edge_labels={(u, v): f"{d['weight']:.2f}" for u, v, d in G.edges(data=True)})

        # plt.title("Question-Knowledge Sentence Graph with Centrality")
        # plt.show()
        # ==============================================================================================================================
        
        for node, score in sorted_pagerank:
            if node.startswith("Knowledge"):
                idx = int(node.split("_")[1]) # kb[idx]
                selected_kb_list.append(kb[idx])
        
        # if len(selected_kb_list) != 3:
        #     print(len(selected_kb_list))
        
        # with open(f"../okvqa_selected_kb/{question_id}.json", "w") as f:
        with open(f"./knowledge_files/aokvqa_pagerank_kb/{question_id}.json", "w") as f:
            json.dump(selected_kb_list, f)