In [None]:
from transformers import AutoProcessor,AutoModelForCausalLM
from utils import get_dataset
from tqdm import tqdm
from literal import ANSWER,IMG,QUESTION
import pandas as pd
import warnings
import torch
warnings.filterwarnings('ignore')

In [None]:
test_datasets = get_dataset('/home/hwlee/dacon/imgQA/preprocess_test.csv')

In [None]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

processors = AutoProcessor.from_pretrained("microsoft/git-large-coco")
models = [AutoModelForCausalLM.from_pretrained(path).to(device) for path in ["/home/user/4TB/hwlee/imgQA_output_0/",
                                                                             "/home/user/4TB/hwlee/imgQA_output_1/", 
                                                                             "/home/user/4TB/hwlee/imgQA_output_2/",
                                                                             "/home/user/4TB/hwlee/imgQA_output_3/",
                                                                             "/home/user/4TB/hwlee/imgQA_output_4/"
                                                                             ]]

In [None]:
pixel_value_list = []
input_ids_list = []

for data in tqdm(test_datasets):
    image = data[IMG]
    question = data[QUESTION].lower()
    pixel_values = processors(images=image, return_tensors="pt").pixel_values
    input_ids = processors(text=question, return_tensors="pt").input_ids
    pixel_value_list.append(pixel_values)
    input_ids_list.append(input_ids)

In [None]:
labels = []

for i in tqdm(range(len(test_datasets))):
    pixel_values = pixel_value_list[i].to(device)
    input_ids = input_ids_list[i].to(device)
    question = processors.tokenizer.decode(input_ids_list[i][0], skip_special_tokens=True)

    answer_list = []
    for model in models:
        generated_ids = model.generate(pixel_values=pixel_values, input_ids=input_ids, max_length=100,eos_token_id = 102)[0]
        answer = processors.tokenizer.decode(generated_ids,skip_special_tokens=True).replace(question,"").lstrip().rstrip()
        answer_list.append(answer)

    most_common_answer = max(set(answer_list), key=answer_list.count)

    labels.append(most_common_answer)

In [None]:
sub = pd.read_csv('/home/hwlee/dacon/imgQA/sample_submission.csv')
sub[ANSWER] = labels

sub.to_csv('0.csv',index=False)