In [2]:
import pandas as pd
import numpy as np
from tqdm import tqdm
tqdm.pandas()

from transformers import TextDataset
from torch.utils.data import DataLoader, Dataset
import torch
from torch.nn.functional import cross_entropy
from transformers import AutoModelForSeq2SeqLM, T5TokenizerFast, AdamW

from datasets import load_dataset

In [None]:
# для отладки проблем с cuda
# import os
# os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [3]:
# Первая часть данных, датасет с Rucode
train = pd.read_csv('/kaggle/input/spell-check-dataset-rucode/train.csv', usecols=['corrupted_text', 'correct_text'])

train = train.rename({'corrupted_text':'incorrect', 'correct_text':'correct'}, axis=1)
train = train.sample(26000).reset_index(drop=True)

In [4]:
train_dataset = load_dataset("ai-forever/spellcheck_benchmark", "RUSpellRU", split='train[:]')
eval_dataset = load_dataset("ai-forever/spellcheck_benchmark", "RUSpellRU", split='test[:]')

Downloading builder script:   0%|          | 0.00/9.07k [00:00<?, ?B/s]

Downloading and preparing dataset russian_spellcheck_benchmark/RUSpellRU to /root/.cache/huggingface/datasets/ai-forever___russian_spellcheck_benchmark/RUSpellRU/0.0.1/87bfa2950c7b82ec565b4da426533874af24d25436ad08dba065a45895ad3945...


Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/1.95M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.69M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Dataset russian_spellcheck_benchmark downloaded and prepared to /root/.cache/huggingface/datasets/ai-forever___russian_spellcheck_benchmark/RUSpellRU/0.0.1/87bfa2950c7b82ec565b4da426533874af24d25436ad08dba065a45895ad3945. Subsequent calls will reuse this data.


In [5]:
first_df = pd.DataFrame({'incorrect':train_dataset['source'], 'correct': train_dataset['correction']})
second_df = pd.DataFrame({'incorrect':eval_dataset['source'], 'correct': eval_dataset['correction']})

full_df = pd.concat([first_df, second_df, train])
full_df['incorrect'] = 'Spell correct: ' + full_df['incorrect']

In [23]:
first_df.sample(5)

Unnamed: 0,incorrect,correct
534,"Как я понял, это изначально сделано.",Как я понял это изначально сделано
823,"Я вам покажу, как будить Колдуна!",Я вам покажу как будить Колдуна
1198,"Я думала, кстате, что даже у маленбких утят пе...",Я думала кстати что даже у маленьких утят перь...
578,"А в 8.30 я пошагала на пары, а вечером дописыв...",А в 8.30 я пошагала на пары а вечером дописыва...
1285,"Мужык был в транче и тихо повторял - "" меня в ...",Мужик был в трансе и тихо повторял меня в горо...


In [45]:
MODEL_NAME = 'google/mt5-small'
tokenizer = T5TokenizerFast.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)

Downloading tokenizer_config.json:   0%|          | 0.00/82.0 [00:00<?, ?B/s]

Downloading spiece.model:   0%|          | 0.00/4.31M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

Downloading config.json:   0%|          | 0.00/553 [00:00<?, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


Downloading pytorch_model.bin:   0%|          | 0.00/1.20G [00:00<?, ?B/s]

Downloading generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

In [7]:
class SpellCheckDataset(Dataset):
    def __init__(self, df:pd.DataFrame):
        self.inputs = df.iloc[:, 0].values
        self.outputs = df.iloc[:, 1].values

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        input_text = self.inputs[idx]
        output_text = self.outputs[idx]
        return tokenizer(input_text, padding=False, truncation=True), tokenizer(output_text, padding=False, truncation=True)
    
def collate_fn(batch):
    input_texts = [item[0] for item in batch]
    output_texts = [item[1] for item in batch]

    input_texts = tokenizer.pad(input_texts, return_tensors="pt", padding='longest')
    output_texts = tokenizer.pad(output_texts, return_tensors="pt", padding='longest')

    return input_texts, output_texts

def bucketed_data_loader(dataframe: pd.DataFrame, batch_size:int):
    dataframe = dataframe.copy()
    dataframe['length'] = dataframe.iloc[:, 0].progress_apply(lambda x: len(tokenizer.encode(x, truncation=True)))
    dataframe.sort_values(by='length', inplace=True)

    dataset = SpellCheckDataset(dataframe.drop(columns=['length']))

    sampler = torch.utils.data.BatchSampler(
        torch.utils.data.SequentialSampler(dataset),
        batch_size=batch_size,
        drop_last=False
    )

    return DataLoader(dataset, batch_sampler=sampler, collate_fn=collate_fn)

batch_size = 8
loader = bucketed_data_loader(full_df, batch_size)

  0%|          | 0/30008 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
100%|██████████| 30008/30008 [00:03<00:00, 9296.67it/s] 


In [8]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model.to(device)
optimizer = AdamW(model.parameters(), lr=5e-5)



In [9]:
num_epochs = 40
model.train()

for epoch in range(num_epochs):
    for batch_idx, batch in tqdm(enumerate(loader)):
        optimizer.zero_grad()

        input_ids = batch[0]["input_ids"].to(device)
        attention_mask = batch[0]["attention_mask"].to(device)
        labels = batch[1]["input_ids"].to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        if batch_idx % 240 == 9:
            print(f"Epoch: {epoch + 1}, Batch: {batch_idx}/{len(loader)}, Loss: {loss.item()}")

0it [00:00, ?it/s]You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
11it [00:05,  5.48it/s]

Epoch: 1, Batch: 9/3751, Loss: 1.9971446990966797


251it [00:41,  6.70it/s]

Epoch: 1, Batch: 249/3751, Loss: 0.43817901611328125


491it [01:18,  6.85it/s]

Epoch: 1, Batch: 489/3751, Loss: 0.4459899663925171


731it [01:53,  6.79it/s]

Epoch: 1, Batch: 729/3751, Loss: 0.9319515228271484


971it [02:29,  6.75it/s]

Epoch: 1, Batch: 969/3751, Loss: 0.5056430697441101


1211it [03:05,  6.85it/s]

Epoch: 1, Batch: 1209/3751, Loss: 0.3211566209793091


1451it [03:41,  6.87it/s]

Epoch: 1, Batch: 1449/3751, Loss: 0.45441192388534546


1691it [04:16,  6.87it/s]

Epoch: 1, Batch: 1689/3751, Loss: 0.2680402100086212


1931it [04:52,  6.85it/s]

Epoch: 1, Batch: 1929/3751, Loss: 0.20242580771446228


2171it [05:28,  6.91it/s]

Epoch: 1, Batch: 2169/3751, Loss: 0.39576393365859985


2411it [06:03,  6.77it/s]

Epoch: 1, Batch: 2409/3751, Loss: 0.5102012753486633


2651it [06:39,  6.72it/s]

Epoch: 1, Batch: 2649/3751, Loss: 0.30002525448799133


2891it [07:14,  6.80it/s]

Epoch: 1, Batch: 2889/3751, Loss: 0.5882150530815125


3131it [07:50,  6.60it/s]

Epoch: 1, Batch: 3129/3751, Loss: 0.3034571707248688


3371it [08:26,  6.44it/s]

Epoch: 1, Batch: 3369/3751, Loss: 0.3213995099067688


3611it [09:04,  5.99it/s]

Epoch: 1, Batch: 3609/3751, Loss: 0.0950116217136383


3751it [09:32,  6.56it/s]
11it [00:01,  6.76it/s]

Epoch: 2, Batch: 9/3751, Loss: 0.4633176028728485


251it [00:37,  6.64it/s]

Epoch: 2, Batch: 249/3751, Loss: 0.28708985447883606


491it [01:12,  6.97it/s]

Epoch: 2, Batch: 489/3751, Loss: 0.30857500433921814


731it [01:48,  6.83it/s]

Epoch: 2, Batch: 729/3751, Loss: 0.4761451482772827


971it [02:24,  6.81it/s]

Epoch: 2, Batch: 969/3751, Loss: 0.253158301115036


1211it [02:59,  6.77it/s]

Epoch: 2, Batch: 1209/3751, Loss: 0.07676158100366592


1451it [03:35,  6.78it/s]

Epoch: 2, Batch: 1449/3751, Loss: 0.21788163483142853


1691it [04:11,  6.84it/s]

Epoch: 2, Batch: 1689/3751, Loss: 0.16486069560050964


1931it [04:47,  6.80it/s]

Epoch: 2, Batch: 1929/3751, Loss: 0.092677041888237


2171it [05:22,  6.90it/s]

Epoch: 2, Batch: 2169/3751, Loss: 0.23089006543159485


2411it [05:58,  6.84it/s]

Epoch: 2, Batch: 2409/3751, Loss: 0.3159630298614502


2651it [06:33,  6.74it/s]

Epoch: 2, Batch: 2649/3751, Loss: 0.260448157787323


2891it [07:09,  6.51it/s]

Epoch: 2, Batch: 2889/3751, Loss: 0.3076982796192169


3131it [07:45,  6.77it/s]

Epoch: 2, Batch: 3129/3751, Loss: 0.12892726063728333


3371it [08:21,  6.78it/s]

Epoch: 2, Batch: 3369/3751, Loss: 0.1899082362651825


3611it [08:58,  5.98it/s]

Epoch: 2, Batch: 3609/3751, Loss: 0.05206269025802612


3751it [09:26,  6.62it/s]
11it [00:01,  6.89it/s]

Epoch: 3, Batch: 9/3751, Loss: 0.13047991693019867


251it [00:37,  6.88it/s]

Epoch: 3, Batch: 249/3751, Loss: 0.19360317289829254


491it [01:12,  6.86it/s]

Epoch: 3, Batch: 489/3751, Loss: 0.2343810647726059


731it [01:47,  6.83it/s]

Epoch: 3, Batch: 729/3751, Loss: 0.21103784441947937


971it [02:23,  6.84it/s]

Epoch: 3, Batch: 969/3751, Loss: 0.15399006009101868


1211it [02:58,  6.83it/s]

Epoch: 3, Batch: 1209/3751, Loss: 0.07900543510913849


1451it [03:34,  6.91it/s]

Epoch: 3, Batch: 1449/3751, Loss: 0.38898175954818726


1691it [04:10,  6.43it/s]

Epoch: 3, Batch: 1689/3751, Loss: 0.047756556421518326


1931it [04:45,  6.88it/s]

Epoch: 3, Batch: 1929/3751, Loss: 0.11855590343475342


2171it [05:20,  6.87it/s]

Epoch: 3, Batch: 2169/3751, Loss: 0.2142293006181717


2411it [05:56,  6.84it/s]

Epoch: 3, Batch: 2409/3751, Loss: 0.37622424960136414


2651it [06:31,  6.83it/s]

Epoch: 3, Batch: 2649/3751, Loss: 0.11150596290826797


2891it [07:07,  6.52it/s]

Epoch: 3, Batch: 2889/3751, Loss: 0.1761285364627838


3131it [07:43,  6.83it/s]

Epoch: 3, Batch: 3129/3751, Loss: 0.1598057895898819


3371it [08:19,  6.69it/s]

Epoch: 3, Batch: 3369/3751, Loss: 0.0985126942396164


3611it [08:57,  6.01it/s]

Epoch: 3, Batch: 3609/3751, Loss: 0.05450660362839699


3751it [09:24,  6.64it/s]
11it [00:01,  6.78it/s]

Epoch: 4, Batch: 9/3751, Loss: 0.10883665084838867


251it [00:37,  6.86it/s]

Epoch: 4, Batch: 249/3751, Loss: 0.08462006598711014


491it [01:12,  6.76it/s]

Epoch: 4, Batch: 489/3751, Loss: 0.07649488747119904


731it [01:48,  6.45it/s]

Epoch: 4, Batch: 729/3751, Loss: 0.14134010672569275


971it [02:23,  6.90it/s]

Epoch: 4, Batch: 969/3751, Loss: 0.037552446126937866


1211it [02:59,  6.88it/s]

Epoch: 4, Batch: 1209/3751, Loss: 0.05176844820380211


1451it [03:34,  6.54it/s]

Epoch: 4, Batch: 1449/3751, Loss: 0.13337944447994232


1691it [04:10,  6.64it/s]

Epoch: 4, Batch: 1689/3751, Loss: 0.015207069925963879


1931it [04:45,  6.73it/s]

Epoch: 4, Batch: 1929/3751, Loss: 0.03202793374657631


2171it [05:21,  6.61it/s]

Epoch: 4, Batch: 2169/3751, Loss: 0.052407391369342804


2411it [05:57,  6.50it/s]

Epoch: 4, Batch: 2409/3751, Loss: 0.1362016499042511


2651it [06:32,  6.71it/s]

Epoch: 4, Batch: 2649/3751, Loss: 0.0790378674864769


2891it [07:08,  6.71it/s]

Epoch: 4, Batch: 2889/3751, Loss: 0.13644085824489594


3131it [07:43,  6.74it/s]

Epoch: 4, Batch: 3129/3751, Loss: 0.029689496383070946


3371it [08:20,  6.78it/s]

Epoch: 4, Batch: 3369/3751, Loss: 0.09692831337451935


3611it [08:57,  5.97it/s]

Epoch: 4, Batch: 3609/3751, Loss: 0.024262862280011177


3751it [09:25,  6.64it/s]
11it [00:01,  6.87it/s]

Epoch: 5, Batch: 9/3751, Loss: 0.09621081501245499


251it [00:37,  6.91it/s]

Epoch: 5, Batch: 249/3751, Loss: 0.016371937468647957


491it [01:12,  6.70it/s]

Epoch: 5, Batch: 489/3751, Loss: 0.06699804961681366


731it [01:47,  6.92it/s]

Epoch: 5, Batch: 729/3751, Loss: 0.10688319802284241


971it [02:23,  6.82it/s]

Epoch: 5, Batch: 969/3751, Loss: 0.0270080529153347


1211it [02:58,  6.98it/s]

Epoch: 5, Batch: 1209/3751, Loss: 0.016049183905124664


1451it [03:33,  6.92it/s]

Epoch: 5, Batch: 1449/3751, Loss: 0.1178140789270401


1691it [04:09,  6.85it/s]

Epoch: 5, Batch: 1689/3751, Loss: 0.06646321713924408


1931it [04:45,  6.79it/s]

Epoch: 5, Batch: 1929/3751, Loss: 0.041152823716402054


2171it [05:20,  6.32it/s]

Epoch: 5, Batch: 2169/3751, Loss: 0.05096874758601189


2411it [05:56,  6.80it/s]

Epoch: 5, Batch: 2409/3751, Loss: 0.15410064160823822


2651it [06:31,  6.82it/s]

Epoch: 5, Batch: 2649/3751, Loss: 0.10466513782739639


2891it [07:07,  6.85it/s]

Epoch: 5, Batch: 2889/3751, Loss: 0.03456886485219002


3131it [07:43,  6.83it/s]

Epoch: 5, Batch: 3129/3751, Loss: 0.11810452491044998


3371it [08:19,  6.72it/s]

Epoch: 5, Batch: 3369/3751, Loss: 0.07819945365190506


3611it [08:56,  6.02it/s]

Epoch: 5, Batch: 3609/3751, Loss: 0.021809043362736702


3751it [09:24,  6.64it/s]
11it [00:01,  6.93it/s]

Epoch: 6, Batch: 9/3751, Loss: 0.06406961381435394


251it [00:37,  6.60it/s]

Epoch: 6, Batch: 249/3751, Loss: 0.009546487592160702


491it [01:12,  6.72it/s]

Epoch: 6, Batch: 489/3751, Loss: 0.011705342680215836


731it [01:47,  6.73it/s]

Epoch: 6, Batch: 729/3751, Loss: 0.027988741174340248


971it [02:23,  6.89it/s]

Epoch: 6, Batch: 969/3751, Loss: 0.010717462748289108


1211it [02:58,  6.98it/s]

Epoch: 6, Batch: 1209/3751, Loss: 0.012508644722402096


1451it [03:34,  6.85it/s]

Epoch: 6, Batch: 1449/3751, Loss: 0.025888115167617798


1691it [04:09,  6.87it/s]

Epoch: 6, Batch: 1689/3751, Loss: 0.0593896247446537


1931it [04:44,  6.78it/s]

Epoch: 6, Batch: 1929/3751, Loss: 0.1704317182302475


2171it [05:19,  6.93it/s]

Epoch: 6, Batch: 2169/3751, Loss: 0.01649695821106434


2411it [05:55,  6.76it/s]

Epoch: 6, Batch: 2409/3751, Loss: 0.1172742173075676


2651it [06:30,  6.84it/s]

Epoch: 6, Batch: 2649/3751, Loss: 0.03495532646775246


2697it [06:37,  6.78it/s]


KeyboardInterrupt: 

# model inference

In [10]:
def restore_punctuation(original_text, good_text):
    res = []
    punctuation = '''!@#$%^&*(){}[]|._`/?:;"'\,~'''
    for orig_word, word in zip(original_text.split(), good_text.split()):
        if orig_word[-1] in punctuation:
            word = word + orig_word[-1]
        if orig_word[0] in punctuation:
            word = orig_word[0] + word

        if (orig_word == '-' or orig_word == '—') and word == 'я':
            word = orig_word

        if orig_word.istitle():
            word = word.capitalize()

        res.append(word)

    restored_text = ' '.join(res)
    for punct in punctuation:
        restored_text = restored_text.replace(' ' + punct, punct)

    return restored_text

In [42]:
prefix = 'Spell correct: '
input_text = 'Молако! Надя решила прикинутся мервтой'
input_ids = tokenizer.encode(prefix+input_text, return_tensors='pt')

input_ids = input_ids.to(device)

with torch.no_grad():
    outputs = model.generate(input_ids)

output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

In [43]:
from IPython.display import display, HTML

# для подкрашивания измененных слов в получившемся тексте
def highlight_corrections(original_text, corrected_text):
    original_words = original_text.split()
    corrected_words = corrected_text.split()

    def create_highlighted_html(words, corrections):
        highlighted_html = ""
        for word, is_corrected in zip(words, corrections):
            if is_corrected:
                highlighted_html += f"<mark>{word}</mark> "
            else:
                highlighted_html += word + " "
        return highlighted_html

    corrections = [ow != cw for ow, cw in zip(original_words, corrected_words)]
    highlighted_original = create_highlighted_html(original_words, corrections)
    highlighted_corrected = create_highlighted_html(corrected_words, corrections)

    display(HTML(f"Оригинальный текст: {highlighted_original}<br><br>Исправленный текст: {highlighted_corrected}"))

In [44]:
highlight_corrections(input_text, restore_punctuation(input_text, output_text))