<a href="https://colab.research.google.com/github/khadkechetan/spell_checker/blob/main/grammar/grammar_check_correct_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Name: Chetan Khadke 

Email: khadkechetan@gmail.com

# <center> Grammar Check & Correct

- this is a demo that uses two models to check grammar & correct as needed
    - grammar check is with `textattack/roberta-base-CoLA`
    - grammar correct is with `pszemraj/flan-t5-large-grammar-synthesis`


---




In [None]:
pip install -U -q transformers accelerate

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.1/7.1 MB[0m [31m46.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m219.1/219.1 kB[0m [31m26.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m224.5/224.5 kB[0m [31m26.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m88.1 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import re
from tqdm.auto import tqdm
import pprint as pp

In [None]:
#@title define functions
params = {
    'max_length':1024,
    'repetition_penalty':1.05,
    'num_beams':4
}
def split_text(text: str) -> list:
    # Split the text into sentences using regex
    sentences = re.split(r"(?<=[^A-Z].[.?]) +(?=[A-Z])", text)
    sentence_batches = []
    temp_batch = []
    for sentence in sentences:
        temp_batch.append(sentence)
        # If the length of the temporary batch is between 2 and 3 sentences, or if it is the last batch, add it to the list of sentence batches
        if len(temp_batch) >= 2 and len(temp_batch) <= 3 or sentence == sentences[-1]:
            sentence_batches.append(temp_batch)
            temp_batch = []
    return sentence_batches


def correct_text(text: str, checker, corrector, separator: str = " ") -> str:
    sentence_batches = split_text(text)
    corrected_text = []
    for batch in tqdm(
        sentence_batches, total=len(sentence_batches), desc="correcting text.."
    ):
        raw_text = " ".join(batch)
        results = checker(raw_text)
        if results[0]["label"] != "LABEL_1" or (
            results[0]["label"] == "LABEL_1" and results[0]["score"] < 0.9
        ):
            # Correct the text using the text-generation pipeline
            corrected_batch = corrector(raw_text, **params)
            corrected_text.append(corrected_batch[0]["generated_text"])
            print("-----------------------------")
        else:
            corrected_text.append(raw_text)
    corrected_text = separator.join(corrected_text)
    return corrected_text

In [None]:
# Initialize the text-classification pipeline
from transformers import pipeline
checker = pipeline("text-classification", "textattack/roberta-base-CoLA")

# Initialize the text-generation pipeline
from transformers import pipeline
corrector = pipeline(
        "text2text-generation",
        "pszemraj/flan-t5-large-grammar-synthesis",
         device=0
    )

Downloading (…)lve/main/config.json:   0%|          | 0.00/564 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/501M [00:00<?, ?B/s]

Some weights of the model checkpoint at textattack/roberta-base-CoLA were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Downloading (…)okenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

Xformers is not installed correctly. If you want to use memorry_efficient_attention to accelerate training use the following command to install Xformers
pip install xformers.


Downloading (…)lve/main/config.json:   0%|          | 0.00/892 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/3.13G [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/2.56k [00:00<?, ?B/s]

Downloading spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

In [None]:
#@title enter text to correct
raw_text = "my helth is not well, I hv to tak 2 day leave." #@param {type:"string"}
pp.pprint(raw_text)

'my helth is not well, I hv to tak 2 day leave.'


In [None]:
# Correct the text using the correct_text function
corrected_text = correct_text(raw_text, checker, corrector)
pp.pprint(corrected_text)

correcting text..:   0%|          | 0/1 [00:00<?, ?it/s]

-----------------------------
'my health is not well, I have to take 2 days leave.'
