In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import json
import collections
import itertools
import pprint

from nsc.api import utils
from nsc.utils import io
from nsc.data import utils as data_utils

from spell_checking import BENCHMARK_DIR, DATA_DIR

In [3]:
with open(os.path.join(DATA_DIR, "misspellings", "train_misspellings.json"), "r") as inf:
    train_missp = json.load(inf)
with open(os.path.join(DATA_DIR, "misspellings", "dev_misspellings.json"), "r") as inf:
    dev_missp = json.load(inf)
with open(os.path.join(DATA_DIR, "misspellings", "test_misspellings.json"), "r") as inf:
    test_missp = json.load(inf)
with open(os.path.join(DATA_DIR, "misspellings", "misspellings.json"), "r") as inf:
    missp = json.load(inf)
len(train_missp), len(dev_missp), len(test_missp), len(missp)

(119554, 36645, 36753, 119725)

In [4]:
# make sure that all misspellings lists are disjoint
for correct, misspellings in train_missp.items():
    misspellings_set = set(misspellings)
    if misspellings_set.intersection(set(dev_missp.get(correct, set()))):
        raise RuntimeError("found train misspellings in dev misspellings")
    if misspellings_set.intersection(set(test_missp.get(correct, set()))):
        raise RuntimeError("found train misspellings in test misspellings")
                                     
for correct, misspellings in dev_missp.items():
    misspellings_set = set(misspellings)
    if misspellings_set.intersection(set(test_missp.get(correct, set()))):
        raise RuntimeError("found dev misspellings in test misspellings")

In [6]:
with open(os.path.join(DATA_DIR, "misspellings_no_neuspell_no_bea", "misspellings.json"), "r") as inf:
    no_neuspell_no_bea_missp = json.load(inf)
len(missp), len(no_neuspell_no_bea_missp), sum(len(v) for v in missp.values()), sum(len(v) for v in no_neuspell_no_bea_missp.values())

(119725, 119647, 2303867, 2240553)

In [61]:
benchmarks = io.glob_safe(os.path.join(BENCHMARK_DIR, "test", "sec", "neuspell", "bea60k", "corrupt.txt"))
benchmarks

['/home/sebastian/msc/masters_thesis/code/spell_checking/benchmarks/test/sec/neuspell/bea60k/corrupt.txt']

In [62]:
def extract_misspellings_from_parallel_corpus(corrupt_file, correct_file):
    corrupt_lines = utils.load_text_file(corrupt_file)
    correct_lines = utils.load_text_file(correct_file)
    
    misspellings = collections.defaultdict(list)
    for i, (corrupt_line, correct_line) in enumerate(zip(corrupt_lines, correct_lines)):
        affected = False
        assert len(corrupt_line.split()) == len(correct_line.split())
        for corrupt_word, correct_word in zip(corrupt_line.split(), correct_line.split()):
            if corrupt_word != correct_word:
                misspellings[correct_word].append(corrupt_word)
    return misspellings

In [63]:
def misspelling_overlap(m1, m2):
    overlap = 0
    total = 0
    for cor, m1s in m1.items():
        total += len(m1s)
        if cor in m2:
            overlap += len(set(m1s).intersection(set(m2[cor])))
    return 100 * overlap / total

In [64]:
from typing import Dict, List
def merge_dicts(*args: Dict) -> Dict[str, List[str]]:
    merged = {}
    for key in merge_keys(*args):
        values = merge_values(*args, key=key)
        if len(values) == 0:
            continue
        merged[key] = values
    return merged


def merge_keys(*args: Dict) -> List[str]:
    keys = set(args[0].keys())
    for i in range(1, len(args)):
        keys = keys.union(set(args[i].keys()))
    return list(keys)


def merge_values(*args: Dict, key: str) -> List[str]:
    values = set(args[0].get(key, set()))
    for i in range(1, len(args)):
        values = values.union(set(args[i].get(key, set())))
    values.discard(key)
    return list(values)

In [65]:
neuspell_corrupt = os.path.join(DATA_DIR, "raw", "neuspell", "traintest", "train.1blm.noise.prob")
neuspell_correct = os.path.join(DATA_DIR, "raw", "neuspell", "traintest", "train.1blm")
misspellings_extracted_from_prob_training = extract_misspellings_from_parallel_corpus(neuspell_corrupt, neuspell_correct)

neuspell_corrupt = os.path.join(DATA_DIR, "raw", "neuspell", "traintest", "train.1blm.noise.word")
misspellings_extracted_from_word_training = extract_misspellings_from_parallel_corpus(neuspell_corrupt, neuspell_correct)

# neuspell_corrupt = os.path.join(DATA_DIR, "raw", "neuspell", "traintest", "train.1blm.noise.random")
# misspellings_extracted_from_rand_training = extract_misspellings_from_parallel_corpus(neuspell_corrupt, neuspell_correct)

misspellings_extracted_from_training = merge_dicts(misspellings_extracted_from_prob_training, misspellings_extracted_from_word_training) #, misspellings_extracted_from_rand_training)

neuspell_corrupt = os.path.join(DATA_DIR, "raw", "neuspell", "traintest", "test.bea60k.noise")
neuspell_correct = os.path.join(DATA_DIR, "raw", "neuspell", "traintest", "test.bea60k")

misspellings_extracted_from_test = extract_misspellings_from_parallel_corpus(neuspell_corrupt, neuspell_correct)

In [36]:
count_missp = collections.Counter()
for misspelled in misspellings_extracted_from_test.values():
    count_missp += collections.Counter(misspelled)

In [51]:
found_bea = 0
found_me = 0
for m in count_missp:
    if any(m in m_ for m_ in misspellings_extracted_from_training.values()):
        found_bea += 1
    if any(m in m_ for m_ in no_neuspell_no_bea_missp.values()):
        found_me += 1
    print(found_bea, found_me)
found_bea, found_me

0 0
0 0
1 1
1 1
2 2
3 3
3 3
4 4
5 5
6 6
6 6
6 6
6 6
7 7
8 7
8 7
9 7
10 8
11 9
11 9
12 10
13 11
14 12
15 13
16 14
17 15
18 16
18 16
18 16
19 17
20 18
21 19
22 19
23 20
23 20
24 21
25 22
26 23
26 24
26 24
26 24
27 25
27 25
27 25
27 26
27 26
28 27
28 28
29 29
30 30
30 30
30 30
31 31
32 32
33 33
34 34
34 34
35 35
36 36
37 37
37 37
38 38
38 38
38 38
38 38
38 38
38 38
38 38
38 38
39 39
40 40
40 40
41 41
42 41
42 41
43 42
44 42
45 43
45 43
46 44
47 45
48 46
49 47
50 48
51 49
51 49
51 49
52 50
53 51
54 52
54 52
55 53
56 54
56 54
57 55
58 56
58 56
59 57
60 58
61 59
62 60
63 61
64 62
64 62
65 63
65 63
66 64
67 65
68 66
69 67
69 67
69 67
70 68
71 68
72 68
73 69
74 69
75 70
76 71
77 72
77 72
77 72
77 72
77 72
78 73
79 74
79 74
79 74
80 74
81 75
82 76
83 77
84 78
85 79
86 80
87 81
88 82
89 82
90 83
91 84
92 84
93 85
94 86
95 87
96 87
97 88
98 89
99 89
100 90
101 90
102 91
103 91
104 92
105 92
106 93
107 94
108 95
108 95
108 95
108 95
109 96
109 97
109 97
110 98
111 98
112 99
112 100
112 100
113 101

KeyboardInterrupt: 

In [66]:
print(f"overlap neuspell train --> bea60 test: {misspelling_overlap(misspellings_extracted_from_training, misspellings_extracted_from_test)}")
print(f"overlap bea60k test --> neuspell train: {misspelling_overlap(misspellings_extracted_from_test, misspellings_extracted_from_training)}")

overlap neuspell train --> bea60 test: 0.6845065102921232
overlap bea60k test --> neuspell train: 19.9175039963462


In [69]:
print(f"overlap bea60k test --> missp train: {misspelling_overlap(missp, misspellings_extracted_from_test)}")
print(f"overlap bea60k test --> no neuspell no bea train: {misspelling_overlap(misspellings_extracted_from_test, no_neuspell_no_bea_missp)}")

overlap bea60k test --> missp train: 1.3804616325508374
overlap bea60k test --> no neuspell no bea train: 16.452100936286822
