In [None]:
TRAIN_BATCH_SIZE = 1
VAL_BATCH_SIZE = 1
TEST_BATCH_SIZE = 1
NUM_ACCUMULATION_STEPS = 1

ENABLE_DUMMY_DATASET = False
NUM_WORKERS = 8
MAX_TOKENS = 1300

EPOCHS = 1
LR = 5e-5

LORA_RANK = 64
LORA_ALPHA = 128

SAVE_MODEL_PREFIX = 'lora_r64_5e-5'

In [None]:
import json
import os
import seaborn as sns
import matplotlib.pyplot as plt
from IPython.display import display
from tqdm import tqdm

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import datasets
from transformers import AutoProcessor, LlavaForConditionalGeneration, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, PeftModel

from CodaDatasets import CodaDataset, DummyDataset

In [None]:
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
device = 'cuda'
writer = SummaryWriter()

In [None]:
model_id = 'llava-hf/llava-1.5-7b-hf'
model = LlavaForConditionalGeneration.from_pretrained(
    model_id, 
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    quantization_config=BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
)
processor = AutoProcessor.from_pretrained(model_id)

prompt_template = 'USER: {} ASSISTANT:'
full_template = 'USER: {} ASSISTANT: {}</s>'

In [None]:
def find_all_linear_names(model):
    cls = torch.nn.Linear
    lora_module_names = set()
    excluded_keywords = ['vision_tower', 'multi_modal_projector', 'lm_head']
    for name, module in model.named_modules():
        if any(keyword in name for keyword in excluded_keywords):
            continue
        if isinstance(module, cls):
            lora_module_names.add(name)
    return list(lora_module_names)

In [None]:
lora_config = LoraConfig(
    r=LORA_RANK,
    lora_alpha=LORA_ALPHA,
    target_modules=find_all_linear_names(model),
    lora_dropout=0.05,
    bias='none',
    task_type='CAUSAL_LM'
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

In [None]:
model

In [None]:
hf_dataset = {
    'train': datasets.load_dataset('ntudlcv/dlcv_2024_final1', split='train'),
    'val': datasets.load_dataset('ntudlcv/dlcv_2024_final1', split='val'),
    'test': datasets.load_dataset('ntudlcv/dlcv_2024_final1', split='test')
}

In [None]:
hf_dataset['val'] = hf_dataset['val'].shuffle(seed=1234).select(range(100))

In [None]:
dataset = {
    'train': CodaDataset(hf_dataset['train'], has_answer=True),
    'val': CodaDataset(hf_dataset['val'], has_answer=True),
    'test': CodaDataset(hf_dataset['test'], has_answer=False)
}

In [None]:
if ENABLE_DUMMY_DATASET:
    dataset['train'] = DummyDataset(50, has_answer=True)
    dataset['val'] = DummyDataset(2, has_answer=True)
    dataset['test'] = DummyDataset(2, has_answer=False)

In [None]:
def custom_collate_fn(batch):
    return zip(*batch)

In [None]:
dataloader = {
    'train': DataLoader(dataset['train'], batch_size=TRAIN_BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, collate_fn=custom_collate_fn),
    'val': DataLoader(dataset['val'], batch_size=VAL_BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, collate_fn=custom_collate_fn),
    'test': DataLoader(dataset['test'], batch_size=TEST_BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, collate_fn=custom_collate_fn)
}

In [None]:
batch = next(iter(dataloader['train']))
data_id, question_type, image, question, answer = list(zip(*batch))[0]

print('data_id:', repr(data_id))
print('question_type:', repr(question_type))
display(image.resize((200, 200)))
print('question:', repr(question))
print('answer:', repr(answer))

In [None]:
prompt_ids = processor.tokenizer('USER: question. ASSISTANT:')
full_ids = processor.tokenizer('USER: question. ASSISTANT: answer.</s>')
colon_idx = len(prompt_ids['input_ids']) - 1

print(prompt_ids)
print(full_ids)
print(colon_idx)

In [None]:
full_lens = []
for _, _, image, question, answer in tqdm(dataset['val']):
    full = full_template.format(question, answer)
    full_ids = processor(images=image, text=full, return_tensors='pt')['input_ids']
    full_lens.append(full_ids.shape[1])
    
print(max(full_lens))
plt.figure(figsize=(2, 2))
sns.ecdfplot(full_lens)
plt.show()

In [None]:
all_question_types = set()
all_questions = set()

for _, question_types, _, question, _ in tqdm(dataset['val']):
    all_question_types.add(question_types)
    all_questions.add(question)
longest_question = max(all_questions, key=len)

print(repr(all_question_types))
print(repr(longest_question))

In [None]:
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name)

In [None]:
optimizer = torch.optim.Adam(model.parameters())
steps_per_epoch = len(dataloader['train']) // NUM_ACCUMULATION_STEPS
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, 
    max_lr=LR,
    epochs=EPOCHS, 
    steps_per_epoch=steps_per_epoch
)
scaler = torch.GradScaler()

for epoch in range(1, EPOCHS+1):
    print(f'=== epoch {epoch} ===')
    
    for phase in ['train', 'val']:
        loss_total = 0.0

        pbar = tqdm(dataloader[phase])
        for step, (data_ids, question_types, images, questions, answers) in enumerate(pbar):
            batch_size = len(data_ids)
            
            # stress test
            if step == 0:
                questions = [longest_question for _ in range(batch_size)]
                answers = ['answer ' * 1000 for _ in range(batch_size)]

            prompts = [prompt_template.format(q) for q in questions]
            prompt_inputs = processor(images=images, text=prompts, padding=True, return_tensors='pt').to(device)

            fulls = [full_template.format(q, a) for q, a in zip(questions, answers)]
            full_inputs = processor(images=images, text=fulls, padding=True, return_tensors='pt').to(device)

            # truncate to avoid OOM
            full_inputs['input_ids'] = full_inputs['input_ids'][:, :MAX_TOKENS]
            full_inputs['attention_mask'] = full_inputs['attention_mask'][:, :MAX_TOKENS]

            ignored = torch.full((batch_size, 1), -100, device=device)
            labels = torch.cat([full_inputs['input_ids'][:, 1:], ignored], dim=1)
            for b in range(batch_size):
                colon_idx = len(prompt_inputs['input_ids'][b]) - 1
                labels[b, :colon_idx] = -100

            with torch.autocast(device), torch.set_grad_enabled(phase == 'train'):
                if phase == 'train':
                    model.train()
                else:
                    model.eval()
                outputs = model(**full_inputs)
                
                predictions = outputs['logits']
                loss = F.cross_entropy(predictions.reshape(-1, predictions.size(-1)), labels.reshape(-1))
                loss_before_scaling = loss.item()
                loss /= NUM_ACCUMULATION_STEPS

            if step != 0:
                pbar.set_postfix_str(f'loss={loss_before_scaling:.3f}')
                loss_total += loss_before_scaling

                if phase == 'train':
                    writer.add_scalar('lr', optimizer.param_groups[0]['lr'], (epoch-1)*steps_per_epoch+step)
                    writer.add_scalar('loss', loss_before_scaling, (epoch-1)*steps_per_epoch+step)
                    writer.flush()
                    
                    scaler.scale(loss).backward()
                    if (step + 1) % NUM_ACCUMULATION_STEPS == 0:
                        scaler.step(optimizer)
                        scaler.update()
                        lr_scheduler.step()
                        optimizer.zero_grad()
    
        loss_avg = loss_total / (len(dataloader[phase]) - 1)
        print(f'{phase} loss: {loss_avg}')
        
    os.makedirs('models', exist_ok=True)
    model.save_pretrained(f'models/{SAVE_MODEL_PREFIX}_ep{epoch}')

In [None]:
predictions = {}
for data_ids, question_types, images, questions in tqdm(dataloader['test']):
    prompts = [prompt_template.format(q) for q in questions]
    inputs = processor(images=images, text=prompts, padding=True, return_tensors='pt').to(device)
    outputs = model.generate(**inputs, max_new_tokens=MAX_TOKENS, do_sample=False)
    for data_id, output in zip(data_ids, outputs):
        generated_answer = processor.decode(output, skip_special_tokens=True).split('ASSISTANT: ')[1]
        predictions[data_id] = generated_answer
        print(repr(data_id))
        print(repr(generated_answer))

In [None]:
with open('submission.json', 'w') as f:
    json.dump(predictions, f, indent=4)