In [1]:
import os
import wandb
import pandas as pd
from PIL import Image
from time import time
from evaluate import load
from torch.utils.data import Dataset
from transformers import BlipProcessor, BlipForConditionalGeneration
from transformers import AutoProcessor, AutoModelForCausalLM

In [2]:
# Кастомный класс создания синтетического датасета
class CustomDataset(Dataset):
    def __init__(self, data_dir, data, transform=None):
        self.data_dir = data_dir
        self.data = data
        self.transform = transform
        self.images = [os.path.join(data_dir, img) for img in os.listdir(data_dir)]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_path = self.images[idx]
        image, caption = loader(image_path, self.data)
        if self.transform:
            image = self.transform(image)
        return image, caption


def loader(path, data):
    image = Image.open(path)
    caption = data.loc[data["image"] == os.path.basename(path), "caption"].values[0]
    return image, caption

In [3]:
data_dir = "datasets/sin_dataset_img"

data = pd.read_csv(f"{data_dir}/captions.csv")

ds_sin = CustomDataset(f"{data_dir}/images", data)

In [4]:
# Функция расчета метрики METEOR
def metric_meteor(predicted_captions, reference_captions):
    meteor = load("meteor")
    meteor_avg = meteor.compute(
        predictions=predicted_captions, references=reference_captions
    )

    return meteor_avg

In [5]:
# Функция расчета метрики ROUGE
def metric_rouge(predicted_captions, reference_captions):
    rouge = load("rouge")
    rouge_avg = rouge.compute(
        predictions=predicted_captions, references=reference_captions
    )

    return rouge_avg

In [6]:
# Функция расчета метрики WER
def metric_wer(predicted_captions, reference_captions):
    wer = load("wer")
    wer_avg = wer.compute(predictions=predicted_captions, references=reference_captions)

    return wer_avg

In [7]:
# Простая модель генерации описания изображений (показатели снимались на GPU V100)
model_name = "microsoft/git-base"

wandb.init(project="child_diary", group=model_name, job_type="base")

# Загрузка модели генерации описаний изображений
processor = AutoProcessor.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

predicted_captions = []
reference_captions = []

start_time = time()

# Выполнение предсказания модели
for image, captions in ds_sin:
    pixel_values = processor(images=image, return_tensors="pt").pixel_values
    generated_ids = model.generate(pixel_values=pixel_values, max_length=50)
    pred_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

    predicted_captions.append(pred_caption)
    reference_captions.append(captions)

end_time = time()

rouge_result = metric_rouge(predicted_captions, reference_captions)

wandb.log(
    {
        "METEOR": metric_meteor(predicted_captions, reference_captions),
        "ROUGE-1": rouge_result["rouge1"],
        "ROUGE-2": rouge_result["rouge2"],
        "ROUGE-L": rouge_result["rougeL"],
        "WER": metric_wer(predicted_captions, reference_captions),
        "Speed 1 image": (end_time - start_time) / len(ds_sin),
    }
)

wandb.finish()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33mn-hilkovich[0m. Use [1m`wandb login --relogin`[0m to force relogin


[nltk_data] Downloading package wordnet to /home/nikolai/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt to /home/nikolai/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /home/nikolai/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
ROUGE-1,▁
ROUGE-2,▁
ROUGE-L,▁
Speed 1 image,▁
WER,▁

0,1
ROUGE-1,0.25369
ROUGE-2,0.08123
ROUGE-L,0.23948
Speed 1 image,1.84151
WER,0.87461


In [8]:
# Средняя модель генерации описания изображений (показатели снимались на GPU V100)
model_name = "Salesforce/blip-image-captioning-large"

wandb.init(project="child_diary", group=model_name, job_type="base")

# Загрузка модели генерации описаний изображений
processor = BlipProcessor.from_pretrained(model_name)
model = BlipForConditionalGeneration.from_pretrained(model_name)

predicted_captions = []
reference_captions = []

start_time = time()

# Выполнение предсказания модели
for image, captions in ds_sin:
    inputs = processor(image, return_tensors="pt")
    out = model.generate(**inputs)
    predicted_captions.append(processor.decode(out[0], skip_special_tokens=True))
    reference_captions.append(captions)

end_time = time()

rouge_result = metric_rouge(predicted_captions, reference_captions)

wandb.log(
    {
        "METEOR": metric_meteor(predicted_captions, reference_captions),
        "ROUGE-1": rouge_result["rouge1"],
        "ROUGE-2": rouge_result["rouge2"],
        "ROUGE-L": rouge_result["rougeL"],
        "WER": metric_wer(predicted_captions, reference_captions),
        "Speed 1 image": (end_time - start_time) / len(ds_sin),
    }
)

wandb.finish()

[nltk_data] Downloading package wordnet to /home/nikolai/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt to /home/nikolai/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /home/nikolai/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
ROUGE-1,▁
ROUGE-2,▁
ROUGE-L,▁
Speed 1 image,▁
WER,▁

0,1
ROUGE-1,0.44335
ROUGE-2,0.19654
ROUGE-L,0.36868
Speed 1 image,2.94832
WER,0.80062


In [9]:
# Сложная модель генерации описания изображений (показатели снимались на GPU V100)
model_name = "abhijit2111/Pic2Story"

wandb.init(project="child_diary", group=model_name, job_type="base")

# Загрузка модели генерации описаний изображений
processor = BlipProcessor.from_pretrained(model_name)
model = BlipForConditionalGeneration.from_pretrained(model_name)

predicted_captions = []
reference_captions = []

start_time = time()

# Выполнение предсказания модели
for image, captions in ds_sin:
    inputs = processor(image, return_tensors="pt")
    out = model.generate(**inputs)
    predicted_captions.append(processor.decode(out[0], skip_special_tokens=True))
    reference_captions.append(captions)

end_time = time()

rouge_result = metric_rouge(predicted_captions, reference_captions)

wandb.log(
    {
        "METEOR": metric_meteor(predicted_captions, reference_captions),
        "ROUGE-1": rouge_result["rouge1"],
        "ROUGE-2": rouge_result["rouge2"],
        "ROUGE-L": rouge_result["rougeL"],
        "WER": metric_wer(predicted_captions, reference_captions),
        "Speed 1 image": (end_time - start_time) / len(ds_sin),
    }
)

wandb.finish()

[nltk_data] Downloading package wordnet to /home/nikolai/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt to /home/nikolai/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /home/nikolai/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


VBox(children=(Label(value='0.003 MB of 0.008 MB uploaded\r'), FloatProgress(value=0.35766670427620445, max=1.…

0,1
ROUGE-1,▁
ROUGE-2,▁
ROUGE-L,▁
Speed 1 image,▁
WER,▁

0,1
ROUGE-1,0.49204
ROUGE-2,0.22125
ROUGE-L,0.38709
Speed 1 image,3.76119
WER,0.89072
