In [38]:
import re

In [None]:
def no_corrections(alignment_dict):
    return len(alignment_dict) == 1 and len(alignment_dict['X']) == 0

In [175]:
def align_tokens(m2_block):
    unc_sentence, *edits = m2_block.splitlines()
    unc_tokens = re.split(
        r'\s+',
        unc_sentence
    )[1:]
    alignment_dict = {}
    alignment_dict['X'] = [] # For additions

    spans_outputs = []
    for edit in edits:
        coords, _, out, *_ = edit.split('|||')
        i, j = list(
            map(
                int,
                coords.split()[1:]))
        spans_outputs.append((
            (i,j),
            out))
    
    spans_outputs.sort()
    spans   = [el[0] for el in spans_outputs]
    outputs = [el[1] for el in spans_outputs]
    
    # If the edit is one-to-one, we align the beginning
    # of the edit with output_token_num.
    # We also keep track of one-word additions (zero-length input
    # spans) and deletions (zero-length output sequences).
    # Otherwise we skip the span and increase output_token_num
    # by the length of the output sequence.
    last_span_end = 0
    output_token_num = 0
    for idx in range(len(spans)):
        i, j = spans[idx]
        if i > last_span_end:
            output_token_num += i-last_span_end # Move cursor by the number of
                                                # copied tokens.
        out = outputs[idx]
#         print(spans[idx], f"inp: {' '.join(unc_tokens[i:j])}", f"out: {out}")
        out_len = len(out.split())
        if i-j == 0: # Addition
            if out_len == 1:
                alignment_dict['X'].append({
                    'idx': output_token_num,
                    'inp': ' '.join(unc_tokens[i:j]),
                    'out': out # For testing
                })
            output_token_num += out_len
        elif out == '': # Deletion
            if j-i == 1:
                alignment_dict[i] = 'X'
        elif j-i == out_len == 1: # One-to-one replacement
            alignment_dict[i] = {
                'idx': output_token_num,
                'inp': ' '.join(unc_tokens[i:j]),
                'out': out
            }
            output_token_num += 1
        else: # Something else
            output_token_num += out_len
        last_span_end = j
    return alignment_dict

In [207]:
def print_sentences(unc_tokens, cor_tokens):
    for i, tok in enumerate(unc_tokens):
        print(f'{tok}[{i}]', end=' ')
    print()
    for i, tok in enumerate(cor_tokens):
        print(f'{tok}[{i}]', end=' ')
    print('\n')
    
def test_alignment(m2_block, cor):
    cor_tokens = re.split(
        r'\s+',
        cor
    )
    cor = ' '.join(cor_tokens)
    unc_tokens = re.split(
        r'\s+',
        m2_block.splitlines()[0]
    )[1:]
    unc = ' '.join(unc_tokens)
    alignment_dict = align_tokens(m2_block)
    
    for k, v in alignment_dict.items():
        if v == 'X': # Don't test deletions for now
            continue
        if k == 'X':
            for el in v:
                i = el['idx']
                out = el['out']
                if cor_tokens[i] != out:
                    raise ValueError(f'{k}->{i}')
        else:
            i = v['idx']
            out = v['out']
            if cor_tokens[i] != out:
                raise ValueError(f'{k}->{i}')

In [208]:
# Look for corrected sentences where the output as per alignment
# is not the same as per the edit annotation.
corrected = 0
errors = 0
for part in [
    'dev',
    'train',
    'test'
]:
    with open(f'm2_files/RULEC-GEC.{part}.M2', 'r') as inp:
        m2_dev_blocks = inp.read().strip().split('\n\n')
    with open(f'preprocessing/RULEC-GEC.{part}.corrected', 'r') as inp:
        m2_dev_cor_blocks = inp.readlines()
    assert(len(m2_dev_blocks) == len(m2_dev_cor_blocks))
    corrected = 0
    errors = 0
    for i in range(len(m2_dev_blocks)):
        alignment_dict = align_tokens(m2_dev_blocks[i])
        if no_corrections(alignment_dict):
            continue
        try:
            test_alignment(
                m2_dev_blocks[i],
                m2_dev_cor_blocks[i]
            )
            corrected += 1
        except ValueError:
            errors += 1
        except IndexError:
            errors += 1

In [209]:
corrected

2247

In [210]:
errors

23