In [None]:
%%capture
!pip install transformers datasets gdown
!pip install --upgrade accelerate

In [None]:
from datasets import Dataset, DatasetDict
import pandas
import transformers
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    T5ForConditionalGeneration,
    T5Config,
)

In [None]:
%%capture
!gdown "https://drive.google.com/uc?id=1e0vl0UwBiWyQgwWevul_sDhwtkb7ASWP"
!gdown "https://drive.google.com/uc?id=13YfUdSJDPvkn4_weszCkKIIHjnkpEQ5T"
!gdown "https://drive.google.com/uc?id=1AFsyj4RepOzCPWQ7CIBQoqOEqRi4KFEO"

In [None]:
data_train = pandas.read_csv("t5-train-laeme-data.csv")
data_valid = pandas.read_csv("t5-valid-laeme-data.csv")
data_test = pandas.read_csv("t5-test-laeme-data.csv")

dataset = DatasetDict()

dataset["test"] = Dataset.from_pandas(data_test)
dataset["train"] = Dataset.from_pandas(data_train)
dataset["valid"] = Dataset.from_pandas(data_valid)

In [None]:
def preprocess(dataset):
    with tokenizer.as_target_tokenizer():
        input = tokenizer(
            dataset["input"],
            padding=True,
            add_special_tokens=False,
            return_attention_mask=True,
            return_tensors="pt",
        )
        target = tokenizer(
            dataset["target"],
            padding=True,
            add_special_tokens=True,
            return_attention_mask=True,
            return_tensors="pt",
        )
        input["labels"] = target["input_ids"].masked_fill(
            target.attention_mask.ne(1), -100
        )
        return input

In [None]:
CHECKPOINT = "google/byt5-small"

In [None]:
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)

In [None]:
tokenized_dataset = dataset.map(
    preprocess, batched=True, remove_columns=["input", "prefix", "id", "target"]
)

In [None]:
# === Here is how to decode the text with ByT5 tokenizer === #
# Mind you, the text is padded

# decoded_text = tokenizer.decode(tokenized_dataset["test"]["input_ids"][1], skip_special_tokens=False)

In [None]:
model = T5ForConditionalGeneration.from_pretrained(CHECKPOINT)

config = T5Config.from_pretrained(CHECKPOINT)
config.num_decoder_layers = 2
config.num_layers = 6
config.d_kv = 64
config.d_model = 256
config.d_ff = 512

args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    generation_num_beams=5,
    evaluation_strategy="steps",
    per_device_train_batch_size=32,  # 256
    per_device_eval_batch_size=64,  # 512
    num_train_epochs=5,  # 10
    gradient_accumulation_steps=1,  # 2
    learning_rate=1e-4,
    warmup_steps=1000,
    lr_scheduler_type="cosine",
    fp16=True, 
    output_dir="contents",
    logging_steps=1000,
    save_steps=20000,  # 5000
    eval_steps=20000,  # 5000
    save_total_limit=2,
    load_best_model_at_end=True,
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [None]:
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=tokenizer,
    args=args,
    data_collator=data_collator,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["valid"],
    compute_metrics=None,
)

In [None]:
trainer.train()