In [12]:
from collections import Counter, defaultdict
from edit import SubwordEdit
import json
from utils import load_data, write_json, write_tsv, apply_edits

In [13]:
def read_edits_tsv(comps_edits_path, nocomprs_edits_path):
    examples = []
    example = {'subwords': [], 'cmpr_edits': [], 'edits': []}
    with open(comps_edits_path) as f1, open(nocomprs_edits_path) as f2:

        for cmpr_line, no_cmpr_line in zip(f1.readlines(), f2.readlines()):
            cmpr_line = cmpr_line.strip().split('\t')
            no_cmpr_line = no_cmpr_line.strip().split('\t')

            if len(cmpr_line) == 2 and len(no_cmpr_line) == 2:
                assert cmpr_line[0] == no_cmpr_line[0]

                example['subwords'].append(cmpr_line[0].replace('<s>', ''))
                example['cmpr_edits'].append(cmpr_line[1].replace('<s>', ''))
                example['edits'].append(no_cmpr_line[1].replace('<s>', ''))
            else:
                 examples.append(example)
                 example = {'subwords': [], 'edits': [], 'cmpr_edits': []}
    
    return examples

In [14]:
def detokenize_sent(sent):
    detokenize_sent = []
    for subword in sent:
        if subword.startswith('##'):
            detokenize_sent[-1] = detokenize_sent[-1] + subword.replace('##', '')
        else:
            detokenize_sent.append(subword)

    return ' '.join(detokenize_sent)

In [15]:
def edits_intersection(qalb14_edits, zaebuc_edits):
    qalb14_edits_nocmprss = Counter([edit.edit for ex in qalb14_edits for edit in ex['subword-edits-append']])
    zaebuc_edits_nocmprss = Counter([edit.edit for ex in zaebuc_edits for edit in ex['subword-edits-append']])

    edits_to_keep = []
    for edit in qalb14_edits_nocmprss:
        if edit in zaebuc_edits_nocmprss:
            edits_to_keep.append(edit)

    return edits_to_keep

In [28]:
def adjust_edits(qalb14_edits, edits_to_keep):
    adjusted_dataset = []

    for ex_num, example in enumerate(qalb14_edits):
        # print(ex_num)
        src = example['src']
        tgt = example['tgt']
        subwords = [e.subword for e in example['subword-edits-append']]
        edits = [e.edit for e in example['subword-edits-append']]

        # Ensure subwords and edits align
        assert len(subwords) == len(edits)

        rewritten_subwords = []
        rewritten_raw_subwords = []
        adjusted_edits = []

        for i, (subword, edit) in enumerate(zip(subwords, edits)):
            if edit in edits_to_keep:
                # Retain the subword and edit as is
                rewritten_subwords.append(subword)
                adjusted_edits.append(edit)
                continue

            sw_edit = SubwordEdit(subword=subword, raw_subword=subword, edit=edit)
            rewritten_subword = sw_edit.apply(subword)

            if edit.startswith('M'):
                # Handle modifications
                if rewritten_subword.startswith('##'):
                    raise ValueError("Unexpected modification: rewritten_subword starts with '##'")
                
                rewritten_parts = rewritten_subword.split()
                if len(rewritten_parts) > 1:

                    # Merge the first part with the last subword, and add remaining parts
                    rewritten_subwords[-1] += rewritten_parts[0]

                    if len(set(adjusted_edits[-1])) == 1 and set(adjusted_edits[-1]) == {'K'}:
                        adjusted_edits[-1] = 'K' * len(rewritten_subwords[-1].replace('##', ''))
                    else:
                        adjusted_edits[-1] += 'K' * len(rewritten_parts[0].replace('##', ''))
                        
                        
                    rewritten_subwords.extend(rewritten_parts[1:])

                    adjusted_edits.extend(['K'] * (len(rewritten_parts) - 1))

                elif rewritten_subword:
                    # Simple merge with the last subword
                    rewritten_subwords[-1] += rewritten_subword

                    if len(set(adjusted_edits[-1])) == 1 and set(adjusted_edits[-1]) == {'K'}:
                        adjusted_edits[-1] = 'K' * len(rewritten_subwords[-1].replace('##', ''))
                    else:
                        if (rewritten_subword.startswith(' ') or rewritten_subword.endswith(' ')):
                            adjusted_edits[-1] += 'K' * len(rewritten_subword.replace('##', ''))
                        else:
                            adjusted_edits[-1] += 'K' * len(rewritten_parts[0].replace('##', ''))
    
            else:
                rewritten_subwords.append(rewritten_subword)
                adjusted_edits.append('K')


        assert len(rewritten_subwords) == len(adjusted_edits)

        src = detokenize_sent(rewritten_subwords)
        if src == tgt:
            continue

        check = [SubwordEdit(subword, subword, edit) for subword, edit in zip(rewritten_subwords, adjusted_edits)]

        rewritten_src = apply_edits(rewritten_subwords, check)
        
        if ' '.join(rewritten_src) != example['tgt']:
            import pdb; pdb.set_trace()

        # Append the adjusted example to the dataset
        adjusted_dataset.append({
            'src': detokenize_sent(rewritten_subwords),
            'tgt': tgt,
            'subword-edits-append': [SubwordEdit(subword, subword, edit) for subword, edit in zip(rewritten_subwords, adjusted_edits)]
        })

    return adjusted_dataset

In [29]:
qalb14 = load_data('/scratch/ba63/arabic-text-editing/edits/gec/qalb14/edits_no_compressed/qalb14-arabertv02/subword-level/train_edits.json',
                   edits_granularity='subword')
zaebuc = load_data('/scratch/ba63/arabic-text-editing/edits/gec/zaebuc/edits_no_compressed/zaebuc-arabertv02/subword-level/train_edits.json',
                   edits_granularity='subword')

In [30]:
edits_to_keep = edits_intersection(qalb14, zaebuc)

In [31]:
qalb14_adjusted = adjust_edits(qalb14, edits_to_keep)

In [36]:
write_tsv(path='/scratch/ba63/arabic-text-editing/edits/gec/qalb14_adj+zaebuc_x10/edits_no_compressed/qalb14_adj+zaebuc_x10-arabertv02/subword-level/train',
          data=qalb14_adjusted + zaebuc * 10,
          edits_granularity='subword')

In [37]:
write_json(path='/scratch/ba63/arabic-text-editing/edits/gec/qalb14_adj+zaebuc_x10/edits_no_compressed/qalb14_adj+zaebuc_x10-arabertv02/subword-level/train_edits.json',
          data=qalb14_adjusted + zaebuc * 10,
          edits_granularity='subword')

In [214]:
# adjusted_edits = [edit for ex in qalb14_adjusted for edit in ex['edits']]

In [150]:
# Counter(adjusted_edits)