In [None]:
import os
import sys
import argparse
import logging
import datetime
import json
import time
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from textwrap import wrap

from transformers import TrainingArguments, Trainer

import torch
from torch.utils.data import DataLoader
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

from dataset_config import COCO_dataset_config
from utils import get_device_map

%matplotlib inline

os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
devices = [1, 5, 6, 7]
start_device = 'cuda:' + str(devices[0])

### Logging

In [None]:
# train result save dir name
results_dir = '../results'
result_dirname = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
result_dir_fullpath = os.path.join(results_dir, result_dirname)
os.makedirs(result_dir_fullpath, exist_ok=True)

In [None]:
logging.basicConfig(
    format="%(asctime)s - %(levelname)s: %(message)s",
    level=logging.INFO,
    datefmt="%I:%M:%S",
    handlers=[
        logging.FileHandler(os.path.join(result_dir_fullpath, 'train.log')),
        logging.StreamHandler(sys.stdout),
    ])

In [None]:
# tensorboard
writer = SummaryWriter(os.path.join('../runs', result_dirname))

### Configuration

In [None]:
# blip2 설정 참고하기
# https://github.com/salesforce/LAVIS/blob/main/lavis/projects/blip2/train/caption_coco_ft.yaml

# checkpoint = "Salesforce/blip2-opt-2.7b"
checkpoint = "Salesforce/blip2-flan-t5-xl"
# cache_dir = "/mnt/nas2/kjh/huggingface_cache"
cache_dir = "../caches"
cache_pretrained_files_dir = os.path.join(cache_dir, "pretrained_files")
cache_dataset_dir = os.path.join(cache_dir, "datasets")
cfg_path = "../configs/caption_coco_ft.yaml"
dtype = torch.float32
batch_size = 4
num_workers = 4
max_length = 50
epochs = 5
learning_rate = 1e-4
prompt = "a photo of "

model_name = checkpoint.split("/")[1]

In [None]:
from omegaconf import OmegaConf, DictConfig

config = OmegaConf.load(cfg_path)

print(config)

def dict_to_str_recursive(input_dict, depth=0):
    result_str = ''
    indent_str = '&nbsp;&nbsp;&nbsp;&nbsp;' * depth
    for key in input_dict:
        if type(input_dict[key]) in [dict, DictConfig]:
            value_str = dict_to_str_recursive(input_dict[key], depth + 1)
            result_str += indent_str + str(key) + ':  \n' + value_str + '  \n'
        else:
            value_str = str(input_dict[key])
            result_str += indent_str + str(key) + ': ' + value_str + '  \n'
    return result_str

config_str = dict_to_str_recursive(config)
writer.add_text('configs', config_str)

### Processor
##### image-processor + tokenizer

In [None]:
from transformers import Blip2Processor

processor = Blip2Processor.from_pretrained(
    checkpoint,
    cache_dir=cache_pretrained_files_dir,
    torch_dtype=dtype
)

### Dataset

In [None]:
from datasets import load_dataset
from datasets import Dataset, Image

caption_ds = load_dataset('../datasets/cvpr-nice-val', data_files={'caption': 'nice-val-5k.csv'}, split='caption', cache_dir=cache_dataset_dir)

image_filename_list = caption_ds['public_id']
image_path_list = [os.path.join('../datasets/cvpr-nice-val/val', str(image_filename) + '.jpg') for image_filename in image_filename_list]
train_ds = Dataset.from_dict({'image': image_path_list}).cast_column("image", Image())

for feature in caption_ds.features:
    train_ds = train_ds.add_column(name=feature, column=caption_ds[feature])

In [None]:
print(train_ds)

In [None]:
prompt_tokens = processor.tokenizer(
    prompt, padding='max_length', max_length=max_length
)

In [None]:
def transforms(input_batch, prefix=None):
    if prefix is not None:
        input_batch['caption_gt'] = prefix + input_batch['caption_gt']
    batch = processor(images=input_batch['image'], text=prompt, padding="max_length", max_length=max_length, return_tensors='pt')
    batch['pixel_values'] = batch['pixel_values'].squeeze(0)
    batch.update({'labels': processor.tokenizer(input_batch['caption_gt'], padding='max_length', max_length=max_length)['input_ids']})
    batch.update({'input_ids': batch['input_ids'].squeeze(0)})
    del batch['attention_mask']     # 이거 해야하나 말아야하나.. forward에서 'attention_mask' 지정 안해주면 알아서 만들어 주는 거 같은데
    # batch.update({'decoder_input_ids': prompt_tokens['input_ids']})
    return batch

In [None]:
train_ds = train_ds.map(
    transforms,
    remove_columns=['public_id', 'caption_gt', 'image', 'category'],
)
# batch 설정하면 왜인지 pixel_values가 이상해짐 & 변환할 때 조금 더 느려짐


In [None]:
print(train_ds)

### Plot Images

In [None]:
def denormalize_image(normalized_image, mean, std):
    image = normalized_image.transpose(1, 2, 0)
    image = std * image + mean
    image = np.clip(image, 0, 1)
    
    return image

def plot_images(images, captions, wrap_width=20):
    plt.figure(figsize=(20, 20))
    for i in range(len(images)):
        ax = plt.subplot(1, len(images), i + 1)
        caption = captions[i]
        caption = "\n".join(wrap(caption, wrap_width))
        plt.title(caption)
        plt.imshow(images[i])
        plt.axis("off")

num_samples = 5
samples = [train_ds[i] for i in range(num_samples)]

sample_images = []
sample_captions = []
for i in range(num_samples):
    sample_image = np.array(samples[i]['pixel_values'])
    sample_image = denormalize_image(sample_image, processor.image_processor.image_mean, processor.image_processor.image_std)
    sample_images.append(sample_image)
    
    sample_caption = ' '.join(processor.batch_decode(samples[i]['labels'], skip_special_tokens=True))
    sample_captions.append(sample_caption)

plot_images(sample_images, sample_captions)

### Model

In [None]:
from transformers import Blip2ForConditionalGeneration

device_map = get_device_map(checkpoint, devices)

model = Blip2ForConditionalGeneration.from_pretrained(
    checkpoint,
    cache_dir=cache_pretrained_files_dir,
    torch_dtype=dtype,
    device_map=device_map,
    low_cpu_mem_usage=True,
)

In [None]:
# Freeze
block_list = [
    model.vision_model,
    model.qformer,
    model.language_projection,
    model.language_model,
]

freeze_list = [
    # model.vision_model,
    # model.qformer,
    # model.language_projection,
    model.language_model,
]

for freeze_block in freeze_list:
    for name, param in freeze_block.named_parameters():
        param.requires_grad = False
    freeze_block = freeze_block.eval()

for block in block_list:
    if block not in freeze_list:
        for name, param in block.named_parameters():
            param.requires_grad = True
    block = block.train()

### Train

In [None]:
training_args = TrainingArguments(
    output_dir=f"../training_outputs/{model_name}",
    learning_rate=learning_rate,
    num_train_epochs=epochs,
    fp16=True if dtype is torch.float16 else False,
    fp16_opt_level="02",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=1,
    save_total_limit=3,
    save_strategy="epoch",
    logging_steps=100,
    remove_unused_columns=False,
    push_to_hub=False,
    label_names=['labels'],
)
training_args.set_lr_scheduler(name='linear', warmup_steps=1000)
training_args.set_optimizer(name='adamw_hf', learning_rate=1e-6, weight_decay=0.05)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=None,
)

In [None]:
from pynvml import *

def print_gpu_utilization():
    nvmlInit()
    handle = nvmlDeviceGetHandleByIndex(0)
    info = nvmlDeviceGetMemoryInfo(handle)
    print(f"GPU memory occupied: {info.used//1024**2} MB.")

def print_summary(result):
    print(f"Time: {result.metrics['train_runtime']:.2f}")
    print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}")
    print_gpu_utilization()

In [None]:
from torch.cuda.amp import autocast

with autocast(dtype=dtype):
    result = trainer.train()
    print_summary(result)

In [None]:
from PIL import Image
from utils import show_image_caption

image = '../datasets/cvpr-nice-val/val/215268662.jpg'
# caption_gt = 'Bicycles leaning against tree in wood Close up low angle view'
raw_image = Image.open(image).convert('RGB')

inputs = processor(raw_image, return_tensors="pt").to(start_device, dtype)

generated_ids = model.generate(**inputs, max_new_tokens=max_length)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
print(generated_text)

show_image_caption(raw_image, [generated_text], show_fig=True)

### Training