# Train grammar error correction model


In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*

import torch
import numpy as np

from utils.logging import get_logger
from utils.metrics import (
    metric_bleu,
    metric_sacrebleu,
    metric_gleu,
    metric_chrf,
    metric_meteor,
    metric_ter,
    metric_cer,
    metric_wer,
)
from helper_model import GEC_MODEL, GEC_DIRECTORY
from prepare_gec_dataset import load_gec_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq,
)


In [None]:
# Get logger
train_gec_model = get_logger("Train GEC model")


In [None]:
# Constants
MODEL_CHECKPOINT = "cjvt/t5-sl-small"
MODEL_NAME = GEC_MODEL
BATCH_SIZE = 16  # 32


In [None]:
# Load the GEC dataset
dataset = load_gec_dataset(GEC_DIRECTORY)
train_gec_model.info("{} dataset read".format(MODEL_NAME))

# Create the tokenizer and the model for our model (SloBERTa)
tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT, use_fast=True)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_CHECKPOINT)
train_gec_model.info("{} model and tokenizer initialized".format(MODEL_NAME))


In [None]:
# Making the code device-agnostic
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Transferring the model to a CUDA enabled GPU
model = model.to(DEVICE)


In [None]:
def tokenize_function(data):
    """
    Tokenize sentences with specific tokenizer which suits our model. Tokenizer
    will tokenize text inputs and put it in a format the model excepts, as well
    as generate the other inputs that model generates

    NB: use text target to tokenize our labels
    NB: we use truncation to ensure that the input longer than what the model
    can handle will be truncated to the maximum length accepted by the model.
    NB: we used batched processing to leverage the full benefit of the fast
    tokenizer.

    @param data: the data we want to tokenize
    @return: tokenized data with a specific model required tokenizer
    """
    data = tokenizer(data["source"], text_target=data["target"], truncation=True)

    return data


In [None]:
# Apply tokenize function on all the sentences in our dataset
encoded_dataset = dataset.map(tokenize_function, batched=True)

# Setup the training arguments
args = Seq2SeqTrainingArguments(
    output_dir=MODEL_NAME,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=1e-5,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    num_train_epochs=5,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="gleu",
    greater_is_better=True,
    auto_find_batch_size=True,
    report_to="all",
    predict_with_generate=True,
    deepspeed="./deepspeed_config.json",
)

# Data collator, which will pad the inputs and the labels to the max length
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)


In [None]:
def compute_metrics(eval_pred):
    """
    Get a predictions, which need to be evaluated, and evaluate them with specific
    metric.

    @param eval_pred: the predictions, which needs to be evaluated
    @return: evaluation score
    """
    predictions, labels = eval_pred
    print(predictions, labels)
    if isinstance(predictions, tuple):
        predictions = predictions[0]
    decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)

    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post processing
    decoded_predictions = [
        temp_prediction.strip() for temp_prediction in decoded_predictions
    ]
    decoded_labels = [[temp_label.strip()] for temp_label in decoded_labels]

    bleu = metric_bleu.compute(
        predictions=decoded_predictions, references=decoded_labels
    )["bleu"]

    sacrebleu = metric_sacrebleu.compute(
        predictions=decoded_predictions, references=decoded_labels
    )["score"]

    gleu = metric_gleu.compute(
        predictions=decoded_predictions, references=decoded_labels
    )["google_bleu"]

    chrf = metric_chrf.compute(
        predictions=decoded_predictions, references=decoded_labels
    )["score"]

    meteor = metric_meteor.compute(
        predictions=decoded_predictions, references=decoded_labels
    )["meteor"]

    ter = metric_ter.compute(
        predictions=decoded_predictions, references=decoded_labels
    )["score"]

    cer = metric_cer.compute(
        predictions=decoded_predictions,
        references=[
            decoded_label for element in decoded_labels for decoded_label in element
        ],
    )

    wer = metric_wer.compute(
        predictions=decoded_predictions,
        references=[
            decoded_label for element in decoded_labels for decoded_label in element
        ],
    )

    prediction_lengths = [
        np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions
    ]

    return {
        "bleu": bleu,
        "sacrebleu": sacrebleu,
        "gleu": gleu,
        "chrf": chrf,
        "meteor": meteor,
        "ter": ter,
        "cer": cer,
        "wer": wer,
        "prediction_lengths": np.mean(prediction_lengths),
    }


In [None]:
def model_init():
    """
    Create a model for sequence classification with two labels.
    @return: a model, which we will fine tune
    """
    return AutoModelForSeq2SeqLM.from_pretrained(MODEL_CHECKPOINT).to(DEVICE)


In [None]:
# Hyperparameter search
trainer = Seq2SeqTrainer(
    model_init=model_init,
    args=args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)
train_gec_model.info("{} trainer initialized".format(MODEL_NAME))

# Find most optimal parameters for our model
train_gec_model.info("{} GEC hyperparameter search started".format(MODEL_NAME))
hyperparameters = trainer.hyperparameter_search(direction="maximize")
train_gec_model.info("{} GEC hyperparameter search ended".format(MODEL_NAME))


In [None]:
# Use most optimal parameters
for name, value in hyperparameters.hyperparameters.items():
    setattr(trainer.args, name, value)
train_gec_model.info("Hyperparameters: {}".format(hyperparameters.hyperparameters))

# Fine tune the model for GEC task
train_gec_model.info("{} model training started".format(MODEL_NAME))
trainer.train()
train_gec_model.info("{} model training ended".format(MODEL_NAME))

# Check if the trainer did reload the best model and not the last
train_gec_model.info(trainer.evaluate())

# Save the model so it can be reloaded with from_pretrained()
trainer.save_model(MODEL_NAME)
train_gec_model.info("{} model saved".format(MODEL_NAME))
