In [None]:
TEST_BATCH_SIZE = 1
NUM_WORKERS = 8
MAX_TOKENS = 600

In [None]:
import json
import os
from tqdm import tqdm

import torch
from torch.utils.data import DataLoader

import datasets
from transformers import AutoProcessor, LlavaForConditionalGeneration, BitsAndBytesConfig

from CodaDatasets import CodaDataset

In [None]:
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
device = 'cuda'

In [None]:
model_id = 'llava-hf/llava-1.5-7b-hf'
model = LlavaForConditionalGeneration.from_pretrained(
    model_id, 
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    quantization_config=BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
)
processor = AutoProcessor.from_pretrained(model_id)

In [None]:
prompt_template = 'USER: {} ASSISTANT:'

In [None]:
model.load_adapter('models/lora_r64_5e-5_augmented_ep1')
model

In [None]:
hf_dataset = {
    'test': datasets.load_dataset('ntudlcv/dlcv_2024_final1', split='test')
}

In [None]:
dataset = {
    'test': CodaDataset(hf_dataset['test'], has_answer=False)
}

In [None]:
def custom_collate_fn(batch):
    return zip(*batch)

In [None]:
dataloader = {
    'test': DataLoader(dataset['test'], batch_size=TEST_BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, collate_fn=custom_collate_fn)
}

In [None]:
predictions = {}
for data_ids, question_types, images, questions in tqdm(dataloader['test']):
    prompts = [prompt_template.format(q) for q in questions]
    inputs = processor(images=images, text=prompts, padding=True, return_tensors='pt').to(device)
    outputs = model.generate(**inputs, max_new_tokens=MAX_TOKENS, do_sample=False)
    for data_id, output in zip(data_ids, outputs):
        generated_answer = processor.decode(output, skip_special_tokens=True).split('ASSISTANT: ')[1]
        predictions[data_id] = generated_answer
        print(repr(data_id))
        print(repr(generated_answer))

In [None]:
with open('submission.json', 'w') as f:
    json.dump(predictions, f, indent=4)