In [2]:
%load_ext autoreload
%autoreload 2

In [15]:
import os
import json
import collections
import itertools
import pprint

from nsc.api import utils
from nsc.utils import io
from nsc.data import utils as data_utils

from spell_checking import BENCHMARK_DIR, DATA_DIR

In [82]:
with open(os.path.join(DATA_DIR, "misspellings", "train_misspellings.json"), "r") as inf:
    train_missp = json.load(inf)
with open(os.path.join(DATA_DIR, "misspellings", "dev_misspellings.json"), "r") as inf:
    dev_missp = json.load(inf)
with open(os.path.join(DATA_DIR, "misspellings", "test_misspellings.json"), "r") as inf:
    test_missp = json.load(inf)
with open(os.path.join(DATA_DIR, "misspellings", "misspellings.json"), "r") as inf:
    missp = json.load(inf)
len(train_missp), len(dev_missp), len(test_missp), len(missp)

(119554, 36645, 36753, 119725)

In [90]:
# make sure that all misspellings lists are disjoint
for correct, misspellings in train_missp.items():
    misspellings_set = set(misspellings)
    if misspellings_set.intersection(set(dev_missp.get(correct, set()))):
        raise RuntimeError("found train misspellings in dev misspellings")
    if misspellings_set.intersection(set(test_missp.get(correct, set()))):
        raise RuntimeError("found train misspellings in test misspellings")
                                     
for correct, misspellings in dev_missp.items():
    misspellings_set = set(misspellings)
    if misspellings_set.intersection(set(test_missp.get(correct, set()))):
        raise RuntimeError("found dev misspellings in test misspellings")

In [91]:
with open(os.path.join(DATA_DIR, "misspellings_neuspell", "misspellings.json"), "r") as inf:
    neuspell_missp = json.load(inf)
with open(os.path.join(DATA_DIR, "misspellings_bea", "misspellings.json"), "r") as inf:
    bea_missp = json.load(inf)
with open(os.path.join(DATA_DIR, "misspellings_no_neuspell_no_bea", "misspellings.json"), "r") as inf:
    no_neuspell_no_bea_missp = json.load(inf)
len(neuspell_missp), len(bea_missp), len(no_neuspell_no_bea_missp)

(23332, 18168, 119647)

In [35]:
benchmarks = io.glob_safe(os.path.join(BENCHMARK_DIR, "test", "sec", "neuspell", "bea60k", "corrupt.txt"))
benchmarks

['/home/sebastian/msc/masters_thesis/code/spell_checking/benchmarks/test/sec/neuspell/bea60k/corrupt.txt']

In [73]:
def extract_misspellings_from_parallel_corpus(corrupt_file, correct_file):
    corrupt_lines = utils.load_text_file(corrupt_file)
    correct_lines = utils.load_text_file(correct_file)
    
    misspellings = collections.defaultdict(set)
    for i, (corrupt_line, correct_line) in enumerate(zip(corrupt_lines, correct_lines)):
        affected = False
        assert len(corrupt_line.split()) == len(correct_line.split())
        for corrupt_word, correct_word in zip(corrupt_line.split(), correct_line.split()):
            if corrupt_word != correct_word:
                misspellings[correct_word].add(corrupt_word)
    return misspellings

In [76]:
def misspelling_overlap(m1, m2):
    overlap = 0
    total = 0
    for cor, m1s in m1.items():
        total += len(m1s)
        if cor in m2:
            overlap += len(set(m1s).intersection(set(m2[cor])))
    return 100 * overlap / total

In [97]:
from typing import Dict, List
def merge_dicts(*args: Dict) -> Dict[str, List[str]]:
    merged = {}
    for key in merge_keys(*args):
        values = merge_values(*args, key=key)
        if len(values) == 0:
            continue
        merged[key] = values
    return merged


def merge_keys(*args: Dict) -> List[str]:
    keys = set(args[0].keys())
    for i in range(1, len(args)):
        keys = keys.union(set(args[i].keys()))
    return list(keys)


def merge_values(*args: Dict, key: str) -> List[str]:
    values = args[0].get(key, set())
    for i in range(1, len(args)):
        values = values.union(args[i].get(key, set()))
    values.discard(key)
    return list(values)

In [104]:
neuspell_corrupt = os.path.join(DATA_DIR, "raw", "neuspell", "traintest", "train.1blm.noise.prob")
neuspell_correct = os.path.join(DATA_DIR, "raw", "neuspell", "traintest", "train.1blm")
misspellings_extracted_from_prob_training = extract_misspellings_from_parallel_corpus(neuspell_corrupt, neuspell_correct)

neuspell_corrupt = os.path.join(DATA_DIR, "raw", "neuspell", "traintest", "train.1blm.noise.word")
misspellings_extracted_from_word_training = extract_misspellings_from_parallel_corpus(neuspell_corrupt, neuspell_correct)

# neuspell_corrupt = os.path.join(DATA_DIR, "raw", "neuspell", "traintest", "train.1blm.noise.random")
# misspellings_extracted_from_rand_training = extract_misspellings_from_parallel_corpus(neuspell_corrupt, neuspell_correct)

misspellings_extracted_from_training = merge_dicts(misspellings_extracted_from_prob_training, misspellings_extracted_from_word_training) #, misspellings_extracted_from_rand_training)

neuspell_corrupt = os.path.join(DATA_DIR, "raw", "neuspell", "traintest", "test.bea60k.noise")
neuspell_correct = os.path.join(DATA_DIR, "raw", "neuspell", "traintest", "test.bea60k")

misspellings_extracted_from_test = extract_misspellings_from_parallel_corpus(neuspell_corrupt, neuspell_correct)

In [105]:
print(f"overlap neuspell train --> bea60 test: {misspelling_overlap(misspellings_extracted_from_training, misspellings_extracted_from_test)}")
print(f"overlap bea60k test --> neuspell train: {misspelling_overlap(misspellings_extracted_from_test, misspellings_extracted_from_training)}")

overlap neuspell train --> bea60 test: 0.6845065102921232
overlap bea60k test --> neuspell train: 40.98745851323171


In [106]:
print(f"overlap bea60k test --> missp train: {misspelling_overlap(misspellings_extracted_from_test, missp)}")
print(f"overlap bea60k test --> neuspell train: {misspelling_overlap(misspellings_extracted_from_test, neuspell_missp)}")
print(f"overlap bea60k test --> neuspell train: {misspelling_overlap(misspellings_extracted_from_test, bea_missp)}")
print(f"overlap bea60k test --> no neuspell no bea train: {misspelling_overlap(misspellings_extracted_from_test, no_neuspell_no_bea_missp)}")

overlap bea60k test --> missp train: 93.41204805122331
overlap bea60k test --> neuspell train: 93.41498516756249
overlap bea60k test --> neuspell train: 93.41498516756249
overlap bea60k test --> no neuspell no bea train: 33.85614004170705
