In [None]:
import os
import sys
import argparse
import logging
import datetime
import json
import time
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from textwrap import wrap

from transformers import TrainingArguments, Trainer

import torch
from torch.utils.data import DataLoader
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

from dataset_config import COCO_dataset_config
from utils import get_device_map

%matplotlib inline

os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
devices = [1, 5, 6, 7]
start_device = 'cuda:' + str(devices[0])

### Logging

In [None]:
# train result save dir name
results_dir = '../results'
result_dirname = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
result_dir_fullpath = os.path.join(results_dir, result_dirname)
os.makedirs(result_dir_fullpath, exist_ok=True)

In [None]:
logging.basicConfig(
    format="%(asctime)s - %(levelname)s: %(message)s",
    level=logging.INFO,
    datefmt="%I:%M:%S",
    handlers=[
        logging.FileHandler(os.path.join(result_dir_fullpath, 'train.log')),
        logging.StreamHandler(sys.stdout),
    ])

In [None]:
# tensorboard
writer = SummaryWriter(os.path.join('../runs', result_dirname))

### Configuration

In [None]:
# blip2 설정 참고하기
# https://github.com/salesforce/LAVIS/blob/main/lavis/projects/blip2/train/caption_coco_ft.yaml

checkpoint = "Salesforce/blip2-flan-t5-xl"
# cache_dir = "/mnt/nas2/kjh/huggingface_cache"
cache_dir = "../pretrained_files"
cfg_path = "../configs/caption_coco_ft.yaml"
dtype = torch.float16
batch_size = 16
num_workers = 4
max_length = 50
epochs = 5
prompt = "a photo of "

model_name = checkpoint.split("/")[1]

In [None]:
from omegaconf import OmegaConf, DictConfig

config = OmegaConf.load(cfg_path)

print(config)

def dict_to_str_recursive(input_dict, depth=0):
    result_str = ''
    indent_str = '&nbsp;&nbsp;&nbsp;&nbsp;' * depth
    for key in input_dict:
        if type(input_dict[key]) in [dict, DictConfig]:
            value_str = dict_to_str_recursive(input_dict[key], depth + 1)
            result_str += indent_str + str(key) + ':  \n' + value_str + '  \n'
        else:
            value_str = str(input_dict[key])
            result_str += indent_str + str(key) + ': ' + value_str + '  \n'
    return result_str

config_str = dict_to_str_recursive(config)
writer.add_text('configs', config_str)

### Processor
##### image-processor + tokenizer

In [None]:
from transformers import Blip2Processor

processor = Blip2Processor.from_pretrained(
    checkpoint,
    cache_dir=cache_dir,
)

### Dataset

In [None]:
from datasets import load_dataset
train_ds = load_dataset('../datasets/cvpr-nice-val/', split='validation')
caption_ds = load_dataset('../datasets/cvpr-nice-val', data_files={'caption': 'nice-val-5k.csv'}, split='caption')
for feature in caption_ds.features:
    print(feature)
    train_ds = train_ds.add_column(name=feature, column=caption_ds[feature])
    
# column명 변경
train_ds.rename_column("image", "images")

In [None]:
print(train_ds)

In [None]:
prompt_tokens = processor.tokenizer(
    prompt, padding='max_length', max_length=max_length
)

In [None]:
def transforms(input_batch, prefix=None):
    if prefix is not None:
        input_batch['caption_gt'] = prefix + input_batch['caption_gt']
    # batch = processor(images=batch['image'], text=batch['caption_gt'], padding="max_length", max_length=max_length, return_tensors='pt')
    batch = processor(images=input_batch['image'], text=prompt, padding="max_length", max_length=max_length, return_tensors='pt')
    batch['pixel_values'] = batch['pixel_values'].squeeze(0)
    batch.update({'labels': processor.tokenizer(input_batch['caption_gt'], padding='max_length', max_length=max_length)})
    # batch.update({'decoder_input_ids': prompt_tokens['input_ids']})
    return batch

In [None]:
train_ds = train_ds.map(
    transforms,
    remove_columns=['public_id', 'caption_gt', 'image', 'category'],
)
# batch 설정하면 왜인지 pixel_values가 이상해짐 & 변환할 때 image가 있어서 그런지 더 느려짐

In [None]:
print(train_ds)

### Plot Images

In [None]:
def denormalize_image(normalized_image, mean, std):
    image = normalized_image.transpose(1, 2, 0)
    image = std * image + mean
    image = np.clip(image, 0, 1)
    
    return image

def plot_images(images, captions):
    plt.figure(figsize=(20, 20))
    for i in range(len(images)):
        ax = plt.subplot(1, len(images), i + 1)
        caption = captions[i]
        caption = "\n".join(wrap(caption, 16))
        plt.title(caption)
        plt.imshow(images[i])
        plt.axis("off")

num_samples = 5
samples = [train_ds[i] for i in range(5)]

sample_images = []
sample_captions = []
for i in range(num_samples):
    sample_image = np.array(samples[i]['pixel_values'])
    sample_image = denormalize_image(sample_image, processor.image_processor.image_mean, processor.image_processor.image_std)
    sample_images.append(sample_image)
    
    sample_caption = ' '.join(processor.batch_decode(samples[i]['labels'], skip_special_tokens=True))
    sample_captions.append(sample_caption)

plot_images(sample_images, sample_captions)

### Model

In [None]:
from transformers import Blip2ForConditionalGeneration

device_map = get_device_map(checkpoint, devices)

model = Blip2ForConditionalGeneration.from_pretrained(
    checkpoint,
    cache_dir=cache_dir,
    torch_dtype=dtype,
    device_map=device_map,
)

In [None]:
# Freeze
block_list = [
    model.vision_model,
    model.qformer,
    model.language_projection,
    model.language_model,
]

freeze_list = [
    # model.vision_model,
    # model.qformer,
    # model.language_projection,
    model.language_model,
]

for freeze_block in freeze_list:
    for name, param in freeze_block.named_parameters():
        param.requires_grad = False
    freeze_block = freeze_block.eval()

for block in block_list:
    if block not in freeze_list:
        for name, param in block.named_parameters():
            param.requires_grad = True
    block = block.train()

### Train

In [None]:
training_args = TrainingArguments(
    output_dir=f"../training_outputs/{model_name}",
    learning_rate=1e-5,
    num_train_epochs=epochs,
    fp16=True if dtype is torch.float16 else False,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    save_total_limit=3,
    save_strategy="epoch",
    save_steps=5,
    logging_steps=1,
    remove_unused_columns=False,
    push_to_hub=False,
    label_names=['input_ids'],
)
training_args.set_lr_scheduler(name='linear', warmup_steps=1000)
training_args.set_optimizer(name='adamw_hf', learning_rate=1e-6, weight_decay=0.05)

In [None]:
def compute_metrics(output):
    logits = output.loss.get('logits')
    input_ids = output.loss.get('input_ids')
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = input_ids[..., 1:].contiguous().to(logits.device)
    
    print(torch.isnan(logits).any())
    print(torch.isnan(shift_labels).any())
    
    loss_fn = torch.nn.CrossEntropyLoss()
    loss = loss_fn(shift_logits.view(-1, 32128), shift_labels.view(-1))
    return {
        'loss': loss.item(),
    }

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=None,
    # compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

### Training