In [None]:
# imports

!python -m spacy download ru_core_news_sm en_core_web_sm
import warnings
warnings.filterwarnings('ignore')
import os
import nltk
nltk.download('wordnet')
import string
import torch
import numpy as np
from tqdm.auto import tqdm
import random
import pymorphy3
from datasets import (Dataset, 
                      DatasetDict, 
                      load_dataset, 
                      load_from_disk, 
                      concatenate_datasets)

import random
import spacy
from nltk.corpus import wordnet

nlp_ru = spacy.load("ru_core_news_sm")
nlp_en = spacy.load("en_core_web_sm")

In [None]:
import sys
import os

data_path = os.path.abspath(os.path.join('..', 'fasttext_data'))
if data_path not in sys.path:
    sys.path.append(data_path)

path = '../fasttext_data'
embeds = torch.load(os.path.join(path, 'ru_embeds.pt'))[:10000]
words = np.load(os.path.join(path, 'ru_words.npy'))[:10000]
w2idx = {w:i for i, w in enumerate(words)}

def get_synonym_ru(word, embeds=embeds, words=words, w2idx=w2idx):
    word = word.lower()
    if not word in w2idx:
        return '<NULL>'
    embed = embeds[w2idx[word]]
    synonyms_idx = torch.topk(torch.cosine_similarity(embeds, embed), 10)[1]
    idx = np.random.choice(synonyms_idx[1:])
    return words[idx]

synonyms_ru = {}
for word in tqdm(words):
    syn = get_synonym_ru(word)
    synonyms_ru[word] = syn


def get_synonyms(word):
    """Get synonyms for an English word using WordNet."""
    synonyms = set()
    for syn in wordnet.synsets(word):
        for lemma in syn.lemmas():
            synonym = lemma.name().replace('_', ' ')
            if synonym != word:
                synonyms.add(synonym)
    return list(synonyms)

Each text in the original dataset will be processed with `degrade_text` function which will
1. delete punctuation
2. lemmatize words
3. lower words
4. randomly shuffle words in the sliding window of length 5
5. change random words with their synonyms

In [None]:
def degrade_text(text, lang='ru'):
    text.replace("\n", "")
    nlp = nlp_ru if lang == 'ru' else nlp_en
    doc = nlp(text)
    degraded_tokens = []

    for token in doc:
        if token.is_punct:
            continue
        elif token.pos_ in {"DET", "ADP"}:
            continue
        elif token.pos_ == "VERB":
            degraded_tokens.append(token.lemma_)
        elif token.pos_ == "NOUN" and token.tag_ == "NNS":
            degraded_tokens.append(token.lemma_)
        elif token.pos_ == "PRON":
            continue
        else:
            degraded_tokens.append(token.text.lower())
    
    chunk_size = 3
    chunks = [degraded_tokens[i:i + chunk_size] for i in range(0, len(degraded_tokens), chunk_size)]
    for chunk in chunks:
        if len(chunk) > 1 and random.random() < 0.3:
            random.shuffle(chunk)
    
    degraded_tokens = [item for chunk in chunks for item in chunk]

    for i in range(len(degraded_tokens)):
        if random.random() < 0.2 and lang == 'en':
            synonyms = get_synonyms(degraded_tokens[i])
            if synonyms:
                degraded_tokens[i] = random.choice(synonyms)
        elif random.random() < 0.2 and lang == 'ru' and degraded_tokens[i] in synonyms_ru:
            degraded_tokens[i] = synonyms_ru[degraded_tokens[i].lower()]
            
        if random.random() < 0.02:
            degraded_tokens[i] = '<NULL>'
    
    degraded_text = " ".join(degraded_tokens)
    return degraded_text

original_text = 'Меня зовут по имени каждый раз, когда ко мне обращаются'
degraded_text = degrade_text(original_text, lang='ru')
print("Original Text: ", original_text)
print("Degraded Text: ", degraded_text)

original_text = 'I am going to school right now'
degraded_text = degrade_text(original_text, lang='en')
print("Original Text: ", original_text)
print("Degraded Text: ", degraded_text)

# Original Text:  Меня зовут по имени каждый раз, когда ко мне обращаются
# Degraded Text:  звать имени разом когда обращаться
# Original Text:  I am going to school right now
# Degraded Text:  am crack school right now

Finally, this function is applied to the merged dataset of English and Russian wikipedia texts:

In [None]:
dataset_en = load_dataset('wikipedia', language='en', date='20220301')
dataset_ru = load_dataset('wikipedia', language="ru", date="20240520")

def length(example, subset_ratio):
    text = example['text']
    words = text.split()
    return (len(words) > 5) and (len(words) < 300) and (random.random() < subset_ratio)

filtered_dataset_en = dataset_en.filter(lambda x: length(x, subset_ratio=0.1))
filtered_dataset_ru = dataset_ru.filter(lambda x: length(x, subset_ratio=0.3))


def preprocess_dataset(ex, lang):
    ex['bad_text'] = [degrade_text(text, lang=lang) for text in ex['text']]
    return ex


final_en = filtered_dataset_en.map(lambda x: preprocess_dataset(x, lang='en'), batched=True)
final_ru = filtered_dataset_ru.map(lambda x: preprocess_dataset(x, lang='ru'), batched=True)

dataset_ru_new = final_ru['train'].add_column('lang', ['ru'] * len(final_ru['train']))
dataset_en_new = final_en['train'].add_column('lang', ['en'] * len(final_en['train']))

concatenated_dataset = concatenate_datasets([dataset_ru_new, dataset_en_new])

final_dataset = concatenated_dataset.shuffle(seed=42)
#final_dataset.push_to_hub("gudleifrr/text-correction-en-ru")