In [None]:

from PIL import Image
import os
import sys
import csv
import torch
import numpy as np
import logging
from tqdm.notebook import tqdm

from utils import get_device_map

os.environ["TOKENIZERS_PARALLELISM"] = "false"
logging.disable(sys.maxsize)

In [None]:
devices = [1, 5, 6, 7]
start_device = 'cuda:' + str(devices[0])

### Configs

In [None]:
# checkpoint = "Salesforce/blip2-opt-2.7b"
# checkpoint = "Salesforce/blip2-flan-t5-xxl"
checkpoint = "Salesforce/blip2-flan-t5-xl"
result_file_path = '../results/coco_test_blip2.csv'
# cache_dir = "/mnt/nas2/kjh/huggingface_cache"
cache_dir = "../caches"
cache_pretrained_files_dir = os.path.join(cache_dir, "pretrained_files")
# saved_model_path = "../saved_models/blip2-flan-t5-xl_5epochs/"
saved_model_path = "../training_outputs/blip2-flan-t5-xl/checkpoint-3750/"
cache_dataset_dir = os.path.join(cache_dir, "datasets")


dtype = torch.float16
batch_size = 4
num_workers = 8
max_length = 50
num_beams = 6
length_penalty = 1          # 길이에 대한 패널티. 음수면 짧게 만들도록 유도, 양수면 길게 만들도록
repetition_penalty = 1.5    # 반복에 대한 패널티. 1에서 inf 사이의 값. default=1
temperature = 1             # 높으면 창의적인 값, 낮으면 자신있는 값

### Processor

In [None]:
from transformers import Blip2Processor

processor = Blip2Processor.from_pretrained(
    checkpoint,
    cache_dir=cache_pretrained_files_dir,
)


### Dataset

In [None]:
from dataset_config import nice_dataset_config
from torch.utils.data import DataLoader
from custom_datasets.nice_dataset import NICETestDataset

nice_test_ds = NICETestDataset(
    nice_dataset_config['test_image_folder'],
    processor.image_processor,
)

nice_test_dataloader = DataLoader(nice_test_ds, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=True)

### Model

In [None]:
from transformers import Blip2ForConditionalGeneration

device_map = get_device_map(checkpoint, devices)

model = Blip2ForConditionalGeneration.from_pretrained(
    saved_model_path,
    # checkpoint,
    cache_dir=cache_pretrained_files_dir,
    torch_dtype=dtype,
    # device_map='auto',
    device_map=device_map
)

In [None]:
# Freeze
freeze_list = [
    model.vision_model,
    model.qformer,
    model.language_projection,
    model.language_model,
]

for freeze_block in freeze_list:
    for name, param in freeze_block.named_parameters():
        param.requires_grad = False
    freeze_block = freeze_block.eval()


### NICE Test

In [None]:
def denormalize_image(normalized_image, mean, std):
    image = normalized_image.numpy().transpose(1, 2, 0)
    image = std * image + mean
    image = np.clip(image, 0, 1)
    
    return image

In [None]:
prompt_tokens = processor.tokenizer(
    "a photo of ", padding='max_length', max_length=max_length, return_tensors='pt'
)

def inference(dataloader, model, processor):
    results = {
        'public_id': [],
        'caption': [],
    }

    for inputs, filenames in tqdm(dataloader):
        inputs.to(start_device, dtype)
        generated_ids = model.generate(
            **inputs,
            # input_ids=prompt_tokens['input_ids'].repeat(batch_size, 1),
            num_beams=num_beams,
            length_penalty=length_penalty,
            repetition_penalty=repetition_penalty,
            max_new_tokens=max_length,
            temperature=temperature)
        generated_texts = [text.strip() for text in processor.batch_decode(generated_ids, skip_special_tokens=True)]
        results['public_id'] += filenames
        results['caption'] += generated_texts
        print(generated_texts)
    return results

In [None]:
def save_dict_to_csv(dict_to_save, save_path):
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    rows = list(map(list, zip(*dict_to_save.values())))
    with open(save_path, 'w') as f:
        w = csv.writer(f)
        w.writerow(dict_to_save.keys())
        w.writerows(rows)

In [None]:
generated_texts = inference(nice_test_dataloader, model, processor)

In [None]:
save_dict_to_csv(generated_texts, result_file_path)

In [None]:
torch.cuda.empty_cache()