In [4]:
import re
from nltk import sent_tokenize
from string import punctuation
from collections import Counter
punctuation += "«»—…“”"
punct = set(punctuation)

In [5]:
bad = open('sents_with_mistakes.txt', encoding='utf8').read().splitlines()
true = open('correct_sents.txt', encoding='utf8').read().splitlines()

def align_words(sent_1, sent_2):
    tokens_1 = sent_1.lower().split()
    tokens_2 = sent_2.lower().split()
    
    tokens_1 = [re.sub('(^\W+|\W+$)', '', token) for token in tokens_1 if (set(token)-punct)]
    tokens_2 = [re.sub('(^\W+|\W+$)', '', token) for token in tokens_2 if (set(token)-punct)]
    
    return list(zip(tokens_1, tokens_2))

mistakes = []
total = 0
for i in range(len(true)):
    word_pairs = align_words(true[i], bad[i])
    for pair in word_pairs:
        if pair[0] != pair[1]:
            mistakes.append(pair)
        total += 1

In [6]:
# import gzip, csv
# corpus = open('corpus_20000.txt', 'w')
# with gzip.open('lenta-ru-news.csv.gz', 'rt') as archive:
#     reader = csv.reader(archive, delimiter=',', quotechar='"')
#     for i, line in enumerate(reader):
#         if i < 20000: # увеличьте количество текстов тут
#             corpus.write(line[2].replace('\xa0', ' ') + '\n')

In [7]:
def normalize(text):
    
    normalized_text = [(word.strip(punctuation)) for word \
                                                            in text.lower().split()]
    normalized_text = [word for word in normalized_text if word]
    return normalized_text

corpus = []
for text in open('corpus_20000.txt').read().splitlines():
    sents = sent_tokenize(text)
    norm_sents = [normalize(sent) for sent in sents]
    corpus += norm_sents
    
vocab = set()

for sent in corpus:
    vocab.update(sent)
    
def predict_mistaken(word, vocab):
    if word in vocab:
        return 0
    else:
        return 1

In [8]:
WORDS = Counter()
for sent in corpus:
    WORDS.update(sent)
    
N = sum(WORDS.values())
def P(word, N=N):
    return WORDS[word] / N

In [15]:
class Deletions:
    def __init__(self):
        self.deletions = {}
        self.known = []

    def add_deletion(self, cut, origin):
        origins = self.deletions.get(cut, [])
        if origin not in origins:
            origins.append(origin)
        self.deletions[cut] = origins

    def get_origin(self, cut):
        origins = self.deletions.get(cut, [])
        return origins

    def cut(self, origins):
        cuts = []
        for word in origins:
            for i in range(len(word)):
                cut = word[:i] + word[i+1:]
                cuts.append(cut)
        return cuts

    def create_deletions(self, origin, n = 1):
        if origin not in self.known:
            self.known.append(origin)
        cuts = [origin]
        while n > 0:
            n -= 1
            cuts = self.cut(cuts)
            for cut in cuts:
                self.add_deletion(cut, origin)

    def find_origin(self, word, n = 1):
        if word in self.known:
            return [word]
        all_cuts = [word]
        cuts = [word]
        while n > 0:
            n -= 1
            cuts = self.cut(cuts)
            for cut in cuts:
                if cut not in all_cuts:
                    all_cuts.append(cut)
        all_origins = []
        for cut in all_cuts:
            origins = self.get_origin(cut)
            for origin in origins:
                if origin not in all_origins:
                    all_origins.append(origin)
        if len(all_origins) == 0:
            return [word]
        return all_origins
    

In [16]:
d = Deletions()
for word in WORDS:
    d.create_deletions(word.lower())

In [17]:
def correction(word):
    origins = d.find_origin(word)
    return max(origins, key=P)

In [175]:
correct = 0
total = 0

total_mistaken = 0
mistaken_fixed = 0

total_correct = 0
correct_broken = 0

cashed = {}
for i in range(len(true)):
    word_pairs = align_words(true[i], bad[i])
    for pair in word_pairs:
        pair = (pair[0].lower(), pair[1].lower())
        predicted = cashed.get(pair[1], correction(pair[1]))
        cashed[pair[0]] = predicted
        if predicted == pair[0]:
            correct += 1
        total += 1
        
        if pair[0] == pair[1]:
            total_correct += 1
            if pair[0] !=  predicted:
                correct_broken += 1
        else:
            total_mistaken += 1
            if pair[0] == predicted:
                mistaken_fixed += 1
        
    if not i % 100:
        print(i)
        
print(correct/total)
print(mistaken_fixed/total_mistaken)
print(correct_broken/total_correct)

0
100
200
300
400
500
600
700
800
900
0.7273726273726274
0.3622409823484267
0.21798552888480532


Как мы видим, качество оставляет желать лучшего - по всем 3 критериям проигрываем алгоритму Норвига, однако он гораздо быстрее

In [12]:
tokens = []
for text in corpus:
    tokens.extend(text)

In [13]:
def ngrammer(tokens, n):
    ngrams = []
    for i in range(0,len(tokens)-n+1):
        ngrams.append(' '.join(tokens[i:i+n]))
    return ngrams

unigrams = Counter()
bigrams = Counter()
trigrams = Counter()
for t in tokens:
    unigrams.update(t)
    bigrams.update(ngrammer(t, 2))
    trigrams.update(ngrammer(t, 3))

In [27]:
correct = 0
total = 0

total_mistaken = 0
mistaken_fixed = 0

total_correct = 0
correct_broken = 0

for i in range(len(true)):
    word_pairs = align_words(true[i], bad[i])
    word_pairs = [('<start>', '<start>'),('<start>', '<start>')] + word_pairs
    
    for j in range(2, len(word_pairs)):
        pair = (word_pairs[j][0].lower(), word_pairs[j][1].lower())
        prev_bi = word_pairs[j-2][1] + " " + word_pairs[j-1][1]
        if prev_bi not in bigrams:
            pred = correction(pair[1])
        else:
            predicted = d.find_origin(pair[1])
            max_prob = -1
            best_word = ""
            for word in predicted:
                trigram = prev_bi + " " + word
                prob = trigrams[trigram]/bigrams[prev_bi]
                if prob > max_prob:
                    max_prob = prob
                    best_word = word
            pred = best_word
            
        if pred == pair[0]:
            correct += 1
        total += 1
        if pair[0] == pair[1]:
            total_correct += 1
            if pair[0] !=  pred:
                correct_broken += 1
        else:
            total_mistaken += 1
            if pair[0] == pred:
                mistaken_fixed += 1
        
    if not i % 100:
        print(i)
        
print(correct/total)
print(mistaken_fixed/total_mistaken)
print(correct_broken/total_correct)

0
100
200
300
400
500
600
700
800
900
0.8758241758241758
0.3622409823484267
0.04731824968416217


Модель показала себя достаточно хорошо - она почти не изменяет правильные слова, однако исправляет все так же мало ошибок, что и раньше. Однако итоговая точность стала выше