In [13]:
import csv
import json
from copy import deepcopy
csv.field_size_limit(1131072)

1131072

In [14]:
# Collect All Correction Annotations
coe_annotations = []
for split_name in ["batch_0", "batch_1", "batch_2","batch_3","batch_4","batch_5","batch_6","batch_7"]:
    with open('data/zh/annotations/'+split_name+'.csv', 'r', encoding='utf-8') as f:
        reader = csv.DictReader(f)
        for item in reader:
            coe_annotations.append(item)

In [15]:
correction_inputs_ids = []
with open('data/zh/source/to_tasa_all_zh_trial.csv', 'r', encoding='utf-8') as f:
    reader = csv.DictReader(f)
    for item in reader:
        correction_inputs_ids.append(item)

In [16]:
# Collect Source Data in English
all_en_data = {}
with open('data/zh/source/all_coref_data_en_finalized.json', 'r') as f:
    temp = json.load(f)
    for item in temp:
        all_en_data[item['scene_id']] = item

en_mention_dict = {}
for item in all_en_data:
    annotations = all_en_data[item]['annotations']
    for anno in annotations:
        query = anno['query']
        antecedents = anno['antecedents']
        en_mention_dict[query['mention_id']] = [query['sentenceIndex'], query['startToken'],query['endToken']]
        if antecedents!=[]:
            if isinstance(antecedents[0], dict):
                for ante in antecedents:
                    en_mention_dict[ante['mention_id']] = [ante['sentenceIndex'], ante['startToken'],ante['endToken']]

In [17]:
# Collect Source Data in Chinese
all_zh_data = {}
with open('data/zh/source/all_coref_data_en_zh_finalized.json', 'r') as f:
    temp = json.load(f)
    for item in temp:
        all_zh_data[item['scene_id']] = item

zh_mention_dict = {}
for item in all_zh_data:
    annotations = all_zh_data[item]['annotations']
    for anno in annotations:
        query = anno['query']
        antecedents = anno['antecedents']
        zh_mention_dict[query['mention_id']] = [query['sentenceIndex'], query['startToken'],query['endToken']]
        if antecedents!=[]:
            if isinstance(antecedents[0], dict):
                for ante in antecedents:
                    zh_mention_dict[ante['mention_id']] = [ante['sentenceIndex'], ante['startToken'],ante['endToken']]

In [18]:
# Collect Source Data in Chinese
all_zh_data_char = {}
with open('data/zh/source/all_coref_data_en_zh_finalized_char.json', 'r') as f:
    temp = json.load(f)
    for item in temp:
        all_zh_data_char[item['scene_id']] = item

In [19]:
def remove_drop_words(alignment):
    result = set()
    for start, end in alignment:
        result.add((start, end//2))
    return result

def extract_pair(result):
    starts = set()
    ends = set()
    for start, end in result:
        starts.add(start)
        ends.add(end)
    start_idx = min(list(starts))
    end_idx = max(list(ends))
    return [min(list(ends)), max(list(ends))+1]

def extract_alignment_pair_original(alignment):
    result = set()
    for s in range(len(alignment)):
        for t in range(len(alignment[0])):
            if alignment[s][t]:
                result.add((s, t))
    return sorted(list(result))

def extract_alignment_pair_space(alignment):
    result = set()
    for s in range(len(alignment)):
        for t in range(len(alignment[0])):
            if alignment[s][t]:
                result.add((s, t))
    return sorted(list(remove_drop_words(result)))

In [21]:
correction_dict = {}
remove_set = set()
add_dict = {}

num_corrected = 0

for i in range(len(coe_annotations)):
    # ZH Side Numbers
    zh_original_result = extract_alignment_pair_space(json.loads(coe_annotations[i]['Input.config_obj'])['alignment'])
    zh_corrected_result = extract_alignment_pair_space(json.loads(coe_annotations[i]['Answer.alignment']))
    if zh_original_result==zh_corrected_result:
        continue
    mention_id = correction_inputs_ids[i]['mention_id']
    num_corrected += 1
    # Build char_to_word idx map
    sent_id = en_mention_dict[mention_id][0]
    sent_char = all_zh_data_char[mention_id[:10]]['sentences'][sent_id]
    sent_word = all_zh_data[mention_id[:10]]['sentences'][sent_id]
    map_char_word = {}
    count = 0
    for idx_w, word in enumerate(sent_word):
        for char in word:
            map_char_word[count] = idx_w
            count += 1
    # Extract Mention start_end and convert from char level to word level
    a = ()
    b = ()
    if zh_original_result:
        a_temp = extract_pair(zh_original_result)
        if a_temp:
            a = (map_char_word[a_temp[0]], map_char_word[a_temp[1] - 1] + 1)
    if zh_corrected_result:
        b_temp = extract_pair(zh_corrected_result)
        b = (map_char_word[b_temp[0]], map_char_word[b_temp[1] - 1] + 1)

    if b==(): # Add mention_id to remove set
        remove_set.add(mention_id)
    elif a==(): # Collect mentions to add
        add_dict[mention_id] = [sent_id, b[0], b[1]]
    else:
        correction_dict[mention_id] = [sent_id, b[0], b[1]]

In [22]:
print(num_corrected, len(coe_annotations))

1967 7671


In [9]:
import pickle as pkl

scene_ids = set()
for item in correction_dict:
    scene_ids.add(item[:10])
for item in remove_set:
    scene_ids.add(item[:10])
for item in add_dict:
    scene_ids.add(item[:10])

all_results = {}
for item in scene_ids:
    all_results[item] = {
        "correction_dict": {},
        "remove_set": set(),
        "add_dict": {}
    }

for item in remove_set:
    all_results[item[:10]]['remove_set'].add(item)
for item in correction_dict:
    all_results[item[:10]]['correction_dict'][item] = correction_dict[item]
for item in add_dict:
    all_results[item[:10]]['add_dict'][item] = add_dict[item]

with open('data/zh/all_correction_results.pkl', 'wb') as f:
    pkl.dump(all_results, f)

In [10]:
print(all_results)

{'s09e16c01t': {'correction_dict': {'s09e16c01t|13': [0, 1, 3], 's09e16c01t|21': [2, 11, 13], 's09e16c01t|30': [2, 3, 5], 's09e16c01t|39': [3, 4, 7], 's09e16c01t|66': [9, 6, 7], 's09e16c01t|80': [16, 1, 2], 's09e16c01t|82': [16, 6, 7], 's09e16c01t|85': [16, 12, 13], 's09e16c01t|100': [16, 11, 13], 's09e16c01t|103': [16, 22, 25], 's09e16c01t|110': [17, 2, 4]}, 'remove_set': set(), 'add_dict': {'s09e16c01t|4': [0, 1, 4], 's09e16c01t|5': [0, 2, 3], 's09e16c01t|9': [0, 1, 2]}}, 's09e07c03t': {'correction_dict': {'s09e07c03t|0': [0, 0, 3], 's09e07c03t|3': [1, 0, 24], 's09e07c03t|6': [1, 1, 3], 's09e07c03t|12': [1, 16, 18], 's09e07c03t|28': [5, 4, 5]}, 'remove_set': set(), 'add_dict': {'s09e07c03t|4': [1, 1, 2], 's09e07c03t|26': [4, 5, 6], 's09e07c03t|42': [9, 0, 1], 's09e07c03t|55': [11, 3, 4], 's09e07c03t|69': [12, 0, 1]}}, 's09e10c04t': {'correction_dict': {'s09e10c04t|15': [3, 1, 2], 's09e10c04t|16': [3, 1, 3], 's09e10c04t|22': [3, 16, 17], 's09e10c04t|28': [4, 7, 17], 's09e10c04t|37': [

In [12]:
print(39.23-37.18)
print(39.54-37.01)
print(50.58-47.20)

2.049999999999997
2.530000000000001
3.3799999999999955
