### Training the model

In [None]:
#! pip install transformers[sentencepiece] datasets
#! pip install sacrebleu sentencepiece
#! pip install huggingface_hub
# pip install tensorflow==2.9
# pip freeze > requirements.txt
# from huggingface_hub import notebook_login
# notebook_login()
# !apt install git-lfs
# !git config --global user.email "you@example.com"
# !git config --global user.name "Your Name"

In [None]:
import transformers
import pandas as pd
from datasets import Dataset, load_from_disk, load_dataset
from evaluate import load
import datasets
import random
import pandas as pd
from IPython.display import display, HTML
from transformers import AutoTokenizer
import spacy
from transformers.keras_callbacks import KerasMetricCallback
import numpy as np
from transformers.keras_callbacks import PushToHubCallback
from tensorflow.keras.callbacks import TensorBoard
from transformers import TFAutoModelForSeq2SeqLM, DataCollatorForSeq2Seq
from transformers import AdamWeightDecay
import tensorflow as tf
from copy import deepcopy

def csv_to_dataset(filename, source_lang, target_lang, pos_tags=False, wa_tags=False, store=False):
    data = pd.read_csv(filename)
    new_df = pd.DataFrame()
    new_df['translation'] = [{source_lang: x, target_lang: y} for x, y in zip(data[source_lang], data[target_lang])]
    if pos_tags:
        new_df['pos'] = [{source_lang: x, target_lang: y} for x, y in zip(data[f'pos_{source_lang}'], data[f'pos_{target_lang}'])]
    if wa_tags:
        new_df['wa'] = data['wa']
    return Dataset.from_pandas(new_df).train_test_split(test_size=0.2)

loaded_dataset = load_from_disk('fr_dataset_split.hf')
loaded_dataset.remove_columns(['pos', 'wa'])

model_checkpoint = "Helsinki-NLP/opus-mt-en-fr"
metric = load("sacrebleu")
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

if "mbart" in model_checkpoint:
    tokenizer.src_lang = "en-XX"
    tokenizer.tgt_lang = "fr-FR"


en_pos_sp = spacy.load("en_core_web_sm")
fr_pos_sp = spacy.load('fr_core_news_sm')

if model_checkpoint in ["t5-small", "t5-base", "t5-larg", "t5-3b", "t5-11b"]:
    prefix = "translate English to French: "
else:
    prefix = ""

def token_to_pos(token, lang):
    if lang == 'en':
        decoded = list(en_pos_sp(tokenizer.decode(token)))
    elif lang == 'fr':
        decoded = list(fr_pos_sp(tokenizer.decode(token)))
    return decoded[-1].pos if decoded else -1

def get_pos_tags(tokenized_sent, lang):
    return list(map(lambda x: token_to_pos(x, lang), tokenized_sent))

def metric_fn(eval_predictions):
    preds, labels = eval_predictions
    prediction_lens = [
        np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds
    ]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    # We use -100 to mask labels - replace it with the tokenizer pad token when decoding
    # so that no output is emitted for these
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds = [pred.strip() for pred in decoded_preds]
    decoded_labels = [[label.strip()] for label in decoded_labels]

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    result = {"bleu": result["score"]}
    result["gen_len"] = np.mean(prediction_lens)
    return result

max_input_length = 128
max_target_length = 128
batch_size = 16
learning_rate = 2e-5
weight_decay = 0.01
num_train_epochs = 1
source_lang = "en"
target_lang = "fr"
# Pos_tags need to be set to True in the cell
pos_tags = True
wa_tags = False
wa_type = None

def encode_wa(tokenized_input, tokenized_target, wa, wa_type):
    wa_dict = {int(src): int(trg) for src, trg in map(lambda x: x.split('-'), wa.split())}
    n = len(tokenized_input)
    m = len(tokenized_target)
    if wa_type == 'trg-ids':
        wa_emb = [0]*n
        for k, v in wa_dict.items():
            if k >= n or v >= m:
                break
            wa_emb[k] = tokenized_target[v]
        return wa_emb
    elif wa_type == 'sums':
        wa_emb = deepcopy(tokenized_input)
        for k, v in wa_dict.items():
            if k >= n or v >= m:
                break
            wa_emb[k] += tokenized_target[v]
        return wa_emb
    
    elif wa_type == 'mult':
        wa_emb = deepcopy(tokenized_input)
        for k, v in wa_dict.items():
            if k >= n or v >= m:
                break
            wa_emb[k] *= tokenized_target[v]
        return wa_emb

def preprocess_function(dataset):
    global source_lang, target_lang, pos_tags, wa_type
    inputs = [prefix + d[source_lang] for d in dataset["translation"]]
    targets = [d[target_lang] for d in dataset["translation"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    
    if pos_tags:
        model_inputs['pos'] = [get_pos_tags(x, 'en') for x in model_inputs['input_ids']]
        model_inputs['target_pos'] = [get_pos_tags(y, 'fr') for y in model_inputs['labels']]
        
    if wa_type:
        model_inputs['wa'] = [encode_wa(src, trg, wa, wa_type) for src, trg, wa \
                              in zip(model_inputs['input_ids'],  model_inputs['labels'], dataset["wa"])]
    return model_inputs

split_dataset = loaded_dataset.remove_columns(['wa'])

pos_anno_dataset = split_dataset.map(preprocess_function, batched=True)

model_with_pos = TFAutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
data_collator_pos = DataCollatorForSeq2Seq(tokenizer, model=model_with_pos, return_tensors="tf")
generation_data_collator_pos = DataCollatorForSeq2Seq(tokenizer, model=model_with_pos, return_tensors="tf", pad_to_multiple_of=128)

train_dataset = model_with_pos.prepare_tf_dataset(
    pos_anno_dataset["train"],
    batch_size=batch_size,
    shuffle=True,
    collate_fn=data_collator_pos,
)

validation_dataset = model_with_pos.prepare_tf_dataset(
    pos_anno_dataset["test"],
    batch_size=batch_size,
    shuffle=False,
    collate_fn=data_collator_pos,
)

generation_dataset = model_with_pos.prepare_tf_dataset(
    pos_anno_dataset["test"],
    batch_size=8,
    shuffle=False,
    collate_fn=data_collator_pos,
)

optimizer = AdamWeightDecay(learning_rate=learning_rate, weight_decay_rate=weight_decay)
model_with_pos.compile(optimizer=optimizer)
metric_callback = KerasMetricCallback(
    metric_fn=metric_fn, eval_dataset=generation_dataset, predict_with_generate=True, use_xla_generation=True, 
    generate_kwargs={"max_length": 128}
)
tensorboard_callback = TensorBoard(log_dir="./translation_model_save/logs")

callbacks = [metric_callback, tensorboard_callback]

model_with_pos.fit(
    train_dataset, validation_data=validation_dataset, epochs=1, callbacks=callbacks
)

### Generating translations

In [None]:
vital = pd.read_csv("translations_ft_wa.csv")
vital.head()
# Check that the dataframe opens correctly

In [None]:
outputs_pos = []

for input_sentence in vital["sentences"]:
    tokenized_sentence = tokenizer([input_sentence], return_tensors='np')
    out = model_with_pos.generate(**tokenized_sentence, max_length=128)
    with tokenizer.as_target_tokenizer():
        output_sentence = tokenizer.decode(out[0], skip_special_tokens=True)
        print(output_sentence)
        outputs_pos.append(output_sentence)

vital["ft_pos"] = outputs_pos

In [None]:
vital.to_csv("translations_ft_wa_pos.csv", index=False)