In [1]:
from transformers import MarianMTModel, MarianTokenizer
import pandas as pd
from tqdm import tqdm
import os
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import json
from dotenv import load_dotenv
import numpy as np

In [2]:
# loading variables from .env file
load_dotenv("../../private_data/.env") 

# PARENT gets us to the root of the project
PARENT = "./../../"

MERGED_DATA_VALIDATION_SET = PARENT + os.getenv("MERGED_DATA_VALIDATION_SET")
MERGED_DATA_TESTING_SET = PARENT + os.getenv("MERGED_DATA_TESTING_SET")
MERGED_DATA_TRAINING_SET = PARENT + os.getenv("MERGED_DATA_TRAINING_SET")

In [3]:
BATCH_SIZE = 4
RUN_FR_EN = False
RUN_EN_NL = True


In [4]:
validation_captions = pd.read_csv(MERGED_DATA_VALIDATION_SET)
validation_captions

Unnamed: 0,recordID,category,focus,caption,length_tokenization
0,5002,Tableau,luminosity,"Une fête, un repas, une foule en intérieur, de...",30
1,10900,Tableau,luminosity,"Une femme joue du piano, elle lit une partitio...",36
2,5510,Dessin,luminosity,"Un adolescente est assise contre le mur, elle ...",28
3,4576,Tableau,content,"Un paysage avec un ciel nuageux, un lac en fon...",22
4,7689,Tableau,emotion,"Des hommes barbus et moustachus qui discutent,...",26
...,...,...,...,...,...
949,3904,Tableau,content,Une jeune femme avec longue robe blanche desce...,27
950,4072,Tableau,content,"Portrait d'une jeune femme qui lit un livre, v...",24
951,1294,Tableau,emotion,Un paysage désolé avec des arbres sans feuille...,31
952,482,Tableau,colors,Portrait d'une jeune femme avec une robe noire...,33


In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [6]:
class TranslationDataset(Dataset):
    def __init__(self, source_texts):
        self.source_texts = source_texts

    def __len__(self):
        return len(self.source_texts)

    def __getitem__(self, idx):
        return self.source_texts[idx]
    
def translate_batch(source_captions, model, tokenizer, device, batch_size=16):
    dataset = TranslationDataset(source_captions)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    translated_texts = []

    for i, batch in enumerate(tqdm(dataloader)):
        inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True)
        inputs = inputs.to(device)

        # Perform translation
        translated = model.generate(**inputs)

        # Decode translations
        translated_batch = [tokenizer.decode(t, skip_special_tokens=True) for t in translated]
        translated_texts.extend(translated_batch)
    
    resulting_df = validation_captions.copy()
    resulting_df["caption"] = translated_texts
    resulting_df["length_tokenization"] = np.nan
    return resulting_df


In [7]:
if RUN_FR_EN:
    model_fr_en = "Helsinki-NLP/opus-mt-fr-en"
    tokenizer_fr_en = MarianTokenizer.from_pretrained(model_fr_en)
    model_fr_en = MarianMTModel.from_pretrained(model_fr_en).to(device)

    validation_captions_EN = translate_batch(list(validation_captions["caption"]), model_fr_en, tokenizer_fr_en, device, batch_size=BATCH_SIZE)
    validation_captions_EN.to_csv(PARENT + os.getenv("MERGED_DATA_VALIDATION_SET_EN"), index=False)

validation_captions_EN = pd.read_csv(PARENT + os.getenv("MERGED_DATA_VALIDATION_SET_EN"))
validation_captions_EN

Unnamed: 0,recordID,category,focus,caption,length_tokenization
0,5002,Tableau,luminosity,"A party, a meal, a crowd inside, children play...",
1,10900,Tableau,luminosity,"A woman plays the piano, she reads a score, th...",
2,5510,Dessin,luminosity,"A teenage girl is sitting against the wall, lo...",
3,4576,Tableau,content,"A landscape with a cloudy sky, a lake in the b...",
4,7689,Tableau,emotion,"Bearded and mustachu men who discuss, costumes...",
...,...,...,...,...,...
949,3904,Tableau,content,A young woman with a long white dress goes dow...,
950,4072,Tableau,content,"Portrait of a young woman reading a book, blin...",
951,1294,Tableau,emotion,"A sorry landscape with leafless trees, a cloud...",
952,482,Tableau,colors,Portrait of a young woman with a black dress a...,


In [8]:
if RUN_EN_NL:
    model_en_nl = "Helsinki-NLP/opus-mt-en-nl" # We cannot translate from French to Dutch using opus
    tokenizer_en_nl = MarianTokenizer.from_pretrained(model_en_nl)
    model_en_nl = MarianMTModel.from_pretrained(model_en_nl).to(device)

    validation_captions_NL = translate_batch(list(validation_captions_EN["caption"]), model_en_nl, tokenizer_en_nl, device, batch_size=BATCH_SIZE)
    validation_captions_NL.to_csv(PARENT + os.getenv("MERGED_DATA_VALIDATION_SET_NL"), index=False)
    validation_captions_NL

100%|██████████| 239/239 [01:38<00:00,  2.43it/s]


In [9]:
validation_captions_NL

Unnamed: 0,recordID,category,focus,caption,length_tokenization
0,5002,Tableau,luminosity,"Een feestje, een maaltijd, een menigte binnen,...",
1,10900,Tableau,luminosity,"Een vrouw speelt piano, ze leest een partituur...",
2,5510,Dessin,luminosity,"Een tienermeisje zit tegen de muur, kijkt omho...",
3,4576,Tableau,content,"Een landschap met een bewolkte hemel, een meer...",
4,7689,Tableau,emotion,"Bebaarde en snorren mannen die bespreken, kost...",
...,...,...,...,...,...
949,3904,Tableau,content,Een jonge vrouw met een lange witte jurk gaat ...,
950,4072,Tableau,content,Portret van een jonge vrouw die een boek leest...,
951,1294,Tableau,emotion,"Een zielig landschap met bladloze bomen, een b...",
952,482,Tableau,colors,Portret van een jonge vrouw met een zwarte jur...,
