In [None]:
TRAIN_BATCH_SIZE = 1
VAL_BATCH_SIZE = 1
TEST_BATCH_SIZE = 1
NUM_ACCUMULATION_STEPS = 1

ENABLE_DUMMY_DATASET = False
NUM_WORKERS = 8
MAX_TOKENS = 600

EPOCHS = 1
LR = 1e-5

CROSS_ATTN_Q_DIM = 4096

CROSS_ATTN_EMBED_DIM_1 = 512
CROSS_ATTN_NUM_HEADS_1 = 8
CROSS_ATTN_KV_DIM_1 = 916

CROSS_ATTN_EMBED_DIM_2 = 128
CROSS_ATTN_NUM_HEADS_2 = 2
CROSS_ATTN_KV_DIM_2 = 21

SAVE_MODEL_PREFIX = 'lora_crossattn_1e-5'

In [None]:
import json
import os
import seaborn as sns
import matplotlib.pyplot as plt
from IPython.display import display
from tqdm import tqdm

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import datasets
from transformers import AutoProcessor, LlavaForConditionalGeneration, BitsAndBytesConfig
from peft import PeftModel

from CodaDatasets import CodaDataset, DummyDataset
from CodaFeatureExtractor import CodaFeatureExtractor
from CodaLayers import DecoderWithCrossAttention

In [None]:
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
device = 'cuda'
writer = SummaryWriter()

In [None]:
model_id = 'llava-hf/llava-1.5-7b-hf'
model = LlavaForConditionalGeneration.from_pretrained(
    model_id, 
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    quantization_config=BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
)
processor = AutoProcessor.from_pretrained(model_id)

question_template = 'USER: {}'
answer_template = 'ASSISTANT: {}</s>'

In [None]:
#model = PeftModel.from_pretrained(model, 'models/lora_r64_5e-5_ep1')
#model.merge_and_unload()
model.load_adapter('models/lora_r64_5e-5_ep1')
model

In [None]:
for param in model.parameters():
    param.requires_grad = False

In [None]:
extractor = CodaFeatureExtractor(device)

In [None]:
for param in extractor.parameters():
    param.requires_grad = False

In [None]:
for i in range(len(model.language_model.model.layers)):
    model.language_model.model.layers[i] = DecoderWithCrossAttention(
        model.language_model.model.layers[i],
        CROSS_ATTN_EMBED_DIM_1,
        CROSS_ATTN_EMBED_DIM_2,
        CROSS_ATTN_NUM_HEADS_1,
        CROSS_ATTN_NUM_HEADS_2,
        CROSS_ATTN_Q_DIM,
        CROSS_ATTN_KV_DIM_1,
        CROSS_ATTN_KV_DIM_2
    ).to(device)

In [None]:
hf_dataset = {
    'train': datasets.load_dataset('ntudlcv/dlcv_2024_final1', split='train'),
    'val': datasets.load_dataset('ntudlcv/dlcv_2024_final1', split='val'),
    'test': datasets.load_dataset('ntudlcv/dlcv_2024_final1', split='test')
}

In [None]:
hf_dataset['val'] = hf_dataset['val'].shuffle(seed=1234).select(range(100))

In [None]:
dataset = {
    'train': CodaDataset(hf_dataset['train'], has_answer=True),
    'val': CodaDataset(hf_dataset['val'], has_answer=True),
    'test': CodaDataset(hf_dataset['test'], has_answer=False)
}

In [None]:
if ENABLE_DUMMY_DATASET:
    dataset['train'] = DummyDataset(50, has_answer=True)
    dataset['val'] = DummyDataset(2, has_answer=True)
    dataset['test'] = DummyDataset(2, has_answer=False)

In [None]:
def custom_collate_fn(batch):
    return zip(*batch)

In [None]:
dataloader = {
    'train': DataLoader(dataset['train'], batch_size=TRAIN_BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, collate_fn=custom_collate_fn),
    'val': DataLoader(dataset['val'], batch_size=VAL_BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, collate_fn=custom_collate_fn),
    'test': DataLoader(dataset['test'], batch_size=TEST_BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, collate_fn=custom_collate_fn)
}

In [None]:
answer_lens = []
for _, _, _, _, answer in tqdm(dataset['val']):
    answer_len = len(processor.tokenizer(answer)['input_ids'])
    answer_lens.append(answer_len)

plt.figure(figsize=(2, 2))
sns.ecdfplot(answer_lens)
plt.show()

In [None]:
all_question_types = set()
all_questions = set()

for _, question_types, _, question, _ in tqdm(dataset['val']):
    all_question_types.add(question_types)
    all_questions.add(question)
longest_question = max(all_questions, key=len)

print(repr(all_question_types))
print(repr(longest_question))

In [None]:
batch = next(iter(dataloader['train']))
data_id, question_type, image, question, answer = list(zip(*batch))[0]

print('data_id:', repr(data_id))
print('question_type:', repr(question_type))
display(image.resize((200, 200)))
print('question:', repr(question))
print('answer:', repr(answer))

In [None]:
print(processor.tokenizer('ASSISTANT:', add_special_tokens=False))
print(processor.tokenizer('ASSISTANT: Hello world.', add_special_tokens=False))
print(processor.tokenizer('ASSISTANT: Goodbye world.', add_special_tokens=False))

In [None]:
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name)

In [None]:
def save_module(module, path):
    trainable_names = [name for name, param in module.named_parameters() if param.requires_grad]
    trainable_params = {k: v for k, v in module.state_dict().items() if k in trainable_names}
    torch.save(trainable_params, path)

In [None]:
optimizer = torch.optim.Adam(model.parameters())
steps_per_epoch = len(dataloader['train']) // NUM_ACCUMULATION_STEPS
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, 
    max_lr=LR,
    epochs=EPOCHS, 
    steps_per_epoch=steps_per_epoch
)
scaler = torch.GradScaler()

for epoch in range(1, EPOCHS+1):
    print(f'=== epoch {epoch} ===')
    
    for phase in ['train', 'val']:
        loss_total = 0.0

        pbar = tqdm(dataloader[phase])
        for step, (data_ids, question_types, images, questions, answers) in enumerate(pbar):
            batch_size = len(data_ids)
            
            # stress test
            if step == 0:
                questions = [longest_question for _ in range(batch_size)]
                answers = ['answer ' * 1000 for _ in range(batch_size)]

            # 1. disable cross attention
            for i in range(len(model.language_model.model.layers)):
                model.language_model.model.layers[i].enable_cross_attn = False
            
            questions = [question_template.format(q) for q in questions]
            question_inputs = processor(images=images, text=questions, padding=True, return_tensors='pt').to(device)
            
            model.eval()
            with torch.autocast(device), torch.no_grad():
                question_outputs = model(**question_inputs, use_cache=True)

            # 2. enable cross attention
            with torch.autocast(device), torch.no_grad():
                features = extractor.process_images(images)
            for i in range(len(model.language_model.model.layers)):
                model.language_model.model.layers[i].enable_cross_attn = True
                model.language_model.model.layers[i].cross_attn_context_1 = features['patch_tokens']
                model.language_model.model.layers[i].cross_attn_context_2 = features['instance_tokens']
                model.language_model.model.layers[i].cross_attn_mask_1 = None
                model.language_model.model.layers[i].cross_attn_mask_2 = features['instance_attention_mask']
            
            answers = [answer_template.format(a) for a in answers]
            answer_inputs = processor(text=answers, padding=True, add_special_tokens=False, return_tensors='pt').to(device)

            # truncate to avoid OOM
            answer_inputs['input_ids'] = answer_inputs['input_ids'][:, :MAX_TOKENS]
            answer_inputs['attention_mask'] = answer_inputs['attention_mask'][:, :MAX_TOKENS]

            answer_embeds = model.get_input_embeddings()(answer_inputs['input_ids'])
            attention_mask = torch.cat([question_inputs['attention_mask'], answer_inputs['attention_mask']], dim=1)
            ignored = torch.full((batch_size, 1), -100, device=device)
            labels = torch.cat([answer_inputs['input_ids'][:, 1:], ignored], dim=1)
            labels[:, :4] = -100  # "ASSISTANT:" need not be trained

            with torch.autocast(device), torch.set_grad_enabled(phase == 'train'):
                if phase == 'train':
                    model.train()
                else:
                    model.eval()
                outputs = model(inputs_embeds=answer_embeds, attention_mask=attention_mask, past_key_values=question_outputs['past_key_values'])
                
                predictions = outputs['logits']
                loss = F.cross_entropy(predictions.reshape(-1, predictions.size(-1)), labels.reshape(-1))
                loss_before_scaling = loss.item()
                loss /= NUM_ACCUMULATION_STEPS

            if step != 0:
                pbar.set_postfix_str(f'loss={loss_before_scaling:.3f}')
                loss_total += loss_before_scaling

                if phase == 'train':
                    writer.add_scalar('lr', optimizer.param_groups[0]['lr'], (epoch-1)*steps_per_epoch+step)
                    writer.add_scalar('loss', loss_before_scaling, (epoch-1)*steps_per_epoch+step)
                    writer.add_scalar('gate', model.language_model.model.layers[-1].cross_attn_layer_1.gate.item(), (epoch-1)*steps_per_epoch+step)
                    writer.add_scalar('gate_2', model.language_model.model.layers[-1].cross_attn_layer_2.gate.item(), (epoch-1)*steps_per_epoch+step)
                    writer.flush()
                    
                    scaler.scale(loss).backward()
                    if (step + 1) % NUM_ACCUMULATION_STEPS == 0:
                        scaler.step(optimizer)
                        scaler.update()
                        lr_scheduler.step()
                        optimizer.zero_grad()
    
        loss_avg = loss_total / (len(dataloader[phase]) - 1)
        print(f'{phase} loss: {loss_avg}')
        
    os.makedirs('models', exist_ok=True)
    save_module(model, f'models/{SAVE_MODEL_PREFIX}_ep{epoch}_model.pt')

In [None]:
%%script echo skipped
model.load_state_dict(torch.load('models/test_ep1_model.pt', weights_only=False), strict=False);

In [None]:
predictions = {}
for data in tqdm(dataloader['test']):
    for data_id, question_type, image, question in zip(*data):
        # 1. disable cross attention
        for i in range(len(model.language_model.model.layers)):
            model.language_model.model.layers[i].enable_cross_attn = False
        
        question = question_template.format(question)
        question_inputs = processor(images=image, text=question, return_tensors='pt').to(device)
        
        model.eval()
        with torch.autocast(device), torch.no_grad():
            outputs = model(**question_inputs, use_cache=True)

        # 2. enable cross attention
        with torch.autocast(device), torch.no_grad():
            features = extractor.process_images([image])
        for i in range(len(model.language_model.model.layers)):
            model.language_model.model.layers[i].enable_cross_attn = True
            model.language_model.model.layers[i].cross_attn_context_1 = features['patch_tokens']
            model.language_model.model.layers[i].cross_attn_context_2 = features['instance_tokens']
            model.language_model.model.layers[i].cross_attn_mask_1 = None
            model.language_model.model.layers[i].cross_attn_mask_2 = features['instance_attention_mask']

        answer = 'ASSISTANT:'
        answer_inputs = processor(text=answer, add_special_tokens=False, return_tensors='pt').to(device)
        answer_embeds = model.get_input_embeddings()(answer_inputs['input_ids'])
        attention_mask = torch.cat([question_inputs['attention_mask'], answer_inputs['attention_mask']], dim=1)

        model.eval()
        with torch.autocast(device), torch.no_grad():
            outputs = model(inputs_embeds=answer_embeds, attention_mask=attention_mask, past_key_values=outputs['past_key_values'], use_cache=True, num_logits_to_keep=1)

        # 3. generate
        response = []
        for _ in range(MAX_TOKENS):
            next_token_id = outputs['logits'].argmax(2)
            response.append(next_token_id.item())
            if next_token_id.item() == processor.tokenizer.eos_token_id:
                break
            
            attention_mask = torch.cat([attention_mask, torch.ones(1, 1, device=device)], dim=1)

            model.eval()
            with torch.autocast(device), torch.no_grad():
                outputs = model(input_ids=next_token_id, attention_mask=attention_mask, past_key_values=outputs['past_key_values'], use_cache=True)

        generated_answer = processor.decode(response, skip_special_tokens=True)
        predictions[data_id] = generated_answer
        print(repr(data_id))
        print(repr(generated_answer))

In [None]:
with open('submission.json', 'w') as f:
    json.dump(predictions, f, indent=4)