In [1]:
from transformers import ViTImageProcessor, BertJapaneseTokenizer
from PIL import Image
import datasets
import os
import json
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
tokenizer = BertJapaneseTokenizer.from_pretrained("cl-tohoku/bert-base-japanese-whole-word-masking")

In [3]:
STAIR_CAPTIONS_DIR = "/autofs/diamond2/share/corpus/STAIR-captions"
STAIR_CAPTIONS_TRAIN_JSON_PATH = os.path.join(STAIR_CAPTIONS_DIR, "stair_captions_v1.2_train.json")
STAIR_CAPTIONS_VAL_JSON_PATH = os.path.join(STAIR_CAPTIONS_DIR, "stair_captions_v1.2_val.json")

COCO_DIR = "/autofs/diamond2/share/corpus/MS-COCO"
COCO_TRAIN2014_DIR = os.path.join(COCO_DIR, "train2014")
COCO_VAL2014_DIR = os.path.join(COCO_DIR, "val2014")

In [4]:
train_json = json.load(open(STAIR_CAPTIONS_TRAIN_JSON_PATH))
val_json = json.load(open(STAIR_CAPTIONS_VAL_JSON_PATH))

In [5]:
# train_json["annotations"][0]

In [6]:
# train_json["images"][0]

In [7]:
def convert_stair_caption_json_to_datalist(json, coco_image_dir):
    image_id2image_info = {image_info["id"]: image_info for image_info in json["images"]}

    datalist = []    
    for data in json["annotations"]:
        image_id = data["image_id"]
        image_info = image_id2image_info[image_id]
        image_path = os.path.join(coco_image_dir, image_info["file_name"])

        datalist.append({
            'id': data["id"],
            'caption': data["caption"],
            'image_path': image_path,
            'height': image_info["height"],
            'width': image_info["width"],
        })
    return datalist

In [8]:
datalist_train = convert_stair_caption_json_to_datalist(train_json, COCO_TRAIN2014_DIR)
datalist_val = convert_stair_caption_json_to_datalist(val_json, COCO_VAL2014_DIR)

In [9]:
dataset_dict = datasets.DatasetDict()
dataset_dict["train"] = datasets.Dataset.from_list(datalist_train)
dataset_dict["val"] = datasets.Dataset.from_list(datalist_val)

In [10]:
dataset_dict.save_to_disk("./stair_captions_dataset")

                                                                                                    

In [11]:
def convert_to_features(example_batch):
    inputs = tokenizer(
        example_batch["caption"],
        padding="max_length",
        truncation=True,
        max_length=128,
        return_tensors="pt",
    )
    images = []
    for image_path in example_batch["image_path"]:
        assert os.path.exists(image_path), f"image_path={image_path} is not found."
        images.append(Image.open(image_path).convert("RGB"))
    inputs["pixel_values"] = image_processor(images=images, return_tensors="pt").pixel_values
    return inputs

In [12]:
dataset_dict.set_transform(convert_to_features)