Skip to content

Latest commit

ย 

History

History
281 lines (194 loc) ยท 8.97 KB

image_captioning.md

File metadata and controls

281 lines (194 loc) ยท 8.97 KB

์ด๋ฏธ์ง€ ์บก์…”๋‹[[image-captioning]]

[[open-in-colab]]

์ด๋ฏธ์ง€ ์บก์…”๋‹(Image captioning)์€ ์ฃผ์–ด์ง„ ์ด๋ฏธ์ง€์— ๋Œ€ํ•œ ์บก์…˜์„ ์˜ˆ์ธกํ•˜๋Š” ์ž‘์—…์ž…๋‹ˆ๋‹ค. ์ด๋ฏธ์ง€ ์บก์…”๋‹์€ ์‹œ๊ฐ ์žฅ์• ์ธ์ด ๋‹ค์–‘ํ•œ ์ƒํ™ฉ์„ ํƒ์ƒ‰ํ•˜๋Š” ๋ฐ ๋„์›€์„ ์ค„ ์ˆ˜ ์žˆ๋„๋ก ์‹œ๊ฐ ์žฅ์• ์ธ์„ ๋ณด์กฐํ•˜๋Š” ๋“ฑ ์‹ค์ƒํ™œ์—์„œ ํ”ํžˆ ํ™œ์šฉ๋ฉ๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ ์ด๋ฏธ์ง€ ์บก์…”๋‹์€ ์ด๋ฏธ์ง€๋ฅผ ์„ค๋ช…ํ•จ์œผ๋กœ์จ ์‚ฌ๋žŒ๋“ค์˜ ์ฝ˜ํ…์ธ  ์ ‘๊ทผ์„ฑ์„ ๊ฐœ์„ ํ•˜๋Š” ๋ฐ ๋„์›€์ด ๋ฉ๋‹ˆ๋‹ค.

์ด ๊ฐ€์ด๋“œ์—์„œ๋Š” ์†Œ๊ฐœํ•  ๋‚ด์šฉ์€ ์•„๋ž˜์™€ ๊ฐ™์Šต๋‹ˆ๋‹ค:

  • ์ด๋ฏธ์ง€ ์บก์…”๋‹ ๋ชจ๋ธ์„ ํŒŒ์ธํŠœ๋‹ํ•ฉ๋‹ˆ๋‹ค.
  • ํŒŒ์ธํŠœ๋‹๋œ ๋ชจ๋ธ์„ ์ถ”๋ก ์— ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

์‹œ์ž‘ํ•˜๊ธฐ ์ „์— ํ•„์š”ํ•œ ๋ชจ๋“  ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๊ฐ€ ์„ค์น˜๋˜์–ด ์žˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”:

pip install transformers datasets evaluate -q
pip install jiwer -q

Hugging Face ๊ณ„์ •์— ๋กœ๊ทธ์ธํ•˜๋ฉด ๋ชจ๋ธ์„ ์—…๋กœ๋“œํ•˜๊ณ  ์ปค๋ฎค๋‹ˆํ‹ฐ์— ๊ณต์œ ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ํ† ํฐ์„ ์ž…๋ ฅํ•˜์—ฌ ๋กœ๊ทธ์ธํ•˜์„ธ์š”.

from huggingface_hub import notebook_login

notebook_login()

ํฌ์ผ“๋ชฌ BLIP ์บก์…˜ ๋ฐ์ดํ„ฐ์„ธํŠธ ๊ฐ€์ ธ์˜ค๊ธฐ[[load-the-pokmon-blip-captions-dataset]]

{์ด๋ฏธ์ง€-์บก์…˜} ์Œ์œผ๋กœ ๊ตฌ์„ฑ๋œ ๋ฐ์ดํ„ฐ์„ธํŠธ๋ฅผ ๊ฐ€์ ธ์˜ค๋ ค๋ฉด ๐Ÿค— Dataset ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. PyTorch์—์„œ ์ž์‹ ๋งŒ์˜ ์ด๋ฏธ์ง€ ์บก์…˜ ๋ฐ์ดํ„ฐ์„ธํŠธ๋ฅผ ๋งŒ๋“ค๋ ค๋ฉด ์ด ๋…ธํŠธ๋ถ์„ ์ฐธ์กฐํ•˜์„ธ์š”.

from datasets import load_dataset

ds = load_dataset("lambdalabs/pokemon-blip-captions")
ds
DatasetDict({
    train: Dataset({
        features: ['image', 'text'],
        num_rows: 833
    })
})

์ด ๋ฐ์ดํ„ฐ์„ธํŠธ๋Š” image์™€ text๋ผ๋Š” ๋‘ ํŠน์„ฑ์„ ๊ฐ€์ง€๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.

๋งŽ์€ ์ด๋ฏธ์ง€ ์บก์…˜ ๋ฐ์ดํ„ฐ์„ธํŠธ์—๋Š” ์ด๋ฏธ์ง€๋‹น ์—ฌ๋Ÿฌ ๊ฐœ์˜ ์บก์…˜์ด ํฌํ•จ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ๊ฒฝ์šฐ, ์ผ๋ฐ˜์ ์œผ๋กœ ํ•™์Šต ์ค‘์— ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ์บก์…˜ ์ค‘์—์„œ ๋ฌด์ž‘์œ„๋กœ ์ƒ˜ํ”Œ์„ ์ถ”์ถœํ•ฉ๋‹ˆ๋‹ค.

[~datasets.Dataset.train_test_split] ๋ฉ”์†Œ๋“œ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ฐ์ดํ„ฐ์„ธํŠธ์˜ ํ•™์Šต ๋ถ„ํ• ์„ ํ•™์Šต ๋ฐ ํ…Œ์ŠคํŠธ ์„ธํŠธ๋กœ ๋‚˜๋ˆ•๋‹ˆ๋‹ค:

ds = ds["train"].train_test_split(test_size=0.1)
train_ds = ds["train"]
test_ds = ds["test"]

ํ•™์Šต ์„ธํŠธ์˜ ์ƒ˜ํ”Œ ๋ช‡ ๊ฐœ๋ฅผ ์‹œ๊ฐํ™”ํ•ด ๋ด…์‹œ๋‹ค. Let's visualize a couple of samples from the training set.

from textwrap import wrap
import matplotlib.pyplot as plt
import numpy as np


def plot_images(images, captions):
    plt.figure(figsize=(20, 20))
    for i in range(len(images)):
        ax = plt.subplot(1, len(images), i + 1)
        caption = captions[i]
        caption = "\n".join(wrap(caption, 12))
        plt.title(caption)
        plt.imshow(images[i])
        plt.axis("off")


sample_images_to_visualize = [np.array(train_ds[i]["image"]) for i in range(5)]
sample_captions = [train_ds[i]["text"] for i in range(5)]
plot_images(sample_images_to_visualize, sample_captions)
Sample training images

๋ฐ์ดํ„ฐ์„ธํŠธ ์ „์ฒ˜๋ฆฌ[[preprocess-the-dataset]]

๋ฐ์ดํ„ฐ์„ธํŠธ์—๋Š” ์ด๋ฏธ์ง€์™€ ํ…์ŠคํŠธ๋ผ๋Š” ๋‘ ๊ฐ€์ง€ ์–‘์‹์ด ์žˆ๊ธฐ ๋•Œ๋ฌธ์—, ์ „์ฒ˜๋ฆฌ ํŒŒ์ดํ”„๋ผ์ธ์—์„œ ์ด๋ฏธ์ง€์™€ ์บก์…˜์„ ๋ชจ๋‘ ์ „์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค.

์ „์ฒ˜๋ฆฌ ์ž‘์—…์„ ์œ„ํ•ด, ํŒŒ์ธํŠœ๋‹ํ•˜๋ ค๋Š” ๋ชจ๋ธ์— ์—ฐ๊ฒฐ๋œ ํ”„๋กœ์„ธ์„œ ํด๋ž˜์Šค๋ฅผ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.

from transformers import AutoProcessor

checkpoint = "microsoft/git-base"
processor = AutoProcessor.from_pretrained(checkpoint)

ํ”„๋กœ์„ธ์„œ๋Š” ๋‚ด๋ถ€์ ์œผ๋กœ ํฌ๊ธฐ ์กฐ์ • ๋ฐ ํ”ฝ์…€ ํฌ๊ธฐ ์กฐ์ •์„ ํฌํ•จํ•œ ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ๋ฅผ ์ˆ˜ํ–‰ํ•˜๊ณ  ์บก์…˜์„ ํ† ํฐํ™”ํ•ฉ๋‹ˆ๋‹ค.

def transforms(example_batch):
    images = [x for x in example_batch["image"]]
    captions = [x for x in example_batch["text"]]
    inputs = processor(images=images, text=captions, padding="max_length")
    inputs.update({"labels": inputs["input_ids"]})
    return inputs


train_ds.set_transform(transforms)
test_ds.set_transform(transforms)

๋ฐ์ดํ„ฐ์„ธํŠธ๊ฐ€ ์ค€๋น„๋˜์—ˆ์œผ๋‹ˆ ์ด์ œ ํŒŒ์ธํŠœ๋‹์„ ์œ„ํ•ด ๋ชจ๋ธ์„ ์„ค์ •ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๊ธฐ๋ณธ ๋ชจ๋ธ ๊ฐ€์ ธ์˜ค๊ธฐ[[load-a-base-model]]

"microsoft/git-base"๋ฅผ AutoModelForCausalLM ๊ฐ์ฒด๋กœ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.

from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(checkpoint)

ํ‰๊ฐ€[[evaluate]]

์ด๋ฏธ์ง€ ์บก์…˜ ๋ชจ๋ธ์€ ์ผ๋ฐ˜์ ์œผ๋กœ Rouge ์ ์ˆ˜ ๋˜๋Š” ๋‹จ์–ด ์˜ค๋ฅ˜์œจ(Word Error Rate)๋กœ ํ‰๊ฐ€ํ•ฉ๋‹ˆ๋‹ค. ์ด ๊ฐ€์ด๋“œ์—์„œ๋Š” ๋‹จ์–ด ์˜ค๋ฅ˜์œจ(WER)์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

์ด๋ฅผ ์œ„ํ•ด ๐Ÿค— Evaluate ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. WER์˜ ์ž ์žฌ์  ์ œํ•œ ์‚ฌํ•ญ ๋ฐ ๊ธฐํƒ€ ๋ฌธ์ œ์ ์€ ์ด ๊ฐ€์ด๋“œ๋ฅผ ์ฐธ์กฐํ•˜์„ธ์š”.

from evaluate import load
import torch

wer = load("wer")


def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predicted = logits.argmax(-1)
    decoded_labels = processor.batch_decode(labels, skip_special_tokens=True)
    decoded_predictions = processor.batch_decode(predicted, skip_special_tokens=True)
    wer_score = wer.compute(predictions=decoded_predictions, references=decoded_labels)
    return {"wer_score": wer_score}

ํ•™์Šต![[train!]]

์ด์ œ ๋ชจ๋ธ ํŒŒ์ธํŠœ๋‹์„ ์‹œ์ž‘ํ•  ์ค€๋น„๊ฐ€ ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ์ด๋ฅผ ์œ„ํ•ด ๐Ÿค— [Trainer]๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

๋จผ์ €, [TrainingArguments]๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํ•™์Šต ์ธ์ˆ˜๋ฅผ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค.

from transformers import TrainingArguments, Trainer

model_name = checkpoint.split("/")[1]

training_args = TrainingArguments(
    output_dir=f"{model_name}-pokemon",
    learning_rate=5e-5,
    num_train_epochs=50,
    fp16=True,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    gradient_accumulation_steps=2,
    save_total_limit=3,
    eval_strategy="steps",
    eval_steps=50,
    save_strategy="steps",
    save_steps=50,
    logging_steps=50,
    remove_unused_columns=False,
    push_to_hub=True,
    label_names=["labels"],
    load_best_model_at_end=True,
)

ํ•™์Šต ์ธ์ˆ˜๋ฅผ ๋ฐ์ดํ„ฐ์„ธํŠธ, ๋ชจ๋ธ๊ณผ ํ•จ๊ป˜ ๐Ÿค— Trainer์— ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค.

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    compute_metrics=compute_metrics,
)

ํ•™์Šต์„ ์‹œ์ž‘ํ•˜๋ ค๋ฉด [Trainer] ๊ฐ์ฒด์—์„œ [~Trainer.train]์„ ํ˜ธ์ถœํ•˜๊ธฐ๋งŒ ํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค.

trainer.train()

ํ•™์Šต์ด ์ง„ํ–‰๋˜๋ฉด์„œ ํ•™์Šต ์†์‹ค์ด ์›ํ™œํ•˜๊ฒŒ ๊ฐ์†Œํ•˜๋Š” ๊ฒƒ์„ ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

ํ•™์Šต์ด ์™„๋ฃŒ๋˜๋ฉด ๋ชจ๋“  ์‚ฌ๋žŒ์ด ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋„๋ก [~Trainer.push_to_hub] ๋ฉ”์†Œ๋“œ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ์„ ํ—ˆ๋ธŒ์— ๊ณต์œ ํ•˜์„ธ์š”:

trainer.push_to_hub()

์ถ”๋ก [[inference]]

test_ds์—์„œ ์ƒ˜ํ”Œ ์ด๋ฏธ์ง€๋ฅผ ๊ฐ€์ ธ์™€ ๋ชจ๋ธ์„ ํ…Œ์ŠคํŠธํ•ฉ๋‹ˆ๋‹ค.

from PIL import Image
import requests

url = "https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/pokemon.png"
image = Image.open(requests.get(url, stream=True).raw)
image
Test image

๋ชจ๋ธ์— ์‚ฌ์šฉํ•  ์ด๋ฏธ์ง€๋ฅผ ์ค€๋น„ํ•ฉ๋‹ˆ๋‹ค.

device = "cuda" if torch.cuda.is_available() else "cpu"

inputs = processor(images=image, return_tensors="pt").to(device)
pixel_values = inputs.pixel_values

[generate]๋ฅผ ํ˜ธ์ถœํ•˜๊ณ  ์˜ˆ์ธก์„ ๋””์ฝ”๋”ฉํ•ฉ๋‹ˆ๋‹ค.

generated_ids = model.generate(pixel_values=pixel_values, max_length=50)
generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(generated_caption)
a drawing of a pink and blue pokemon

ํŒŒ์ธํŠœ๋‹๋œ ๋ชจ๋ธ์ด ๊ฝค ๊ดœ์ฐฎ์€ ์บก์…˜์„ ์ƒ์„ฑํ•œ ๊ฒƒ ๊ฐ™์Šต๋‹ˆ๋‹ค!