In [1]:
from datasets import load_from_disk
from pathlib import Path
from functools import partial
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
PROCESSED_DATA_DIR = Path.cwd().parent / "data" / "processed"

In [3]:
dataset = load_from_disk(PROCESSED_DATA_DIR / "captioning_dataset_augmented")

In [4]:
dataset["train"][4]

{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=220x275>,
 'clip_score': 0.3598970293496636,
 'file_name': 'eric-fischl_birthday-boy.jpg',
 'captions': ['The artwork depicts a naked boy and woman, reclining on the bed of a room with red sheets, with a cityscape in the background window.',
  'The artwork depicts a naked boy and woman lying on the bed of a room with red sheets, with an urban landscape in the background window.',
  'The artwork depicts a naked boy and a naked woman who lean on the bed of a room with red leaves, with a cityscape in the back window.']}

In [4]:
import nlpaug.augmenter.word as naw
back_translation_aug_de = naw.BackTranslationAug(
    from_model_name='Helsinki-NLP/opus-mt-en-de', 
    to_model_name='Helsinki-NLP/opus-mt-de-en',
    device='cuda',
    batch_size=16
)

Downloading (…)lve/main/config.json: 100%|██████████| 1.33k/1.33k [00:00<00:00, 7.95MB/s]
Downloading pytorch_model.bin: 100%|██████████| 298M/298M [00:06<00:00, 47.5MB/s] 
Downloading (…)neration_config.json: 100%|██████████| 293/293 [00:00<00:00, 914kB/s]
Downloading (…)lve/main/config.json: 100%|██████████| 1.38k/1.38k [00:00<00:00, 3.73MB/s]
Downloading pytorch_model.bin: 100%|██████████| 298M/298M [00:06<00:00, 47.9MB/s] 
Downloading (…)neration_config.json: 100%|██████████| 293/293 [00:00<00:00, 812kB/s]
Downloading (…)okenizer_config.json: 100%|██████████| 42.0/42.0 [00:00<00:00, 168kB/s]
Downloading (…)olve/main/source.spm: 100%|██████████| 768k/768k [00:00<00:00, 1.76MB/s]
Downloading (…)olve/main/target.spm: 100%|██████████| 797k/797k [00:00<00:00, 1.79MB/s]
Downloading (…)olve/main/vocab.json: 100%|██████████| 1.27M/1.27M [00:00<00:00, 2.30MB/s]
Downloading (…)okenizer_config.json: 100%|██████████| 42.0/42.0 [00:00<00:00, 161kB/s]
Downloading (…)olve/main/source.spm: 100%|██

In [5]:
# def captions_as_lists(examples):
#     captions_as_lists = [[caption] for caption in examples["caption"]]
#     examples["captions"] = captions_as_lists
#     return examples

# dataset = dataset.map(captions_as_lists, batched=True, remove_columns=["caption"])

In [6]:
dataset["train"][0]

{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=220x275>,
 'clip_score': 0.3412533427114651,
 'file_name': 'jamie-wyeth_pumpkinhead-self-portrait-1972.jpg',
 'captions': ['The artwork depicts a surreal self-portrait of the artist with a pumpkin for a head, standing in a desolated field.',
  'The artwork depicts a surrealist self-portrait of the artist with a pumpkin for a head, standing in a sorry field.']}

In [7]:
def augment_captions(examples, augmenter):
    aug_captions = augmenter.augment([captions[0].replace("The artwork depicts ", "") for captions in examples["captions"]])
    for i, captions in enumerate(examples["captions"]):
        captions.append("The artwork depicts " + aug_captions[i])
    return examples

In [8]:
# dataset = dataset.map(partial(augment_captions, augmenter=back_translation_aug_jap), batched=True)
dataset = dataset.map(partial(augment_captions, augmenter=back_translation_aug_de), batched=True)

                                                                     

In [9]:
dataset.save_to_disk("captioning_dataset_augmented")

                                                                                                

In [10]:
dataset["train"][55000]["captions"]

['The artwork depicts a woman in a red dress with a white parasol, standing in front of a blue and pink background with swirling patterns and flowers.',
 'The artwork depicts a woman in a red dress with a white umbrella, standing in front of a blue and pink background with swirling motifs and flowers.',
 'The artwork depicts a woman in a red dress with a white parasol standing in front of a blue and pink background with swirling patterns and flowers.']