# Grammar Error Handling

This file represents the whole pipeline for grammar error handling. This includes:
- Text preprocessing
- Sloleks component
- Grammar Error Detection (GED)
- Grammar Error Recognition (GER)
- Grammar Error Correction (GEC).

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-


import argparse
import pandas as pd
import string

from train_model.utils.logging import get_logger
from nltk.tokenize import word_tokenize, sent_tokenize
from transformers import (
    AutoModelForSequenceClassification,
    AutoModelForTokenClassification,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    TextClassificationPipeline,
    TokenClassificationPipeline,
    Text2TextGenerationPipeline,
)


In [None]:
# Get logger
logger_geh = get_logger("GEH")


In [None]:
# Constants
SLOLEKS = "../data/sloleks/sloleks.csv"
GED = "./models/ged_model_finetuned/"
GER = "./models/ger_model_finetuned/"
GEC = "./models/gec_model_finetuned/"
MAX_ERROR_RATE_SLOLEKS = 0.1
MAX_NUMBER_OF_CORRECTIONS = 5


In [None]:
def preprocess(text):
    """
    Removes useless characters and splits text into sentences.

    @param text: text that needs to be preprocessed
    @return: a list of sentences
    """
    # Tokenize the text into sentences
    sentence_list = sent_tokenize(text)

    # Tokenize each sentence into words
    tokenized_sentence_list = [word_tokenize(sentence) for sentence in sentence_list]

    # Remove stop words from each sentence
    filtered_sentence_list = []
    for index, sentence in enumerate(tokenized_sentence_list):
        filtered_sentence = " ".join([word for word in sentence if word.isalnum()])

        # Check if the origin text contains punctuation
        if sentence_list[index][-1] in string.punctuation:
            filtered_sentence += sentence_list[index][-1]

        filtered_sentence_list.append(filtered_sentence)

    return filtered_sentence_list


In [None]:
def sloleks_tool(sentence):
    """
    For every word in sentence checks if it is contained in the Sloleks lexicon.

    @param sentence: a sentence that needs to be handled
    @result: the percentage of grammatically incorrect words
    """
    sloleks = pd.read_csv(SLOLEKS, keep_default_na=False)

    # Tokenize each sentence into words
    tokenized_sentence = word_tokenize(sentence)

    number_of_incorrect_words = 0
    number_of_all_words = len(tokenized_sentence)
    for index, word in enumerate(tokenized_sentence):
        if word in string.punctuation:
            # Ignore punctuations
            continue
        elif (
            # Check for the first words in the sentence (ignore case)
            index == 0 and word not in sloleks.values and word.lower() in sloleks.values
        ):
            continue
        elif word not in sloleks.values:
            # Check if the word is in Sloleks lexicon
            number_of_incorrect_words += 1

    return number_of_incorrect_words / number_of_all_words


In [None]:
def main(text):
    """
    Grammatically handles the input text.

    @param text: input text that user wants to handle
    @return: grammatically handled text
    """
    # Model, tokenizer and pipeline for GED
    model_ged = AutoModelForSequenceClassification.from_pretrained(GED)
    tokenizer_ged = AutoTokenizer.from_pretrained(GED)
    pipeline_ged = TextClassificationPipeline(
        model=model_ged, tokenizer=tokenizer_ged, task="Grammar Error Detection"
    )

    # Model, tokenizer and pipeline for GER
    model_ger = AutoModelForTokenClassification.from_pretrained(GER)
    tokenizer_ger = AutoTokenizer.from_pretrained(GER)
    pipeline_ger = TokenClassificationPipeline(
        model=model_ger, tokenizer=tokenizer_ger, task="Grammar Error Recognition"
    )

    # Model, tokenizer and pipeline for GEC
    model_gec = AutoModelForSeq2SeqLM.from_pretrained(GEC)
    tokenizer_gec = AutoTokenizer.from_pretrained(GEC)
    pipeline_gec = Text2TextGenerationPipeline(
        model=model_gec, tokenizer=tokenizer_gec, task="Grammar Error Correction"
    )

    # Text preprocessing
    unhandled_sentence_list = preprocess(text)
    handled_sentence_list = []

    # Repeat for every sentence
    for sentence in unhandled_sentence_list:
        handled_sentence = sentence

        # Repeat for the max allowed number of corrections
        for _ in range(MAX_NUMBER_OF_CORRECTIONS):
            # Sloleks tool
            sloleks_result = sloleks_tool(handled_sentence)

            # GED
            ged_result = pipeline_ged(handled_sentence)
            logger_geh.info("GED result " + str(ged_result))

            # Check if the sentence is grammatically correct
            if (
                ged_result[0]["label"] == "LABEL_1"
                and sloleks_result < MAX_ERROR_RATE_SLOLEKS
            ):
                # Sentence is grammatically correct
                break

            # GER
            ger_result = pipeline_ger(handled_sentence)
            logger_geh.info("GER result " + str(ger_result))

            # GEC
            gec_result = pipeline_gec(handled_sentence)
            logger_geh.info("GEC result " + str(gec_result))

            # Update the sentence
            handled_sentence = " ".join(
                [
                    corrected_sentence["generated_text"]
                    for corrected_sentence in gec_result
                ]
            )

        # Add the sentence to the final output
        handled_sentence_list.append(handled_sentence)
        logger_geh.info(sentence + " -> " + handled_sentence)

    return


In [None]:
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Grammar error handling")
    parser.add_argument(
        "-t",
        "--text",
        metavar="path",
        required=True,
        type=str,
        help="Text you want to grammatically handle",
    )
    args = parser.parse_args()
    main(args.text)
