In [2]:
import codecs
import collections
import Levenshtein
import re

In [3]:
russian_word_pattern = re.compile('^[а-яА-Я]+$')

def is_russian(word):
    return russian_word_pattern.match(word)

In [4]:
operation_probability = collections.defaultdict(int)
action_probability = collections.defaultdict(int)
dependent_probability = collections.defaultdict(int)

def count_word_mutation(word, expected_word):
    edit_ops = Levenshtein.editops(word, expected_word)
    edit_indexes = set(map(lambda op: op[1], edit_ops))
    mutation = []
    cur_edit_pos = 0
    for i in range(len(word)):
        if i not in edit_indexes:
            mutation.append(('replace', word[i], word[i]))
        else:
            cur_edit = edit_ops[cur_edit_pos]
            if cur_edit[0] == 'replace':
                mutation.append(('replace', word[cur_edit[1]], expected_word[cur_edit[2]]))
            elif cur_edit[0] == 'delete':
                mutation.append(('delete', word[cur_edit[1]]))
            else:
                mutation.append(('insert', expected_word[cur_edit[2]]))
    return mutation

def count_probabilities(word_mutation):
    prev_op = None
    for mut in word_mutation:
        operation_probability[mut[0]] += 1
        action_probability[mut] += 1
        dependent_probability[(prev_op, mut)] += 1
        prev_op = mut
        

def build_error_model():
    with codecs.open("train.csv", 'r', 'utf8') as train:
        train.readline()
        i = 0
        for line in train:
            word, expected_word = line.split(',')
            word, expected_word = word.strip(), expected_word.strip()
            if (is_russian(word)):
                word_mutation = count_word_mutation(word, expected_word)
                count_probabilities(word_mutation)
                

build_error_model()

In [7]:
class TrieNode(object):
    def __init__(self, char):
        self.char = char
        self.children = []
        self.terminal = False
        self.word = None
        self.freq = None
    
    def set_frequency(self, freq):
        self.freq = freq

    def set_word(self, word):
        self.word = word
    

def add(root, word, freq):
    node = root
    for char in word:
        found_in_child = False
        for child in node.children:
            if child.char == char:
                node = child
                found_in_child = True
                break
        if not found_in_child:
            new_node = TrieNode(char)
            node.children.append(new_node)
            node = new_node
    node.terminal = True
    node.set_frequency(freq)
    node.set_word(word)

In [8]:
words_freq = collections.defaultdict(int)

def build_trie():
    root = TrieNode("*")
    with codecs.open('words.csv', 'r', 'utf8') as words:
        words.readline()
        for line in words:
            processed_line = line.split(',')
            word, freq = processed_line[0].strip(), int(processed_line[1].strip())
            if is_russian(word):
                add(root, word, freq)
                words_freq[word] = freq
    return root
            
trie = build_trie()

In [9]:
def fix_ratio(fix, word):
    return words_freq[fix] / words_freq[word]

def best_ratio(fix1, fix2):
    return words_freq[fix1] / words_freq[fix2]

In [10]:
a = 0
b = 0.95
c = 0.05
top = 3
fix_ratio_threshold = 5.0
best_ratio_threshold = 3.0

In [11]:
from operator import itemgetter, attrgetter

In [12]:
def compute_correction_probability(correction, prev_op):
    return a * operation_probability[correction[0]] +\
           b * action_probability[correction] +\
           c * dependent_probability[(prev_op, correction)]
    
    
def fix_word(root, word, fix_variants, prev_op = None, corrections_num = 0):
    if corrections_num > 1:
        return
    if not root.terminal and not word:
        return
    if root.terminal and not word:
        fix_variants.append((root.word, root.freq))
        return
    corrections = []
    cur_char = word[0] 
    for idx in range(len(root.children)):
        child = root.children[idx]
        if (len(word) >= 4) or (cur_char == child.char):
            corrections.append(('replace', cur_char, child.char, idx))    
    corrections_probs = list(map(lambda cor: compute_correction_probability(cor[0:3], prev_op), corrections))
    corrections_statistic = sorted(list(zip(corrections, corrections_probs)), key=itemgetter(1), reverse=True)
    top_corrections = corrections_statistic[0:top]
    for correction, _ in top_corrections:
        cur_char, fixed_char, idx = correction[1:4]
        next_node = root.children[idx]
        fix_word(next_node, word[1:], fix_variants, correction[0:3],
                 corrections_num if cur_char == fixed_char else corrections_num + 1)

In [13]:
def process_word(word):
    if is_russian(word):
        fix_variants = []
        fix_word(trie, word , fix_variants)
        best_fixes = sorted(fix_variants, key = itemgetter(1), reverse = True)
        result = best_fixes[0] if fix_ratio(best_fixes[0][0], word) >= fix_ratio_threshold else word
        if len(best_fixes) > 1:
            result = best_fixes[0][0] if best_ratio(best_fixes[0][0], best_fixes[1][0]) >= best_ratio_threshold else word
        return result
    else:
        return word

In [14]:
fixed = 0

with codecs.open('no_fix.submission.csv', 'r', 'utf8') as file:
    with codecs.open('submission.csv', 'w', 'utf8') as submission:
        header = file.readline()
        submission.write(header)
        for line in file:
            processed_line = line.split(',')
            word = processed_line[0]
            if is_russian(word):
                fixed_word = process_word(word)
                if fixed_word != word:
                    fixed += 1
                submission.write(f'{word},{fixed_word}\n')
            else:
                submission.write(line)

print(fixed)

17268
