In [None]:
from PIL import Image
import torch
import numpy as np

from utils import show_image_caption

In [None]:
device = 'cuda'
print('Use ', torch.cuda.get_device_name(device))

In [None]:
checkpoint = "Salesforce/blip2-opt-2.7b"
dtype = torch.float32

### Processor

In [None]:
from transformers import Blip2Processor

processor = Blip2Processor.from_pretrained(
    checkpoint,
    cache_dir='../pretrained_files',
)


### Model

In [None]:
# device_map = "auto" 로 했을 때 다음과 같았음
device_map = {
    "query_tokens": 0,
    "vision_model.embeddings": 0,
    "vision_model.encoder.layers.0": 0,
    "vision_model.encoder.layers.1": 0,
    "vision_model.encoder.layers.2": 0,
    "vision_model.encoder.layers.3": 0,
    "vision_model.encoder.layers.4": 0,
    "vision_model.encoder.layers.5": 0,
    "vision_model.encoder.layers.6": 0,
    "vision_model.encoder.layers.7": 0,
    "vision_model.encoder.layers.8": 0,
    "vision_model.encoder.layers.9": 0,
    "vision_model.encoder.layers.10": 0,
    "vision_model.encoder.layers.11": 0,
    "vision_model.encoder.layers.12": 0,
    "vision_model.encoder.layers.13": 0,
    "vision_model.encoder.layers.14": 0,
    "vision_model.encoder.layers.15": 0,
    "vision_model.encoder.layers.16": 0,
    "vision_model.encoder.layers.17": 0,
    "vision_model.encoder.layers.18": 0,
    "vision_model.encoder.layers.19": 0,
    "vision_model.encoder.layers.20": 0,
    "vision_model.encoder.layers.21": 0,
    "vision_model.encoder.layers.22": 0,
    "vision_model.encoder.layers.23": 0,
    "vision_model.encoder.layers.24": 0,
    "vision_model.encoder.layers.25": 0,
    "vision_model.encoder.layers.26": 0,
    "vision_model.encoder.layers.27": 0,
    "vision_model.encoder.layers.28": 0,
    "vision_model.encoder.layers.29": 0,
    "vision_model.encoder.layers.30": 0,
    "vision_model.encoder.layers.31": 0,
    "vision_model.encoder.layers.32": 0,
    "vision_model.encoder.layers.33": 0,
    "vision_model.encoder.layers.34": 0,
    "vision_model.encoder.layers.35": 0,
    "vision_model.encoder.layers.36": 0,
    "vision_model.encoder.layers.38": 0,
    "vision_model.post_layernorm": 0,
    "qformer": 0,
    "language_projection": 0,
    "language_model.model.decoder.embed_tokens": 0,
    "language_model.lm_head": 0,
    "language_model.model.decoder.embed_positions": 0,
    "language_model.model.decoder.final_layer_norm": 0,
    "language_model.model.decoder.layers.0": 0,
    "language_model.model.decoder.layers.1": 0,
    "language_model.model.decoder.layers.2": 0,
    "language_model.model.decoder.layers.3": 0,
    "language_model.model.decoder.layers.4": 0,
    "language_model.model.decoder.layers.5": 0,
    "language_model.model.decoder.layers.6": 1,
    "language_model.model.decoder.layers.7": 1,
    "language_model.model.decoder.layers.8": 1,
    "language_model.model.decoder.layers.9": 1,
    "language_model.model.decoder.layers.10": 1,
    "language_model.model.decoder.layers.11": 1,
    "language_model.model.decoder.layers.12": 1,
    "language_model.model.decoder.layers.13": 1,
    "language_model.model.decoder.layers.14": 1,
    "language_model.model.decoder.layers.15": 1,
    "language_model.model.decoder.layers.16": 1,
    "language_model.model.decoder.layers.17": 1,
    "language_model.model.decoder.layers.18": 1,
    "language_model.model.decoder.layers.19": 1,
    "language_model.model.decoder.layers.20": 1,
    "language_model.model.decoder.layers.21": 1,
    "language_model.model.decoder.layers.22": 1,
    "language_model.model.decoder.layers.23": 1,
    "language_model.model.decoder.layers.24": 1,
    "language_model.model.decoder.layers.25": 1,
    "language_model.model.decoder.layers.26": 1,
    "language_model.model.decoder.layers.27": 1,
    "language_model.model.decoder.layers.28": 1,
    "language_model.model.decoder.layers.29": 1,
    "language_model.model.decoder.layers.30": 1,
    "language_model.model.decoder.layers.31": 1,
    "vision_model.encoder.layers.37": 0
}

In [None]:
from transformers import Blip2ForConditionalGeneration

model = Blip2ForConditionalGeneration.from_pretrained(
    checkpoint,
    cache_dir='../pretrained_files',
    torch_dtype=dtype,
    # device_map='auto',
    device_map=device_map
)
# model.parallelize(device_map)

In [None]:
image = '../datasets/cvpr-nice-val/val/215268662.jpg'
raw_image = Image.open(image).convert('RGB')

inputs = processor(raw_image, return_tensors="pt").to(device, dtype)

In [None]:
def denormalize_image(normalized_image, mean, std):
    image = normalized_image.numpy().transpose(1, 2, 0)
    image = std * image + mean
    image = np.clip(image, 0, 1)
    
    return image

In [None]:
generated_ids = model.generate(**inputs, max_new_tokens=20)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
print(generated_text)

denormalized_image = denormalize_image(inputs['pixel_values'].cpu()[0], processor.image_processor.image_mean, processor.image_processor.image_std)
# show_image_caption(denormalized_image, [generated_text], show_fig=True, save_path='sample.png')
show_image_caption(raw_image, [generated_text], show_fig=True)

### NICE Validation

In [None]:
from tqdm import tqdm

def inference(dataset, model, processor):
    generated_texts = {
        'public_id': [],
        'caption': [],
    }

    for inputs, labels in tqdm(dataset):
        inputs.to(device, dtype)
        generated_ids = model.generate(**inputs, max_new_tokens=20)
        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
        generated_texts['public_id'].append(labels)
        generated_texts['caption'].append(generated_text)
        
    return generated_texts

In [None]:
import os
import csv

def save_dict_to_csv(dict_to_save, save_path):
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    rows = list(map(list, zip(*dict_to_save.values())))
    with open(save_path, 'w') as f:
        w = csv.writer(f, quoting=csv.QUOTE_NONE)
        w.writerow(dict_to_save.keys())
        w.writerows(rows)

In [None]:
from dataset_config import nice_dataset_config
from torch.utils.data import DataLoader
from custom_datasets.nice_dataset import NICETestDataset
from datasets import Dataset

## NICEValDataset으로 바꾸고 평가 지표 알 수 있도록 하면 좋을듯
nice_val_ds = NICETestDataset(
    nice_dataset_config['val_image_folder'],
    processor.image_processor,
)

In [None]:
generated_texts = inference(nice_val_ds, model, processor)

In [None]:
save_dict_to_csv(generated_texts, '../results/coco_val_blip2.csv')

In [None]:
# from datasets import load_dataset
# from dataset_config import nice_dataset_config
# from custom_datasets.nice_dataset import NICEValDataset

# nice_val_ds = NICEValDataset(
#     nice_dataset_config['val_image_folder'],
#     nice_dataset_config['val_caption_csv'],
#     processor.image_processor,
# )

### NICE Test

In [None]:
from dataset_config import nice_dataset_config
from torch.utils.data import DataLoader
from custom_datasets.nice_dataset import NICETestDataset
from datasets import Dataset

nice_test_ds = NICETestDataset(
    nice_dataset_config['test_image_folder'],
    processor.image_processor,
)

In [None]:
generated_texts = inference(nice_test_ds, model, processor)

In [None]:
save_dict_to_csv(generated_texts, '../results/coco_val_blip2.csv')