In [None]:
from PIL import Image
import torch
import numpy as np

from utils import get_device_map

In [None]:
devices = [1, 5, 6, 7]
start_device = 'cuda:' + str(devices[0])

### Configs

In [None]:
# checkpoint = "Salesforce/blip2-opt-2.7b"
# checkpoint = "Salesforce/blip2-flan-t5-xxl"
checkpoint = "Salesforce/blip2-flan-t5-xl"
result_file_path = '../results/coco_test_blip2.csv'
cache_dir = "/mnt/nas2/kjh/huggingface_cache"
cache_dir = "../pretrained_files"
dtype = torch.float16
batch_size = 16
num_workers = 8
max_new_tokens = 50

### Processor

In [4]:
from transformers import Blip2Processor

processor = Blip2Processor.from_pretrained(
    checkpoint,
    cache_dir=cache_dir,
)


KeyboardInterrupt: 

### Model

In [None]:
from transformers import Blip2ForConditionalGeneration

device_map = get_device_map(checkpoint, devices)

model = Blip2ForConditionalGeneration.from_pretrained(
    checkpoint,
    cache_dir=cache_dir,
    torch_dtype=dtype,
    # device_map='auto',
    device_map=device_map
)

In [None]:
# Freeze
freeze_list = [
    model.vision_model,
    model.qformer,
    model.language_projection,
    model.language_model,
]

for freeze_block in freeze_list:
    for name, param in freeze_block.named_parameters():
        param.requires_grad = False
    freeze_block = freeze_block.eval()


### NICE Test

In [None]:
def denormalize_image(normalized_image, mean, std):
    image = normalized_image.numpy().transpose(1, 2, 0)
    image = std * image + mean
    image = np.clip(image, 0, 1)
    
    return image

In [None]:
from tqdm import tqdm

def inference(dataloader, model, processor):
    results = {
        'public_id': [],
        'caption': [],
    }

    for inputs, filenames in tqdm(dataloader):
        inputs.to(start_device, dtype)
        generated_ids = model.generate(**inputs, num_beams=6, max_new_tokens=max_new_tokens)
        generated_texts = [text.strip() for text in processor.batch_decode(generated_ids, skip_special_tokens=True)]
        results['public_id'] += filenames
        results['caption'] += generated_texts
        
    return results

In [None]:
import os
import csv

def save_dict_to_csv(dict_to_save, save_path):
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    rows = list(map(list, zip(*dict_to_save.values())))
    with open(save_path, 'w') as f:
        w = csv.writer(f)
        w.writerow(dict_to_save.keys())
        w.writerows(rows)

In [None]:
from dataset_config import nice_dataset_config
from torch.utils.data import DataLoader
from custom_datasets.nice_dataset import NICETestDataset
from datasets import Dataset

nice_test_ds = NICETestDataset(
    nice_dataset_config['test_image_folder'],
    processor.image_processor,
)

nice_test_dataloader = DataLoader(nice_test_ds, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=True)

In [None]:
generated_texts = inference(nice_test_dataloader, model, processor)

In [None]:
save_dict_to_csv(generated_texts, result_file_path)

In [None]:
torch.cuda.empty_cache()