In [None]:
TEST_BATCH_SIZE = 1
NUM_WORKERS = 8
MAX_TOKENS = 600

CROSS_ATTN_Q_DIM = 4096

CROSS_ATTN_EMBED_DIM_1 = 512
CROSS_ATTN_NUM_HEADS_1 = 8
CROSS_ATTN_KV_DIM_1 = 916

CROSS_ATTN_EMBED_DIM_2 = 128
CROSS_ATTN_NUM_HEADS_2 = 2
CROSS_ATTN_KV_DIM_2 = 21

In [None]:
import json
import os
from tqdm import tqdm

import torch
from torch.utils.data import DataLoader

import datasets
from transformers import AutoProcessor, LlavaForConditionalGeneration, BitsAndBytesConfig

from CodaDatasets import CodaDataset
from CodaFeatureExtractor import CodaFeatureExtractor
from CodaLayers import DecoderWithCrossAttention

In [None]:
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
device = 'cuda'

In [None]:
model_id = 'llava-hf/llava-1.5-7b-hf'
model = LlavaForConditionalGeneration.from_pretrained(
    model_id, 
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    quantization_config=BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
)
processor = AutoProcessor.from_pretrained(model_id)

In [None]:
question_template = 'USER: {}'

In [None]:
model.load_adapter('models/lora_r64_5e-5_ep1')
model

In [None]:
extractor = CodaFeatureExtractor(device)

In [None]:
for i in range(len(model.language_model.model.layers)):
    model.language_model.model.layers[i] = DecoderWithCrossAttention(
        model.language_model.model.layers[i],
        CROSS_ATTN_EMBED_DIM_1,
        CROSS_ATTN_EMBED_DIM_2,
        CROSS_ATTN_NUM_HEADS_1,
        CROSS_ATTN_NUM_HEADS_2,
        CROSS_ATTN_Q_DIM,
        CROSS_ATTN_KV_DIM_1,
        CROSS_ATTN_KV_DIM_2
    ).to(device)

In [None]:
hf_dataset = {
    'test': datasets.load_dataset('ntudlcv/dlcv_2024_final1', split='test')
}

In [None]:
dataset = {
    'test': CodaDataset(hf_dataset['test'], has_answer=False)
}

In [None]:
def custom_collate_fn(batch):
    return zip(*batch)

In [None]:
dataloader = {
    'test': DataLoader(dataset['test'], batch_size=TEST_BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, collate_fn=custom_collate_fn)
}

In [None]:
model.load_state_dict(torch.load('models/lora_crossattn_1e-5_ep1_model.pt', weights_only=False), strict=False);

In [None]:
predictions = {}
for data in tqdm(dataloader['test']):
    for data_id, question_type, image, question in zip(*data):
        # 1. disable cross attention
        for i in range(len(model.language_model.model.layers)):
            model.language_model.model.layers[i].enable_cross_attn = False
        
        question = question_template.format(question)
        question_inputs = processor(images=image, text=question, return_tensors='pt').to(device)
        
        model.eval()
        with torch.autocast(device), torch.no_grad():
            outputs = model(**question_inputs, use_cache=True)

        # 2. enable cross attention
        with torch.autocast(device), torch.no_grad():
            features = extractor.process_images([image])
        for i in range(len(model.language_model.model.layers)):
            model.language_model.model.layers[i].enable_cross_attn = True
            model.language_model.model.layers[i].cross_attn_context_1 = features['patch_tokens']
            model.language_model.model.layers[i].cross_attn_context_2 = features['instance_tokens']
            model.language_model.model.layers[i].cross_attn_mask_1 = None
            model.language_model.model.layers[i].cross_attn_mask_2 = features['instance_attention_mask']

        answer = 'ASSISTANT:'
        answer_inputs = processor(text=answer, add_special_tokens=False, return_tensors='pt').to(device)
        answer_embeds = model.get_input_embeddings()(answer_inputs['input_ids'])
        attention_mask = torch.cat([question_inputs['attention_mask'], answer_inputs['attention_mask']], dim=1)

        model.eval()
        with torch.autocast(device), torch.no_grad():
            outputs = model(inputs_embeds=answer_embeds, attention_mask=attention_mask, past_key_values=outputs['past_key_values'], use_cache=True, num_logits_to_keep=1)

        # 3. generate
        response = []
        for _ in range(MAX_TOKENS):
            next_token_id = outputs['logits'].argmax(2)
            response.append(next_token_id.item())
            if next_token_id.item() == processor.tokenizer.eos_token_id:
                break
            
            attention_mask = torch.cat([attention_mask, torch.ones(1, 1, device=device)], dim=1)

            model.eval()
            with torch.autocast(device), torch.no_grad():
                outputs = model(input_ids=next_token_id, attention_mask=attention_mask, past_key_values=outputs['past_key_values'], use_cache=True)

        generated_answer = processor.decode(response, skip_special_tokens=True)
        predictions[data_id] = generated_answer
        print(repr(data_id))
        print(repr(generated_answer))

In [None]:
with open('submission.json', 'w') as f:
    json.dump(predictions, f, indent=4)