In [1]:
import os
import re
import csv
import json
import gzip
import numpy as np
from pprint import pprint
from string import punctuation
from nltk import sent_tokenize
from collections import Counter, defaultdict
from typing import List, Tuple, Dict, Union, Generator

punctuation += '«»–—…“”'
punct = set(punctuation)

In [2]:
bad = open('sents_with_mistakes.txt', encoding='utf8').read().splitlines()
true = open('correct_sents.txt', encoding='utf8').read().splitlines()

In [3]:
def align_words(sent_1: str, sent_2: str) -> List[Tuple[str]]:
    tokens_1 = sent_1.lower().split()
    tokens_2 = sent_2.lower().split()
    
    tokens_1 = [re.sub('(^\W+|\W+$)', '', token) for token in tokens_1 if (set(token)-punct)]
    tokens_2 = [re.sub('(^\W+|\W+$)', '', token) for token in tokens_2 if (set(token)-punct)]
    
    return list(zip(tokens_1, tokens_2))

In [4]:
def normalize(text: str) -> List[str]:
    normalized_text = [(word.strip(punctuation)) for word in text.lower().split()]
    normalized_text = [word for word in normalized_text if word]
    return normalized_text

In [5]:
# Let's take a bigger corpus.

# with open('corpus_10000.txt', 'w') as corpus, gzip.open('lenta-ru-news.csv.gz', 'rt') as archive:
#     reader = csv.reader(archive, delimiter=',', quotechar='"')
#     for i, line in enumerate(reader):
#         if i < 10000: # увеличьте количество текстов тут
#             corpus.write(line[2].replace('\xa0', ' ') + '\n')

In [6]:
corpus = []
for text in open('corpus_10000.txt').read().splitlines():
    sentences = sent_tokenize(text)
    norm_sentences = [normalize(sentence) for sentence in sentences]
    corpus += norm_sentences

In [7]:
WORDS = Counter()
for sentence in corpus:
    WORDS.update(sentence)

In [8]:
N = sum(WORDS.values())
def P(word): 
    return WORDS[word] / N

In [9]:
misspells = defaultdict(set)

for word in WORDS:
    for i in range(len(word)):
        misspells[word[:i] + word[i+1:]].add(word)

In [10]:
list(misspells.items())[10:15]

[('вице-пемьер', {'вице-премьер'}),
 ('вице-прмьер', {'вице-премьер'}),
 ('вице-преьер', {'вице-премьер'}),
 ('вице-премер', {'вице-премьер'}),
 ('вице-премьр', {'вице-премьер'})]

In [11]:
def get_correction(word: str) -> Union[str, None]:
    best_option = (word, 0)
    for i in range(len(word)):
        candidate = misspells.get(word[:i] + word[i+1:])
        if candidate is not None:
            candidate, p = sorted(zip(candidate, (P(x) for x in candidate)), key=lambda x: x[1], reverse=True)[0]
            if p > best_option[1]:
                best_option = (candidate, p)
    if best_option[1] > 0:
        return best_option[0]
    return None

In [12]:
%%time

get_correction('сомнце')

CPU times: user 47 µs, sys: 0 ns, total: 47 µs
Wall time: 53.2 µs


'солнце'

# No language model (unigram probabilities)

In [13]:
correct = 0
total = 0

total_correct = 0
correct_broken = 0

total_mistaken = 0
mistaken_fixed = 0

for i in range(len(true)):
    word_pairs = align_words(true[i], bad[i])
    for (true_word, bad_word) in word_pairs:
        if bad_word in WORDS:
            pred = bad_word
        else:
            pred = get_correction(bad_word)
            if pred is None:
                pred = bad_word
        if pred == true_word:
            correct += 1
        total += 1
        
        if true_word == bad_word:
            total_correct += 1
            if true_word != pred:
                correct_broken += 1
        else:
            total_mistaken += 1
            if true_word == pred:
                mistaken_fixed += 1

In [14]:
print(np.round(correct/total, 3))
print(np.round(mistaken_fixed/total_mistaken, 3))
print(np.round(correct_broken/total_correct, 3))

0.85
0.21
0.054


# Using a language model to range the results

In [15]:
class NgramLanguageModel:
    def __init__(self, path: str, n: int = 3) -> None:
        if n < 2:
            raise Exception('Parameter N of an n-gram language model cannot be less than 2!')
        self.n = n
        with open(path, encoding='utf-8') as f:
            self.text = f.read()
        self.sentences = map(lambda sentence: ['<START>'] * (self.n - 1) + \
                             [x for x in self.tokenize(sentence) if x] + ['<END>'],
                             sent_tokenize(self.text))
        self.token2id = {}
        if self.n > 2:
            self.ngram2id = {}
        self.ngrams = [Counter(), Counter()]
        self.get_ngram_counts()
        self.vocab_size = len(self.token2id)
        self.id2token = {value: key for key, value in self.token2id.items()}
        self.matrix = np.zeros((len(self.ngrams[0]), self.vocab_size))
        if self.n > 2:
            self.id2ngram = {value: key for key, value in self.ngram2id.items()}
        self.populate_matrix()
    
    def get_ngram_counts(self) -> None:
        if self.n == 2:
            for sentence in self.sentences:
                self.ngrams[0].update(sentence)
                self.ngrams[1].update(list(self.ngrammer(sentence, 2)))
                self.update_token2id(sentence)
        if self.n > 2:
            for sentence in self.sentences:
                ngrams = list(self.ngrammer(sentence, self.n-1))
                self.ngrams[0].update(ngrams)
                self.ngrams[1].update(list(self.ngrammer(sentence, self.n)))
                self.update_token2id(sentence)
                self.update_ngram2id(ngrams)
    
    def update_token2id(self, tokens: List[str]) -> None:
        for token in tokens:
            if token not in self.token2id:
                self.token2id[token] = len(self.token2id)
    
    def update_ngram2id(self, ngrams: List[str]) -> None:
        for ngram in ngrams:
            if ngram not in self.ngram2id:
                self.ngram2id[ngram] = len(self.ngram2id)
    
    def populate_matrix(self):
        if self.n == 2:
            for ngram, count in self.ngrams[1].items():
                src, dest = ngram.split()
                self.matrix[self.token2id[src], self.token2id[dest]] = \
                    count / self.ngrams[0][src]
        if self.n > 2:
            for ngram, count in self.ngrams[1].items():
                ngram_splitted = ngram.split()
                src = ' '.join(ngram_splitted[:-1])
                dest = ngram_splitted[-1]
                self.matrix[self.ngram2id[src], self.token2id[dest]] = \
                    count / self.ngrams[0][src]
    
    def get_trigram_probability(self, word: str, prev_bigram: str) -> float:
        if prev_bigram in self.ngram2id and word in self.token2id:
            return self.matrix[self.ngram2id[prev_bigram], self.token2id[word]]
        return P(word)
            
    @staticmethod
    def tokenize(sentence: str) -> map:
        return map(lambda x: x.strip(punctuation).replace('ё', 'е'),
                   sentence.strip().lower().split())
    
    @staticmethod
    def ngrammer(sentence: List[str], n: int) -> Generator[str, None, None]:
        for i in range(len(sentence)-n+1):
            yield ' '.join(sentence[i: i+n])

In [16]:
trigram_lm = NgramLanguageModel('corpus_10000.txt')

In [17]:
def get_lm_correction(word: str, bigram: str) -> Union[str, None]:
    best_option = (word, 0)
    for i in range(len(word)):
        candidate = misspells.get(word[:i] + word[i+1:])
        if candidate is not None:
            candidate, p = sorted(zip(candidate, (trigram_lm.get_trigram_probability(x, bigram) for x in candidate)),
                                  key=lambda x: x[1], reverse=True)[0]
            if p > best_option[1]:
                best_option = (candidate, p)
    if best_option[1] > 0:
        return best_option[0]
    return None

In [18]:
%%time

get_lm_correction('сомнце', '<START> взошло')

CPU times: user 664 µs, sys: 8.33 ms, total: 8.99 ms
Wall time: 58.7 ms


'солнце'

In [19]:
correct = 0
total = 0

total_correct = 0
correct_broken = 0

total_mistaken = 0
mistaken_fixed = 0

for i in range(len(true)):
    word_pairs = [('<START>, <START>')] * 2 + align_words(true[i], bad[i])
    for i, (true_word, bad_word) in enumerate(word_pairs[2:]):
        if bad_word in WORDS:
            pred = bad_word
        else:
            pred = get_lm_correction(bad_word, word_pairs[i][1] + ' ' + word_pairs[i+1][1])
            if pred is None:
                pred = bad_word
        if pred == true_word:
            correct += 1
        total += 1
        
        if true_word == bad_word:
            total_correct += 1
            if true_word != pred:
                correct_broken += 1
        else:
            total_mistaken += 1
            if true_word == pred:
                mistaken_fixed += 1

In [20]:
print(np.round(correct/total, 3))
print(np.round(mistaken_fixed/total_mistaken, 3))
print(np.round(correct_broken/total_correct, 3))

0.858
0.181
0.041


It might look like using a language model to choose the most probable result led to an improvement. But:

In [21]:
correct = 0
total = 0

total_correct = 0
correct_broken = 0

total_mistaken = 0
mistaken_fixed = 0

for i in range(len(true)):
    word_pairs = align_words(true[i], bad[i])
    for (true_word, bad_word) in word_pairs:
        pred = bad_word
        if pred == true_word:
            correct += 1
        total += 1
        
        if true_word == bad_word:
            total_correct += 1
            if true_word != pred:
                correct_broken += 1
        else:
            total_mistaken += 1
            if true_word == pred:
                mistaken_fixed += 1

In [22]:
print(np.round(correct/total, 3))
print(np.round(mistaken_fixed/total_mistaken, 3))
print(np.round(correct_broken/total_correct, 3))

0.87
0.0
0.0


If we let the misspells be, the first metric (the main one) will be higher. That's why SymSpell is a really dubious algorithm. Let's give it its last chance — edit distance of 2 (we will be deleting two random characters in a word).

P. S. The following cells are executed really slowly.

In [23]:
misspells = defaultdict(set)

for word in WORDS:
    for i in range(len(word)):
        subword = word[:i] + word[i+1:]
        for j in range(len(subword)):
            misspells[subword[:j] + subword[j+1:]].add(word)

In [24]:
misspells['сонц']

{'солнца', 'солнце', 'солнцу'}

In [25]:
def get_lm_correction_ld2(word: str, bigram: str) -> Union[str, None]:
    best_option = (word, 0)
    for i in range(len(word)):
        subword = word[:i] + word[i+1:]
        for j in range(len(subword)):
            candidate = misspells.get(subword[:j] + subword[j+1:])
            if candidate is not None:
                candidate, p = sorted(zip(candidate, (trigram_lm.get_trigram_probability(x, bigram) for x in candidate)),
                                      key=lambda x: x[1], reverse=True)[0]
                if p > best_option[1]:
                    best_option = (candidate, p)
    if best_option[1] > 0:
        return best_option[0]
    return None

In [26]:
correct = 0
total = 0

total_correct = 0
correct_broken = 0

total_mistaken = 0
mistaken_fixed = 0

for i in range(len(true)):
    word_pairs = [('<START>, <START>')] * 2 + align_words(true[i], bad[i])
    for i, (true_word, bad_word) in enumerate(word_pairs[2:]):
        if bad_word in WORDS:
            pred = bad_word
        else:
            pred = get_lm_correction_ld2(bad_word, word_pairs[i][1] + ' ' + word_pairs[i+1][1])
            if pred is None:
                pred = bad_word
        if pred == true_word:
            correct += 1
        total += 1
        
        if true_word == bad_word:
            total_correct += 1
            if true_word != pred:
                correct_broken += 1
        else:
            total_mistaken += 1
            if true_word == pred:
                mistaken_fixed += 1

In [27]:
print(np.round(correct/total, 3))
print(np.round(mistaken_fixed/total_mistaken, 3))
print(np.round(correct_broken/total_correct, 3))

0.82
0.094
0.071


It only became worse, so it's safe to say that SymSpell doesn't work too well, at least for our synthetically generated data.