In [1]:
import collections
import csv
import json
import os
import sys
import tarfile
import urllib.request

# CSpell Misspelling Correction Dataset

This script converts the misspelling detection / correct dataset by CSpell team ([link](https://lsg3.nlm.nih.gov/LexSysGroup/Projects/cSpell/current/web/index.html)) into our format. There are selection processes to filter some examples in CSpell dataset:

- The CSpell dataset contains different types of misspellings, and we carefully choose as much as possible. Also the datset has multiple-word misspellings and misspellings with special symbols (like punctuations). We filtered out such examples and keep only single-word/alphabet only misspellings.
- Since the CSpell is a software that detect & correct misspellings, while ours only correct misspellings, we filter out the examples that are not detected by CSpell. Please refer to `scripts/cspell_results.ipynb` for reproduction of the performance by CSpell and the masks we used to filter examples.


In [26]:
train_url = '/TrainSet_brat.tgz'
test_url = '/TestSet_reconciled.tgz'

train_fname = os.path.basename(train_url)
test_fname = os.path.basename(test_url)

dataset_dir = '../data/cspell/'

# CSpell files
train_tar_fpath = os.path.join(dataset_dir, train_fname)
test_tar_fpath = os.path.join(dataset_dir, test_fname)
train_dir = "../data/cspell/train"
test_dir = "../data/cspell/test"
#train_dir = os.path.join(dataset_dir, os.path.splitext(train_fname)[0])
#test_dir = os.path.join(dataset_dir, os.path.splitext(test_fname)[0])

print(train_dir)

# Output files
output_train_before_fname = 'train_before.tsv'  # Not filtered by CSpell detection
output_test_before_fname = 'test_before.tsv'
output_train_fname = 'train.tsv'  # Filtered by CSpell detection
output_test_fname = 'test.tsv'

# Excluded by NLM
ann_fnames_exclude = ['11199.ann']
ann_fnames_exclude = set(ann_fnames_exclude)

../data/cspell/train


## Download CSpell dataset

In [None]:
urllib.request.urlretrieve(train_url, train_tar_fpath)

In [None]:
urllib.request.urlretrieve(test_url, test_tar_fpath)

In [None]:
def extract(tar_fpath, extract_path='.'):
    print(f'Extract {tar_fpath} to {extract_path}')
    print(tar_fpath)
    tar = tarfile.open(tar_fpath, 'r')
    for item in tar:
        tar.extract(item, extract_path)
        if item.name.find(".tgz") != -1 or item.name.find(".tar") != -1:
            extract(item.name, "./" + item.name[:item.name.rfind('/')])

In [None]:
extract(train_tar_fpath, dataset_dir)
extract(test_tar_fpath, dataset_dir)

## Statistics of the misspellings + Selection strategy

In [27]:
def get_tokens(line, num_tokens):
    line = line.strip()
    tokens = []
    for i in range(num_tokens):
        space_start, space_end = 0, 0
        for j, c in enumerate(line):
            if c.isspace() and space_start == 0:
                space_start = j
            if not c.isspace() and space_start != 0:
                space_end = j
                break
        if space_end == 0 or space_end == 0:
            raise ValueError(f'Can find {i}-th token: {line}')
        
        tokens.append(line[:space_start])
        line = line[space_end:]
    return tokens, line

def read_ann(ann_fpath):
    print(ann_fpath)
    with open(ann_fpath) as fd:
        lines = fd.readlines()

    typos, corrections = [], {}
    for line in lines:
        line = line.strip()
        if line.startswith('A'):  # Important annotation -> skip
            continue
        elif line.startswith('T'):  # Misspelling locations
            tokens, remaining = get_tokens(line, 4)
            typo_label, typo_type, typo_span, typo = tokens[0], tokens[1], (int(tokens[2]), int(tokens[3])), remaining
            typos.append((typo_label, typo_type, typo_span, typo))
        elif line.startswith('#'):  # Correction to the annotations
            tokens, remaining = get_tokens(line, 3)
            cor_label, typo_label, correction = tokens[0], tokens[2], remaining
            corrections[typo_label] = (cor_label, correction)
        else:
            raise ValueError(f'Wrong header: {line}')
    
    typos.sort(key=lambda x: (x[2][0], -x[2][1]))  # Note: There can be overlapped annotations,
                                                   #       Sort them so that the former will contain the latter
    results = []
    for typo_label, typo_type, typo_span, typo in typos:
        if typo_label in corrections:
            results.append((typo_label, typo_type, typo_span, typo, corrections[typo_label][1]))
        else:
            results.append((typo_label, typo_type, typo_span, typo, None))
#             print(f'{ann_fname} - {typo_label}')
    return results

In [24]:
def show_ann_stat(dataset_root, print_overlap=False, type_show=None):
    print(f'Dataset dir: {dataset_root}')
    ann_fnames = [fname for fname in os.listdir(dataset_root) if fname.endswith('.ann')]
    ann_cor_counter = collections.defaultdict(lambda: [0, 0, 0, 0])
    total_ann_lines = 0
    total_anns = 0
    cnt_empty_ann_files = 0
    cnt_empty_anns = 0

    if type_show:
        print(f'\n[Annotations of ({type_show})]')
        print(f'{"Location(span)":60s} Typo            Correction      Overlapped')
        
    for ann_fname in ann_fnames:
        if ann_fname in ann_fnames_exclude:
            continue
    
        # Get typo annotations
        ann_fpath = os.path.join(dataset_root, ann_fname)
        ann_fcontent = [l.strip() for l in open(ann_fpath).readlines() if l.strip()]
        anns = read_ann(ann_fpath)
        total_ann_lines += len(ann_fcontent)
        total_anns += len(anns)
        cnt_empty_ann_files += int(len(ann_fcontent) == 0)
        cnt_empty_anns += int(len(anns) == 0)

        # Count types of typos, according the presence of correction
        for ann in anns:
            typo_label, typo_type, (start, end), typo, correction = ann
            ann_cor_counter[typo_type][0] += 1
            ann_cor_counter[typo_type][int(correction is None) + 1] += 1

        # Overlapped annotations
        temp = set()
        for i, ann1 in enumerate(anns):
            for j, ann2 in enumerate(anns):
                if i >= j: continue
                start1, end1 = ann1[2]
                start2, end2 = ann2[2]
                if not (end1 <= start2 or end2 <= start1):
                    temp.add(i)
                    temp.add(j)
                    if print_overlap and not (ann1[1] == 'RealWord' or ann2[1] == 'RealWord'):
                        print(f'\t{ann1}')
                        print(f'\t{ann2}')
        for i in temp:
            ann_cor_counter[anns[i][1]][3] += 1
        
        # Show a specific type of typo
        if type_show:
            for i, ann in enumerate(anns):
                if ann[1] == type_show:
                    print(f'{ann_fpath + str(ann[2]):60s} {ann[3]:15s} {ann[4] if ann[4] else "[None]":15s} {i in temp}')

    # Output stats
    print(f'\nTotal {len(set(ann_fnames) - ann_fnames_exclude)} ann files')
    print(f'  {cnt_empty_ann_files} empty files, {cnt_empty_anns} empty anns')
    print(f'Total {total_anns} anns from {total_ann_lines} lines')

    typo_types_temp = sorted(ann_cor_counter.keys())
    print('\nTypo Type        Total  Cor O  Cor X  Dupli')
    for typo_type in typo_types_temp:
        print(f'{typo_type:15s}' + ''.join([f'{n:7d}' for n in ann_cor_counter[typo_type]]))
    print('(Total)        ' + ''.join([f'{sum(l):7d}' for l in zip(*list(ann_cor_counter.values()))]))

## CSpell train set

### Annotation types

Typo -> Use only 1 word/1 word annotations
- Misspelling: 1+ word -> 1+ word
- Grammatical: 1+ word -> 0+ word (0: to delete)
- Punctuation: 

Typo, but do not use it
- ToSplit: 2+ words -> 1 word
- ToMerge: 1 word -> 2+ word

Extra
- RealWord: Real word typo. 1+ word -> 1+ word

### Selection strategies

These types are chosen as an example
- Misspelling with one-word (& alphabet)
- Grammatical with one-word (& alphabet)
- Punctuation with one-word (& alphabet) -> no alphabet-only typo!
- RealWord with one-word (& alphabet): remove if duplicate with other annotations (ToSplit, Misspelling, Grammatical)

These are not chosen
- ToSplit, ToMerge

### Correcting strategies for each annotation types

- Grammatical: Apply (remove if correction is None)
- Misspelling, ToMerge, ToSplit: Apply (ignore if correction is None)
- Realword: Apply if not duplicate with others

In [25]:
# Train set distribution
# type_show = 'RealWord'
type_show = None
show_ann_stat(train_dir, type_show=type_show)

Dataset dir: ../data/cspell/TrainSet_brat


FileNotFoundError: [Errno 2] No such file or directory: '../data/cspell/TrainSet_brat'

## CSpell test set

### Annotation types

Typo -> Use only 1 word/1 word annotations
- Informal: 1+ word (informal expression) -> 1+ word (ABB:/ACR: can precede)
- Misspelling: 1+ word -> 1+ word
- Punctuation: 1 word -> 1 word (ends with a punctuation)
- RealWord: Real word typo. 1+ word -> 1+ word

Typo, but do not use it
- ToMerge: 2+ words -> 1 word, mostly do not have label
- ToSplit: 1 word -> multiple words
- ToSplitOnPunct: Similar to "ToSplit", but these do not have corrections (because they are trivial)

Not a typo
- Unknown: Not valid words, but can not find corretions
- WordExists: Not a typo -> Do not use this as a typo example
- Garbage: No need to be corrected

### Selection strategy

These types are chosen as an example
- Misspelling with one-word (& alphabet)
- RealWord with one-word (& alphabet)
- Informal with one-word (& alphabet)  -> **Need to check this**

These are not chosen
- Punctuation: test examples of this type only does not have a punctuation at the end


### Correcting strategy for each annotation types

- ToSplit, Misspelling, Punctuation: Apply correction
- Informal: Apply correction, with `ABB: ` and `ACR: ` deleted (ignore if correction is None)
- ToSplitOnPunct: find split chars (`.?,&()-*/`) and add a space after that
- ToMerge: delete chars (space, `-`)
- WordExists, Unknown, RealWord, Garbage: ignore

In [None]:
# Test set distribution
show_ann_stat(test_dir)

## Extract misspelling examples

We define several filtering functions to extract the dataset, and we have different selection rules between train and test set based on the analysis above.

- Choose valid types of misspellings in
    - Train set: `choose_anns_for_example_train()`
    - Test set: `choose_anns_for_example_test()`
- Filter out duplidates in
    - Train set: `filter_anns_for_context_train()`
    - Test set: `X`
- Find out corrections for each annotation to filter examples in
    - Train set: `get_correction_for_clean_context_train()`
    - Test set: `get_correction_for_clean_context_test()`

In [None]:
def choose_anns_for_example_train(anns):
    anns_valid = []
    for ann in anns:
        typo_label, typo_type, (start, end), typo, correction = ann
        # select Misspelling, Grammatical, Punctuation, RealWord
        if typo_type not in ['Misspelling', 'Grammatical', 'Punctuation', 'RealWord']:
            continue
        if len(typo.split()) != 1 or not(correction) or len(correction.split()) != 1:
            continue
        anns_valid.append(ann)
    
    # Remove duplicate 'RealWord'
    anns_valid2 = []
    for i, ann in enumerate(anns_valid):
        typo_label, typo_type, (start, end), typo, correction = ann
        flag = True
        if typo_type == 'RealWord':
            for ann2 in anns_valid:
                start2, end2 = ann[2]
                if ann != ann2 and not (start2 >= end or end2 <= start):
                    flag = False
        if flag:
            anns_valid2.append(ann)
                    
    return anns_valid2

def choose_anns_for_example_test(anns):
    anns_valid = []
    for ann in anns:
        typo_label, typo_type, (start, end), typo, correction = ann
        # select Misspelling, RealWord, Informal
        if typo_type not in ['Misspelling', 'RealWord', 'Informal']:
            continue
        if len(typo.split()) != 1 or not(correction) or len(correction.split()) != 1:
            continue
        anns_valid.append(ann)
    return anns_valid


def filter_anns_for_context_train(anns_all, ann):
    # Remove 'RealWord' annotations that duplicate with others
    anns_valid = []
    for ann2 in anns_all:
        typo_label2, typo_type2, (start2, end2), typo2, correction2 = ann2
        flag = True
        if ann2 != ann and typo_type2 == 'RealWord':
            for ann3 in anns_all:
                typo_label3, typo_type3, (start3, end3), typo3, correction3 = ann3
                if ann3 != ann2 and not (start3 >= end2 or end3 <= start2):
                    flag = False
                    break
        if flag:
            anns_valid.append(ann2)
    return anns_valid


def get_correction_for_clean_context_train(ann, txt_correct):
    typo_label, typo_type, (start, end), typo, correction = ann
    if typo_type == 'Grammatical':
        return correction if correction else ''
    elif typo_type in ['Misspelling', 'ToMerge', 'ToSplit']:
        return correction if correction else typo
    elif typo_type == 'RealWord':
        return correction if correction else typo
    elif typo_type == 'Punctuation':
        return correction
    else:
        raise ValueError(f'Wrong annotation: {ann}')

def get_correction_for_clean_context_test(ann, txt_correct):
    typo_label, typo_type, (start, end), typo, correction = ann
    if typo_type in ['ToSplit', 'Misspelling', 'Punctuation']:
        return correction
    elif typo_type == 'Informal':
        if correction is None: return typo
        elif correction.startswith('ABB: '): return correction[5:]
        elif correction.startswith('ACR: '): return correction[5:]
        else: return correction
    elif typo_type == 'ToSplitOnPunct':
        punct_chars = list(".?,&()-*/")
        new_correction = typo
        for i in range(len(new_correction)-1, -1, -1):
            if new_correction[i] in punct_chars and new_correction[i+1] not in punct_chars:
                new_correction = new_correction[:i] + ' ' + new_correction[i+1:]
        return new_correction
    elif typo_type == 'ToMerge':
        return typo.replace(' ', '').replace('-', '')
    elif typo_type in ['WordExists', 'Unknown', 'RealWord', 'Garbage']:
        return typo

In [None]:
def get_typo_examples(dataset_root,
                      choose_ex_ann_fn,       # Choose anns to use as examples
                      filter_context_ann_fn,  # Filter some anns for cleaning text
                      get_correction_fn,      # Get correction according to its annotation type
                      verbose=False):
    # Load anns
    ann_fnames = sorted([fname for fname in os.listdir(dataset_root) if fname.endswith('.ann')])
    if verbose: print(f'{len(ann_fnames)} annotations')
        
    # Parsing each ann
    typo_examples = []  # index, note_id, type, typo, left, right, correct
    for ann_fname in ann_fnames:
        if ann_fname in ann_fnames_exclude: 
            if verbose: print(f'  {ann_fname} - skip')
            continue

        # Read ann, text
        note_id = os.path.splitext(ann_fname)[0]
        anns = read_ann(os.path.join(dataset_root, ann_fname))
        txt_fname = note_id + '.txt'
        with open(os.path.join(dataset_root, txt_fname)) as fd:
            txt = fd.read()

        # Get the valid ann 1: From the filter function
        anns_valid = choose_ex_ann_fn(anns)

        # Get the valid ann 2: not overlapped to each other
        anns_valid, anns_temp = [], anns_valid
        last_end = 0
        for ann in anns_temp:
            typo_label, typo_type, (start, end), typo, correction = ann
            if start < last_end: continue
            anns_valid.append(ann)
            last_start = start

        # Get dataset examples
        # For each typo example, we correct other typos to make the surrounding context clean
        if verbose: print(f'  {ann_fname} - {len(anns_valid)} valid annos')
        for ann in anns_valid:
            # Get anns for cleaning context
            anns_context = filter_context_ann_fn(anns, ann) if filter_context_ann_fn else anns
            
            typo_label, typo_type, (start, end), typo, correction = ann
            txt_correct, last_start = txt, len(txt) + 1
            for ann2 in anns_context[::-1]:
                typo_label2, typo_type2, (start2, end2), typo2, correction2 = ann2
                if end2 >= last_start: continue
                if ann == ann2:
                    right = txt_correct[end2:]
                    txt_correct = txt_correct[:start2]
                else:
#                     print(ann2)
#                     print(txt_correct[:start2])
#                     print(get_correction_fn(ann2, txt_correct))
#                     print(txt_correct[end2:])
                    txt_correct = txt_correct[:start2] + get_correction_fn(ann2, txt_correct) + txt_correct[end2:]
                last_start = start2
            left = txt_correct
            typo_examples.append((len(typo_examples), note_id, typo_type, typo, left, right, correction))
    return typo_examples    

- Training set

In [None]:
examples_train = get_typo_examples(train_dir,
                                   choose_anns_for_example_train,
                                   filter_anns_for_context_train,
                                   get_correction_for_clean_context_train,
                                   True)
print(f'{train_dir}: {len(examples_train)} examples')

type_counter = collections.Counter([e[2] for e in examples_train])
for k, v in type_counter.items():
    print(f'{k:15s}: {v}')

In [None]:
dataset_train_all = []   # For CIM
examples_train_cspell_all = []  # For CSpell

for example in examples_train:
    ex_id, note_id, typo_type, typo, left, right, correction = example
    
    # Process for our model
    note_id = int(note_id.split('-')[-1])
    typo, correction = typo.lower(), correction.lower()
    left, right = left.strip().replace('\t', ' ').replace('\n', ' '), right.strip().replace('\n', ' ')
    
    alpha_check = all([c.isalpha() for c in typo]) and all([c.isalpha() for c in correction])
    if alpha_check:
        ex_id = len(dataset_train_all)
        ex = (ex_id, note_id, typo, left, right, correction)
        dataset_train_all.append(ex)
        ex2 = (ex_id,) + example[1:]
        examples_train_cspell_all.append(ex2)
    
print(f'Test set (before CSpell filtering): {len(dataset_train_all)} examples')

- Test set

In [None]:
examples_test = get_typo_examples(test_dir,
                                  choose_anns_for_example_test,
                                  None,
                                  get_correction_for_clean_context_test,
                                  True)
print(f'{test_dir}: {len(examples_test)} examples')

type_counter = collections.Counter([e[2] for e in examples_test])
for k, v in type_counter.items():
    print(f'{k:15s}: {v}')

In [None]:
dataset_test_all = []   # For CIM
examples_test_cspell_all = []  # For CSpell

for example in examples_test:
    ex_id, note_id, typo_type, typo, left, right, correction = example
    
    # Process for our model
    note_id = int(note_id.split('-')[-1])
    typo, correction = typo.lower(), correction.lower()
    left, right = left.strip().replace('\t', ' ').replace('\n', ' '), right.strip().replace('\n', ' ')
    
    # Select only single-word alphabet-only examples
    alpha_check = all([c.isalpha() for c in typo]) and all([c.isalpha() for c in correction])
    if alpha_check:
        ex_id = len(dataset_test_all)
        ex = (ex_id, note_id, typo, left, right, correction)
        dataset_test_all.append(ex)
        ex2 = (ex_id,) + example[1:]
        examples_test_cspell_all.append(ex2)
    
print(f'Test set (before CSpell filtering): {len(dataset_test_all)} examples')

- Write dataset before CSpell filtering

In [None]:
# Dataset for our CIM model
def write_dataset(examples, output_fpath):
    # examples: list of (ex_id, note_id, typo, left, right, correction)
    with open(output_fpath, 'w') as fd:
        writer = csv.writer(fd, delimiter='\t')
        writer.writerow(['index', 'note_id', 'word', 'left', 'right', 'correct'])  
        for ex in examples:
            writer.writerow(ex)
            
def read_dataset(dataset_fpath):
    with open(dataset_fpath) as fd:
        reader = csv.reader(fd)
        dataset = list(reader)[1:]
    return dataset

In [None]:
output_train_fpath = os.path.join(dataset_dir, output_train_before_fname)
print(f'Write {len(dataset_train_all)} examples to {output_train_fpath}')
write_dataset(dataset_train_all, output_train_fpath)

output_test_fpath = os.path.join(dataset_dir, output_test_before_fname)
print(f'Write {len(dataset_test_all)} examples to {output_test_fpath}')
write_dataset(dataset_test_all, output_test_fpath)

## Write examples for CSpell

We write the datasets generated so far to files so that we can run CSpell on them.
After we run CSpell, we choose only examples that are detected as misspellings by CSpell (the filtering masks are already given below for easier reproduction).
To get the detailed output of CSpell, please run with `-t` and `-d` option enabled:
```
$ cspell -t -i:(input_file) -o:(output_file) -d > (debug_output_file)
```
Check `scripts/cspell_results.ipynb` for more details.

- Training set

In [None]:
cspell_example_train_root = os.path.join(train_dir, 'examples')
if not os.path.exists(cspell_example_train_root):
    os.makedirs(cspell_example_train_root)
    
example_stat = []
for example in examples_train_cspell_all:
    ex_id, note_id, typo_type, typo, left, right, correction = example
    start = len(left)
    end = start + len(typo)
    output = left + typo + right
    
    example_fpath = os.path.join(cspell_example_train_root, f'{ex_id}.txt')
    with open(example_fpath, 'w') as fd:
        fd.write(output)
    example_stat.append([ex_id, note_id, start, end, typo, correction])
    print(f'{example_fpath} {note_id}[{start}:{end}] {typo} -> {correction}')
    
example_stat_fpath = os.path.join(cspell_example_train_root, 'example_stat.json')
with open(example_stat_fpath, 'w') as fd:
    json.dump(example_stat, fd)
print(f'Write example stat to {example_stat_fpath}')

- Test set

In [None]:
cspell_example_test_root = os.path.join(test_dir, 'examples')
if not os.path.exists(cspell_example_test_root):
    os.makedirs(cspell_example_test_root)

char_counter = collections.Counter()

example_stat = []
for example in examples_test_cspell_all:
    ex_id, note_id, typo_type, typo, left, right, correction = example
    start = len(left)
    end = start + len(typo)
    output = left + typo + right
    
    for c in output:
        char_counter[c] += 1
    
    example_fpath = os.path.join(cspell_example_test_root, f'{ex_id}.txt')
    with open(example_fpath, 'w') as fd:
        fd.write(output)
    example_stat.append([ex_id, note_id, start, end, typo, correction])
    print(f'{example_fpath} {note_id}[{start}:{end}] {typo} -> {correction}')
        
example_stat_fpath = os.path.join(cspell_example_test_root, 'example_stat.json')
with open(example_stat_fpath, 'w') as fd:
    json.dump(example_stat, fd)
print(f'Write example stat to {example_stat_fpath}')

## Filter examples by CSpell

We further filter the examples that are not detected by CSpell software. Below are the masks that indicate whether each example's misspelling is detected by CSpell or not.

To get the masks by parsing the CSpell output of the files written above, see `scripts/cspell_results.ipynb`.

In [None]:
cspell_train_mask = [
    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
    1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
    1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 
    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1
]

# After running CSpell and scripts/cspell_results.ipynb, you can get the same mask
# cspell_train_mask = json.load(os.path.join(train_dir, 'cspell_train_mask.json'))

print(f'Train examples detected by CSpell: {sum(cspell_train_mask)}/{len(cspell_train_mask)}')

In [None]:
cspell_test_mask = [
    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
    1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 
    1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 
    1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 
    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 
    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
    0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 
    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
    1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 
    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1
]

# After running CSpell and scripts/cspell_results.ipynb, you can get the same mask
# cspell_test_mask = json.load(os.path.join(train_dir, 'cspell_test_mask.json'))

print(f'Test examples detected by CSpell: {sum(cspell_test_mask)}/{len(cspell_test_mask)}')

- Filter examples with the CSpell masks

In [None]:
dataset_train_cspell_filtered = [
    e for e, m in zip(dataset_train_all, cspell_train_mask) if m
]
print(f'Train set: {len(dataset_train_cspell_filtered)} examples')

dataset_test_cspell_filtered = [
    e for e, m in zip(dataset_test_all, cspell_test_mask) if m
]
print(f'Test set: {len(dataset_test_cspell_filtered)} examples')

## Write the final dataset

In [None]:
output_train_fpath = os.path.join(dataset_dir, output_train_fname)
print(f'Write {len(dataset_train_cspell_filtered)} examples to {output_train_fpath}')
write_dataset(dataset_train_cspell_filtered, output_train_fpath)

output_test_fpath = os.path.join(dataset_dir, output_test_fname)
print(f'Write {len(dataset_test_cspell_filtered)} examples to {output_test_fpath}')
write_dataset(dataset_test_cspell_filtered, output_test_fpath)