In [1]:
import json

In [108]:
minimized = []
with open(f"data/test.english.jsonlines") as f:
        data_lines = f.readlines()
        for line in data_lines:
            minimized.append(json.loads(line))

In [109]:
preds = []
with open(f"output/preds.jsonl") as f:
        data_lines = f.readlines()
        for line in data_lines:
            preds.append(json.loads(line))

In [110]:
doc_to_prediction, doc_to_subtoken_map = preds
keys = sorted(list(doc_to_prediction.keys()), key = lambda x: int(x.split('_')[0].split('/')[1]))

In [111]:
predictions = {}
subtoken_map = {}
for key in keys:
    idx_key = f"faa_{key.split('_')[0].split('/')[1]}_0"
    predictions[idx_key] = doc_to_prediction[key]
    subtoken_map[idx_key] = doc_to_subtoken_map[key]

In [112]:
# Using code taken from s2e-coref/conll.py

In [7]:
import collections, operator, re

In [8]:
BEGIN_DOCUMENT_REGEX = re.compile(r"#begin document \((.*)\); part (\d+)")  # First line at each document
COREF_RESULTS_REGEX = re.compile(r".*Coreference: Recall: \([0-9.]+ / [0-9.]+\) ([0-9.]+)%\tPrecision: \([0-9.]+ / [0-9.]+\) ([0-9.]+)%\tF1: ([0-9.]+)%.*", re.DOTALL)

In [10]:
def get_doc_key(doc_id, part):
    return "{}_{}".format(doc_id, int(part))

def output_conll(input_file, output_file, predictions, subtoken_map):
    prediction_map = {}
    for doc_key, clusters in predictions.items():
        start_map = collections.defaultdict(list)
        end_map = collections.defaultdict(list)
        word_map = collections.defaultdict(list)
        for cluster_id, mentions in enumerate(clusters):
            for start, end in mentions:
                start, end = subtoken_map[doc_key][start], subtoken_map[doc_key][end]
                if start == end:
                    word_map[start].append(cluster_id)
                else:
                    start_map[start].append((cluster_id, end))
                    end_map[end].append((cluster_id, start))
        for k,v in start_map.items():
            start_map[k] = [cluster_id for cluster_id, end in sorted(v, key=operator.itemgetter(1), reverse=True)]
        for k,v in end_map.items():
            end_map[k] = [cluster_id for cluster_id, start in sorted(v, key=operator.itemgetter(1), reverse=True)]
        prediction_map[doc_key] = (start_map, end_map, word_map)

    word_index = 0
    for line in input_file.readlines():
        row = line.split()
        if len(row) == 0:
            output_file.write("\n")
        elif row[0].startswith("#"):
            begin_match = re.match(BEGIN_DOCUMENT_REGEX, line)
            if begin_match:
                doc_key = get_doc_key(begin_match.group(1), begin_match.group(2))
                start_map, end_map, word_map = prediction_map[doc_key]
                word_index = 0
            output_file.write(line)
            output_file.write("\n")
        else:
            assert get_doc_key(row[0], row[1]) == doc_key
            coref_list = []
            if word_index in end_map:
                for cluster_id in end_map[word_index]:
                    coref_list.append("{})".format(cluster_id))
            if word_index in word_map:
                for cluster_id in word_map[word_index]:
                    coref_list.append("({})".format(cluster_id))
            if word_index in start_map:
                for cluster_id in start_map[word_index]:
                    coref_list.append("({}".format(cluster_id))

            if len(coref_list) == 0:
                row[-1] = "-"
            else:
                row[-1] = "|".join(coref_list)

            output_file.write("   ".join(row))
            output_file.write("\n")
            word_index += 1

In [11]:
input_file = open('data/test.english.v4_gold_conll')
output_file = open('output/preds.conll','w')
output_conll(input_file, output_file, predictions, subtoken_map)

In [13]:
prediction_map = {}
for doc_key, clusters in predictions.items():
    start_map = collections.defaultdict(list)
    end_map = collections.defaultdict(list)
    word_map = collections.defaultdict(list)
    for cluster_id, mentions in enumerate(clusters):
        for start, end in mentions:
            start, end = subtoken_map[doc_key][start], subtoken_map[doc_key][end]
            if start == end:
                word_map[start].append(cluster_id)
            else:
                start_map[start].append((cluster_id, end))
                end_map[end].append((cluster_id, start))
    for k,v in start_map.items():
        start_map[k] = [cluster_id for cluster_id, end in sorted(v, key=operator.itemgetter(1), reverse=True)]
    for k,v in end_map.items():
        end_map[k] = [cluster_id for cluster_id, start in sorted(v, key=operator.itemgetter(1), reverse=True)]
    prediction_map[doc_key] = (start_map, end_map, word_map)

In [14]:
import pandas as pd

faa_df = pd.read_csv('../../data/FAA_data/Maintenance_Text_data_nona.csv')

In [103]:
outdict = {'id':[], 'sample':[], 'start_map':[], 'end_map':[], 'word_map':[], 'corefs':[]}

for idoc in range(len(faa_df)):
    if keys[idoc].split('/')[1].split('_')[1] != faa_df['c5'].iat[idoc]:
        print("we have a problem")
    
    outdict['id'].append(faa_df['c5'].iat[idoc])
    outdict['sample'].append(faa_df['c119'].iat[idoc])
    outdict['start_map'].append(dict(prediction_map[f'faa_{idoc}_0'][0]))
    outdict['end_map'].append(dict(prediction_map[f'faa_{idoc}_0'][1]))
    outdict['word_map'].append(dict(prediction_map[f'faa_{idoc}_0'][2]))
    
    part = []
    for sentence in minimized[idoc]['sentences']:
        part = part + sentence

    # save to new dicts for easier access
    starts = {}
    for start, word_idx_list in prediction_map[f'faa_{idoc}_0'][0].items():
        for word_idx in word_idx_list:
            starts[word_idx] = starts.get(word_idx, []) + [start]
    
    ends = {}
    for end, word_idx_list in prediction_map[f'faa_{idoc}_0'][1].items():
        for word_idx in word_idx_list:
            ends[word_idx] = ends.get(word_idx, []) + [end]
    
    for word, word_idx_list in prediction_map[f'faa_{idoc}_0'][2].items():
        for word_idx in word_idx_list:
            starts[word_idx] = starts.get(word_idx, []) + [word]
            ends[word_idx] = ends.get(word_idx, []) + [word]

    
    corefs = {}
    for word_idx in starts.keys():
        for ispan in range(len(starts[word_idx])):
            start = starts[word_idx][ispan]
            end = ends[word_idx][ispan]
            corefs[word_idx] = corefs.get(word_idx, []) + [' '.join(part[start:end+1])]
    
    outdict['corefs'].append(corefs)

In [106]:
pd.DataFrame(outdict).to_csv('../../data/results/s2e-coref/s2e-coref.csv')