In [1]:
import os
import datasets
from transformers import VisionEncoderDecoderModel, AutoFeatureExtractor,AutoTokenizer
os.environ["WANDB_DISABLED"] = "true"

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import nltk
try:
    nltk.data.find("tokenizers/punkt")
except (LookupError, OSError):
    nltk.download("punkt", quiet=True)

In [4]:
data_path = os.getcwd() + "/coco2017_data/"
ds = datasets.load_dataset("ydshieh/coco_dataset_script", "2017", data_dir=data_path)

ds

Downloading and preparing dataset coco_dataset_script/2017 to C:/Users/DRCL/.cache/huggingface/datasets/ydshieh___coco_dataset_script/2017-11a07fbbbc8b3323/0.0.0/e033205c0266a54c10be132f9264f2a39dcf893e798f6756d224b1ff5078998f...


Downloading data files: 100%|██████████| 5/5 [00:00<?, ?it/s]
Extracting data files: 100%|██████████| 5/5 [01:26<00:00, 17.25s/it]
                                                                        

Dataset coco_dataset_script downloaded and prepared to C:/Users/DRCL/.cache/huggingface/datasets/ydshieh___coco_dataset_script/2017-11a07fbbbc8b3323/0.0.0/e033205c0266a54c10be132f9264f2a39dcf893e798f6756d224b1ff5078998f. Subsequent calls will reuse this data.


100%|██████████| 3/3 [00:00<00:00, 44.80it/s]


DatasetDict({
    train: Dataset({
        features: ['image_id', 'caption_id', 'caption', 'height', 'width', 'file_name', 'coco_url', 'image_path'],
        num_rows: 591753
    })
    validation: Dataset({
        features: ['image_id', 'caption_id', 'caption', 'height', 'width', 'file_name', 'coco_url', 'image_path'],
        num_rows: 25014
    })
    test: Dataset({
        features: ['image_id', 'caption_id', 'caption', 'height', 'width', 'file_name', 'coco_url', 'image_path'],
        num_rows: 40670
    })
})

In [None]:
new_data = {"image_id" : -1, 
            "caption_id" : -1, 
            "height" : 640, 
            "width" : 480, 
            "file_name" : "test.png", 
            "coco_url" : "asd", 
            "image_path" : os.getcwd() + "/test.png",
            "caption": "The man is building a fire."}


new_ds = ds

new_ds["train"] = new_ds["train"].add_item(new_data)


In [5]:
from transformers import VisionEncoderDecoderModel, AutoTokenizer, AutoFeatureExtractor

image_encoder_model = "google/vit-base-patch16-224-in21k"
text_decode_model = "gpt2"

model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(image_encoder_model, text_decode_model)

# image feature extractor
feature_extractor = AutoFeatureExtractor.from_pretrained(image_encoder_model)
# text tokenizer
tokenizer = AutoTokenizer.from_pretrained(text_decode_model)

# GPT2 only has bos/eos tokens but not decoder_start/pad tokens
tokenizer.pad_token = tokenizer.eos_token

# update the model config
model.config.eos_token_id = tokenizer.eos_token_id
model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.pad_token_id = tokenizer.pad_token_id



In [6]:
output_dir = "./model/test-model"
model.save_pretrained(output_dir)
feature_extractor.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

('./model/test-model\\tokenizer_config.json',
 './model/test-model\\special_tokens_map.json',
 './model/test-model\\vocab.json',
 './model/test-model\\merges.txt',
 './model/test-model\\added_tokens.json',
 './model/test-model\\tokenizer.json')

In [16]:
# from PIL import Image
import cv2
import numpy as np

# text preprocessing step
def tokenization_fn(captions, max_target_length):
    """Run tokenization on captions."""
    labels = tokenizer(captions, truncation = "True", padding="max_length", max_length=max_target_length).input_ids

    return labels

def feature_extraction_fn(image_paths, check_image=True):
    """
    Run feature extraction on images
    If `check_image` is `True`, the examples that fails during `Image.open()` will be caught and discarded.
    Otherwise, an exception will be thrown.
    """

    model_inputs = {}

    if check_image:
        images = []
        to_keep = []
        for image_file in image_paths:
            try:
                # img = Image.open(image_file)
                img = cv2.imread(image_file)
                if len(img.shape) == 2:  # 이미지가 2차원인 경우
                    img = np.expand_dims(img, axis=2)  # 차원을 확장하여 3차원으로 만듦

                images.append(img)
                to_keep.append(True)
            except Exception:
                to_keep.append(False)
    else:
        images = [Image.open(image_file) for image_file in image_paths]

    encoder_inputs = feature_extractor(images=images, return_tensors="np")

    return encoder_inputs.pixel_values

def preprocess_fn(examples, max_target_length, check_image = True):
    """Run tokenization + image feature extraction"""
    image_paths = examples['image_path']
    captions = examples['caption']    
    
    model_inputs = {}
    # This contains image path column
    model_inputs['labels'] = tokenization_fn(captions, max_target_length)
    model_inputs['pixel_values'] = feature_extraction_fn(image_paths, check_image=check_image)

    return model_inputs

In [22]:
processed_dataset = ds.map(function=preprocess_fn, batched=True, fn_kwargs={"max_target_length": 2000}, remove_columns=ds['train'].column_names)
processed_dataset

# processed_dataset = new_ds.map(function=preprocess_fn, batched=True, fn_kwargs={"max_target_length": 128}, remove_columns=ds['train'].column_names)
# processed_dataset



In [None]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(predict_with_generate=True, evaluation_strategy="epoch", per_device_train_batch_size=4, per_device_eval_batch_size=4, output_dir="./model/image-captioning-output",num_train_epochs = 5)

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


In [None]:
import evaluate
metric = evaluate.load("rouge")

import numpy as np

ignore_pad_token_for_loss = True

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    # rougeLSum expects newline after each sentence
    preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

    return preds, labels

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    if ignore_pad_token_for_loss:
        # Replace -100 in the labels as we can't decode them.
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    result = {k: round(v * 100, 4) for k, v in result.items()}
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    
    return result

In [None]:
from transformers import default_data_collator

# instantiate trainer
trainer = Seq2SeqTrainer(model=model, tokenizer=feature_extractor, args=training_args, compute_metrics=compute_metrics, train_dataset=processed_dataset['train'], eval_dataset=processed_dataset['validation'], data_collator=default_data_collator,)

trainer.train()

  0%|          | 500/739695 [02:54<70:43:16,  2.90it/s]

{'loss': 0.3573, 'learning_rate': 4.996620228607738e-05, 'epoch': 0.0}


  0%|          | 1000/739695 [05:48<70:13:14,  2.92it/s]

{'loss': 0.2909, 'learning_rate': 4.993240457215474e-05, 'epoch': 0.01}


  0%|          | 1500/739695 [08:40<68:42:11,  2.98it/s] 

{'loss': 0.2811, 'learning_rate': 4.989860685823211e-05, 'epoch': 0.01}


  0%|          | 2000/739695 [11:32<69:19:01,  2.96it/s] 

{'loss': 0.2688, 'learning_rate': 4.986480914430948e-05, 'epoch': 0.01}


  0%|          | 2500/739695 [14:26<69:23:31,  2.95it/s] 

{'loss': 0.2629, 'learning_rate': 4.9831011430386854e-05, 'epoch': 0.02}


  0%|          | 3000/739695 [17:19<69:46:04,  2.93it/s] 

{'loss': 0.2647, 'learning_rate': 4.979721371646422e-05, 'epoch': 0.02}


  0%|          | 3500/739695 [20:14<69:55:45,  2.92it/s] 

{'loss': 0.2589, 'learning_rate': 4.976341600254159e-05, 'epoch': 0.02}


  1%|          | 4000/739695 [23:09<69:19:18,  2.95it/s] 

{'loss': 0.2596, 'learning_rate': 4.9729618288618964e-05, 'epoch': 0.03}


  1%|          | 4500/739695 [26:03<69:09:31,  2.95it/s] 

{'loss': 0.2608, 'learning_rate': 4.969582057469633e-05, 'epoch': 0.03}


  1%|          | 5000/739695 [28:57<69:01:22,  2.96it/s] 

{'loss': 0.2533, 'learning_rate': 4.96620228607737e-05, 'epoch': 0.03}


  1%|          | 5500/739695 [31:50<68:34:53,  2.97it/s] 

{'loss': 0.2514, 'learning_rate': 4.962822514685107e-05, 'epoch': 0.04}


  1%|          | 6000/739695 [34:45<69:06:48,  2.95it/s] 

{'loss': 0.2514, 'learning_rate': 4.9594427432928434e-05, 'epoch': 0.04}


  1%|          | 6500/739695 [37:38<68:57:58,  2.95it/s] 

{'loss': 0.252, 'learning_rate': 4.956062971900581e-05, 'epoch': 0.04}


  1%|          | 7000/739695 [40:33<68:34:13,  2.97it/s] 

{'loss': 0.2523, 'learning_rate': 4.9526832005083176e-05, 'epoch': 0.05}


  1%|          | 7500/739695 [43:28<67:49:55,  3.00it/s] 

{'loss': 0.2481, 'learning_rate': 4.949303429116055e-05, 'epoch': 0.05}


  1%|          | 8000/739695 [46:19<67:18:32,  3.02it/s] 

{'loss': 0.2492, 'learning_rate': 4.945923657723792e-05, 'epoch': 0.05}


  1%|          | 8500/739695 [49:10<67:08:09,  3.03it/s] 

{'loss': 0.2535, 'learning_rate': 4.9425438863315286e-05, 'epoch': 0.06}


  1%|          | 9000/739695 [52:00<67:12:18,  3.02it/s] 

{'loss': 0.2479, 'learning_rate': 4.939164114939266e-05, 'epoch': 0.06}


  1%|▏         | 9500/739695 [54:53<67:56:59,  2.99it/s] 

{'loss': 0.2467, 'learning_rate': 4.935784343547003e-05, 'epoch': 0.06}


  1%|▏         | 10000/739695 [57:42<67:33:27,  3.00it/s]

{'loss': 0.2449, 'learning_rate': 4.9324045721547395e-05, 'epoch': 0.07}


  1%|▏         | 10500/739695 [1:00:32<68:16:19,  2.97it/s]

{'loss': 0.2467, 'learning_rate': 4.929024800762477e-05, 'epoch': 0.07}


  1%|▏         | 11000/739695 [1:03:24<67:17:06,  3.01it/s] 

{'loss': 0.2389, 'learning_rate': 4.925645029370214e-05, 'epoch': 0.07}


  2%|▏         | 11500/739695 [1:06:16<67:04:24,  3.02it/s] 

{'loss': 0.2407, 'learning_rate': 4.9222652579779505e-05, 'epoch': 0.08}


  2%|▏         | 12000/739695 [1:09:07<67:35:59,  2.99it/s] 

{'loss': 0.2445, 'learning_rate': 4.918885486585688e-05, 'epoch': 0.08}


  2%|▏         | 12500/739695 [1:11:56<66:27:34,  3.04it/s] 

{'loss': 0.2427, 'learning_rate': 4.915505715193425e-05, 'epoch': 0.08}


  2%|▏         | 13000/739695 [1:14:49<67:38:12,  2.98it/s] 

{'loss': 0.2465, 'learning_rate': 4.9121259438011614e-05, 'epoch': 0.09}


  2%|▏         | 13500/739695 [1:17:40<66:26:13,  3.04it/s] 

{'loss': 0.2385, 'learning_rate': 4.908746172408898e-05, 'epoch': 0.09}


  2%|▏         | 14000/739695 [1:20:31<68:50:49,  2.93it/s] 

{'loss': 0.2346, 'learning_rate': 4.9053664010166356e-05, 'epoch': 0.09}


  2%|▏         | 14500/739695 [1:23:25<66:25:01,  3.03it/s] 

{'loss': 0.2374, 'learning_rate': 4.9019866296243724e-05, 'epoch': 0.1}


  2%|▏         | 15000/739695 [1:26:19<66:24:00,  3.03it/s] 

{'loss': 0.2359, 'learning_rate': 4.898606858232109e-05, 'epoch': 0.1}


  2%|▏         | 15500/739695 [1:29:10<65:49:51,  3.06it/s] 

{'loss': 0.2416, 'learning_rate': 4.8952270868398466e-05, 'epoch': 0.1}


  2%|▏         | 16000/739695 [1:32:01<66:32:31,  3.02it/s] 

{'loss': 0.2385, 'learning_rate': 4.8918473154475833e-05, 'epoch': 0.11}


  2%|▏         | 16500/739695 [1:34:54<66:50:17,  3.01it/s] 

{'loss': 0.2344, 'learning_rate': 4.88846754405532e-05, 'epoch': 0.11}


  2%|▏         | 17000/739695 [1:37:45<66:15:09,  3.03it/s] 

{'loss': 0.237, 'learning_rate': 4.8850877726630575e-05, 'epoch': 0.11}


  2%|▏         | 17500/739695 [1:40:36<66:54:35,  3.00it/s] 

{'loss': 0.2329, 'learning_rate': 4.881708001270794e-05, 'epoch': 0.12}


  2%|▏         | 18000/739695 [1:43:31<66:24:11,  3.02it/s] 

{'loss': 0.2367, 'learning_rate': 4.878328229878531e-05, 'epoch': 0.12}


  3%|▎         | 18500/739695 [1:46:21<67:08:15,  2.98it/s] 

{'loss': 0.2357, 'learning_rate': 4.874948458486268e-05, 'epoch': 0.13}


  3%|▎         | 18563/739695 [1:46:49<66:51:00,  3.00it/s] 

ValueError: expected sequence of length 128 at dim 1 (got 182)

In [11]:
trainer.save_model("./image-captioning-output")
tokenizer.save_pretrained("./image-captioning-output")

('./image-captioning-output\\tokenizer_config.json',
 './image-captioning-output\\special_tokens_map.json',
 './image-captioning-output\\vocab.json',
 './image-captioning-output\\merges.txt',
 './image-captioning-output\\added_tokens.json',
 './image-captioning-output\\tokenizer.json')

In [22]:

from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
import torch
from PIL import Image

model = VisionEncoderDecoderModel.from_pretrained("./image-captioning-output")
feature_extractor = ViTImageProcessor.from_pretrained("./image-captioning-output")
tokenizer = AutoTokenizer.from_pretrained("./image-captioning-output")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

max_length = 32
num_beams = 12
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}

def predict_step(image_paths):
  images = []
  for image_path in image_paths:
    i_image = Image.open(image_path)
    if i_image.mode != "RGB":
      i_image = i_image.convert(mode="RGB")

    images.append(i_image)

  pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values
  pixel_values = pixel_values.to(device)

  output_ids = model.generate(pixel_values, **gen_kwargs)

  preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
  preds = [pred.strip() for pred in preds]
  return preds




In [23]:
predict_step(['./images/Sunset_TB_360_00-00-03.jpg']) 

['A giraffe standing next to a tree.']