In [39]:
import os
import datasets
from transformers import VisionEncoderDecoderModel, AutoFeatureExtractor,AutoTokenizer
os.environ["WANDB_DISABLED"] = "true"

In [3]:
import nltk
try:
    nltk.data.find("tokenizers/punkt")
except (LookupError, OSError):
    nltk.download("punkt", quiet=True)

In [21]:
data_path = "./coco2017_data_test/"

train_ds = datasets.load_dataset(data_path, split="train")
valid_ds = datasets.load_dataset(data_path, split="validation")
test_ds = datasets.load_dataset(data_path, split="test")

Found cached dataset imagefolder (/home/drcl-yang/.cache/huggingface/datasets/imagefolder/coco2017_data_test-ab80f7257c64201e/0.0.0/37fbb85cc714a338bea574ac6c7d0b5be5aff46c1862c1989b20e0771199e93f)
Found cached dataset imagefolder (/home/drcl-yang/.cache/huggingface/datasets/imagefolder/coco2017_data_test-ab80f7257c64201e/0.0.0/37fbb85cc714a338bea574ac6c7d0b5be5aff46c1862c1989b20e0771199e93f)
Found cached dataset imagefolder (/home/drcl-yang/.cache/huggingface/datasets/imagefolder/coco2017_data_test-ab80f7257c64201e/0.0.0/37fbb85cc714a338bea574ac6c7d0b5be5aff46c1862c1989b20e0771199e93f)


In [25]:
train_ds[3]['image']
train_ds[6]
os.getcwd()

'/home/drcl-yang/workspace/Image-Captioning'

In [27]:
data_path = os.getcwd() + "/coco2017_data_test/"
ds = datasets.load_dataset("ydshieh/coco_dataset_script", "2017", data_dir=data_path)

ds

Found cached dataset coco_dataset_script (/home/drcl-yang/.cache/huggingface/datasets/ydshieh___coco_dataset_script/2017-eca5cab7eb1b92f0/0.0.0/e033205c0266a54c10be132f9264f2a39dcf893e798f6756d224b1ff5078998f)


  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['image_id', 'caption_id', 'caption', 'height', 'width', 'file_name', 'coco_url', 'image_path'],
        num_rows: 80
    })
    validation: Dataset({
        features: ['image_id', 'caption_id', 'caption', 'height', 'width', 'file_name', 'coco_url', 'image_path'],
        num_rows: 80
    })
    test: Dataset({
        features: ['image_id', 'caption_id', 'caption', 'height', 'width', 'file_name', 'coco_url', 'image_path'],
        num_rows: 16
    })
})

In [42]:
from transformers import VisionEncoderDecoderModel, AutoTokenizer, AutoFeatureExtractor

image_encoder_model = "google/vit-base-patch16-224-in21k"
text_decode_model = "gpt2"

model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(image_encoder_model, text_decode_model)

# image feature extractor
feature_extractor = AutoFeatureExtractor.from_pretrained(image_encoder_model)
# text tokenizer
tokenizer = AutoTokenizer.from_pretrained(text_decode_model)

# GPT2 only has bos/eos tokens but not decoder_start/pad tokens
tokenizer.pad_token = tokenizer.eos_token

# update the model config
model.config.eos_token_id = tokenizer.eos_token_id
model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.pad_token_id = tokenizer.pad_token_id



In [44]:
from PIL import Image

# text preprocessing step
def tokenization_fn(captions, max_target_length):
    """Run tokenization on captions."""
    labels = tokenizer(captions, padding="max_length", max_length=max_target_length).input_ids

    return labels

def feature_extraction_fn(image_paths, check_image=True):
    """
    Run feature extraction on images
    If `check_image` is `True`, the examples that fails during `Image.open()` will be caught and discarded.
    Otherwise, an exception will be thrown.
    """

    model_inputs = {}

    if check_image:
        images = []
        to_keep = []
        for image_file in image_paths:
            try:
                img = Image.open(image_file)
                images.append(img)
                to_keep.append(True)
            except Exception:
                to_keep.append(False)
    else:
        images = [Image.open(image_file) for image_file in image_paths]

    encoder_inputs = feature_extractor(images=images, return_tensors="np")

    return encoder_inputs.pixel_values

def preprocess_fn(examples, max_target_length, check_image = True):
    """Run tokenization + image feature extraction"""
    image_paths = examples['image_path']
    captions = examples['caption']    
    
    model_inputs = {}
    # This contains image path column
    model_inputs['labels'] = tokenization_fn(captions, max_target_length)
    model_inputs['pixel_values'] = feature_extraction_fn(image_paths, check_image=check_image)

    return model_inputs

In [46]:
processed_dataset = ds.map(function=preprocess_fn, batched=True, fn_kwargs={"max_target_length": 128}, remove_columns=ds['train'].column_names)
processed_dataset

Map:   0%|          | 0/80 [00:00<?, ? examples/s]

Loading cached processed dataset at /home/drcl-yang/.cache/huggingface/datasets/ydshieh___coco_dataset_script/2017-eca5cab7eb1b92f0/0.0.0/e033205c0266a54c10be132f9264f2a39dcf893e798f6756d224b1ff5078998f/cache-ef15e4f3cd9c3b51.arrow
Loading cached processed dataset at /home/drcl-yang/.cache/huggingface/datasets/ydshieh___coco_dataset_script/2017-eca5cab7eb1b92f0/0.0.0/e033205c0266a54c10be132f9264f2a39dcf893e798f6756d224b1ff5078998f/cache-13ee92af179bfb18.arrow


DatasetDict({
    train: Dataset({
        features: ['labels', 'pixel_values'],
        num_rows: 80
    })
    validation: Dataset({
        features: ['labels', 'pixel_values'],
        num_rows: 80
    })
    test: Dataset({
        features: ['labels', 'pixel_values'],
        num_rows: 16
    })
})

In [47]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(predict_with_generate=True, evaluation_strategy="epoch", per_device_train_batch_size=4, per_device_eval_batch_size=4, output_dir="./image-captioning-output",)

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


In [48]:
import evaluate
metric = evaluate.load("rouge")

import numpy as np

ignore_pad_token_for_loss = True

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    # rougeLSum expects newline after each sentence
    preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

    return preds, labels

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    if ignore_pad_token_for_loss:
        # Replace -100 in the labels as we can't decode them.
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    result = {k: round(v * 100, 4) for k, v in result.items()}
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    
    return result

In [49]:
from transformers import default_data_collator

# instantiate trainer
trainer = Seq2SeqTrainer(model=model, tokenizer=feature_extractor, args=training_args, compute_metrics=compute_metrics, train_dataset=processed_dataset['train'], eval_dataset=processed_dataset['validation'], data_collator=default_data_collator,)

trainer.train()



  0%|          | 0/60 [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 100.00 MiB (GPU 0; 7.75 GiB total capacity; 5.21 GiB already allocated; 43.00 MiB free; 5.37 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
trainer.save_model("./image-captioning-output")

In [None]:
from transformers import pipeline

# full dataset trained model can be found at https://huggingface.co/nlpconnect/vit-gpt2-image-captioning
image_captioner = pipeline("image-to-text", model="./image-captioning-output")

image_captioner("sample_image.png")