In [None]:
from PIL import Image
import torch
import numpy as np

from utils import show_image_caption

In [None]:
device = 'cuda'
print('Use ', torch.cuda.get_device_name(device))

In [None]:
checkpoint = "Salesforce/blip2-opt-2.7b"
result_save_path = '../results/coco_val_blip2.csv'
dtype = torch.float32
batch_size = 16
num_workers = 8
max_new_tokens = 50

### Processor

In [None]:
from transformers import Blip2Processor

processor = Blip2Processor.from_pretrained(
    checkpoint,
    cache_dir='../pretrained_files',
)


### Model

In [None]:
# device_map = "auto" 로 했을 때 다음과 같았음

devices = [4, 5]

device_map = {
    "query_tokens": devices[0],
    "vision_model.embeddings": devices[0],
    "vision_model.encoder.layers.0": devices[0],
    "vision_model.encoder.layers.1": devices[0],
    "vision_model.encoder.layers.2": devices[0],
    "vision_model.encoder.layers.3": devices[0],
    "vision_model.encoder.layers.4": devices[0],
    "vision_model.encoder.layers.5": devices[0],
    "vision_model.encoder.layers.6": devices[0],
    "vision_model.encoder.layers.7": devices[0],
    "vision_model.encoder.layers.8": devices[0],
    "vision_model.encoder.layers.9": devices[0],
    "vision_model.encoder.layers.10": devices[0],
    "vision_model.encoder.layers.11": devices[0],
    "vision_model.encoder.layers.12": devices[0],
    "vision_model.encoder.layers.13": devices[0],
    "vision_model.encoder.layers.14": devices[0],
    "vision_model.encoder.layers.15": devices[0],
    "vision_model.encoder.layers.16": devices[0],
    "vision_model.encoder.layers.17": devices[0],
    "vision_model.encoder.layers.18": devices[0],
    "vision_model.encoder.layers.19": devices[0],
    "vision_model.encoder.layers.20": devices[0],
    "vision_model.encoder.layers.21": devices[0],
    "vision_model.encoder.layers.22": devices[0],
    "vision_model.encoder.layers.23": devices[0],
    "vision_model.encoder.layers.24": devices[0],
    "vision_model.encoder.layers.25": devices[0],
    "vision_model.encoder.layers.26": devices[0],
    "vision_model.encoder.layers.27": devices[0],
    "vision_model.encoder.layers.28": devices[0],
    "vision_model.encoder.layers.29": devices[0],
    "vision_model.encoder.layers.30": devices[0],
    "vision_model.encoder.layers.31": devices[0],
    "vision_model.encoder.layers.32": devices[0],
    "vision_model.encoder.layers.33": devices[0],
    "vision_model.encoder.layers.34": devices[0],
    "vision_model.encoder.layers.35": devices[0],
    "vision_model.encoder.layers.36": devices[0],
    "vision_model.encoder.layers.38": devices[0],
    "vision_model.post_layernorm": devices[0],
    "qformer": devices[0],
    "language_projection": devices[0],
    "language_model.model.decoder.embed_tokens": devices[0],
    "language_model.lm_head": devices[0],
    "language_model.model.decoder.embed_positions": devices[0],
    "language_model.model.decoder.final_layer_norm": devices[0],
    "language_model.model.decoder.layers.0": devices[0],
    "language_model.model.decoder.layers.1": devices[0],
    "language_model.model.decoder.layers.2": devices[0],
    "language_model.model.decoder.layers.3": devices[0],
    "language_model.model.decoder.layers.4": devices[0],
    "language_model.model.decoder.layers.5": devices[0],
    "language_model.model.decoder.layers.6": devices[1],
    "language_model.model.decoder.layers.7": devices[1],
    "language_model.model.decoder.layers.8": devices[1],
    "language_model.model.decoder.layers.9": devices[1],
    "language_model.model.decoder.layers.10": devices[1],
    "language_model.model.decoder.layers.11": devices[1],
    "language_model.model.decoder.layers.12": devices[1],
    "language_model.model.decoder.layers.13": devices[1],
    "language_model.model.decoder.layers.14": devices[1],
    "language_model.model.decoder.layers.15": devices[1],
    "language_model.model.decoder.layers.16": devices[1],
    "language_model.model.decoder.layers.17": devices[1],
    "language_model.model.decoder.layers.18": devices[1],
    "language_model.model.decoder.layers.19": devices[1],
    "language_model.model.decoder.layers.20": devices[1],
    "language_model.model.decoder.layers.21": devices[1],
    "language_model.model.decoder.layers.22": devices[1],
    "language_model.model.decoder.layers.23": devices[1],
    "language_model.model.decoder.layers.24": devices[1],
    "language_model.model.decoder.layers.25": devices[1],
    "language_model.model.decoder.layers.26": devices[1],
    "language_model.model.decoder.layers.27": devices[1],
    "language_model.model.decoder.layers.28": devices[1],
    "language_model.model.decoder.layers.29": devices[1],
    "language_model.model.decoder.layers.30": devices[1],
    "language_model.model.decoder.layers.31": devices[1],
    "vision_model.encoder.layers.37": devices[0]
}

In [None]:
from transformers import Blip2ForConditionalGeneration

model = Blip2ForConditionalGeneration.from_pretrained(
    checkpoint,
    cache_dir='../pretrained_files',
    torch_dtype=dtype,
    # device_map='auto',
    device_map=device_map
)

In [None]:
def denormalize_image(normalized_image, mean, std):
    image = normalized_image.numpy().transpose(1, 2, 0)
    image = std * image + mean
    image = np.clip(image, 0, 1)
    
    return image

### NICE Validation

In [None]:
from tqdm import tqdm

def inference(dataloader, model, processor):
    results = {
        'public_id': [],
        'caption': [],
    }

    for inputs, labels in tqdm(dataloader):
        inputs.to(device, dtype)
        generated_ids = model.generate(**inputs, max_new_tokens=max_new_tokens)
        generated_texts = [text.strip() for text in processor.batch_decode(generated_ids, skip_special_tokens=True)]
        results['public_id'] += labels
        results['caption'] += generated_texts
        
    return results

In [None]:
import os
import csv

def save_dict_to_csv(dict_to_save, save_path):
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    rows = list(map(list, zip(*dict_to_save.values())))
    with open(save_path, 'w') as f:
        w = csv.writer(f)
        w.writerow(dict_to_save.keys())
        w.writerows(rows)

In [None]:
from dataset_config import nice_dataset_config
from torch.utils.data import DataLoader
from custom_datasets.nice_dataset import NICETestDataset
from datasets import Dataset

## NICEValDataset으로 바꾸고 평가 지표 알 수 있도록 하면 좋을듯
nice_val_ds = NICETestDataset(
    nice_dataset_config['val_image_folder'],
    processor.image_processor,
)

nice_val_dataloader = DataLoader(nice_val_ds, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=True)

In [None]:
generated_texts = inference(nice_val_dataloader, model, processor)

In [None]:
save_dict_to_csv(generated_texts, result_save_path)