In [1]:
"""
In this script, we want to convert annotations to conll style. There are several steps:
1.clustering input data
2.Turn cluster data into conll format
3.Try to run the preprocess script from HOI code base
"""

'\nIn this script, we want to convert annotations to conll style. There are several steps:\n1.clustering input data\n2.Turn cluster data into conll format\n3.Try to run the preprocess script from HOI code base\n'

In [5]:
import pickle as pkl
from copy import deepcopy
import jsonlines
from utils.my_util import cluster_mentions, remove_speaker_prefix
import json
from collections import defaultdict

## Prepare Dialogue Data

In [6]:
speaker_dict = {}
with open('data/raw_source/dialogue_zh/all_coref_data_en.json', 'r') as f:
    temp = json.load(f)
    for line in temp:
        scene_id = line['scene_id'][:-1]
        speakers = []
        for sent in line['sentences']:
            speakers.append(" ".join(sent[:sent.index(":")]))
        speaker_dict[scene_id] = speakers

split_dict = {"train":[], "dev":[], "test":[]}
with open('data/raw_source/dialogue_zh/dev_finalized.pkl', 'rb') as f:
    temp = pkl.load(f)
    for line in temp:
        split_dict['dev'].append(line['scene_id'])
with open('data/raw_source/dialogue_zh/test_finalized.pkl', 'rb') as f:
    temp = pkl.load(f)
    for line in temp:
        split_dict['test'].append(line['scene_id'])
with open('data/raw_source/dialogue_zh/train_finalized.pkl', 'rb') as f:
    temp = pkl.load(f)
    for line in temp:
        split_dict['train'].append(line['scene_id'])

In [7]:
def remove_empty_sentences(instance):
    sentences = instance['sentences']
    answers = instance['answers']
    speakers = instance['speakers']

    # Build old sent_id to new sent_id map
    map_sent_id = {}
    count = 0
    for i, sent in enumerate(sentences):
        if sent == []:
            continue
        map_sent_id[i] = count
        count += 1

    # Collect answers, speakers for each sentence
    temp = []
    for i, sent in enumerate(sentences):
        if sent == []:
            continue
        annotations = []
        for answer in answers:
            if answer[0][0]==i:
                annotations.append(answer)
        temp.append([sent, annotations, speakers[i]])

    # Change Sentence ID
    sentences = []
    answers = []
    speakers = []
    for i, (sent, annotations, speaker) in enumerate(temp):
        # print(i, speaker, sent)
        sentences.append(sent)
        temp_answers = []
        for query, antecedents in annotations:
            new_query = tuple((map_sent_id[query[0]], query[1], query[2]))
            # print(query, new_query)
            new_antecedents = []
            if isinstance(antecedents, str):
                new_antecedents = antecedents
                # print(new_antecedents)
            else:
                # print(antecedents)
                for antecedent in antecedents:
                    new_antecedents.append((map_sent_id[antecedent[0]], antecedent[1], antecedent[2]))
            # print(new_antecedents)
            temp_answers.append([new_query, new_antecedents])
        answers.extend(temp_answers)
        speakers.append(speaker)

    return {
        "sentences": sentences,
        "answers": answers,
        "speakers": speakers,
        "scene_id": instance['scene_id']
    }

In [8]:
def cluster_en_mentions_with_id(answers, sentences):
    """
    Cluster mention including plural. The clustering steps are as follows:
    1. Gather all non-plural query-annotation pairs
    2. Merge all non-plural pairs to build big cluster
    3. Add Plural: turn each mention in plural into {1.speaker name, 2.cluster id, 3.turn sent_id, start, end into special identification}, then use these to build a string for clustering
    4. Add each pair to cluster and do merging
    5. Remove strings from each cluster
    ?One Thing to Consider: Whether use speaker mention in sentence to merge to speaker cluster?
    """
    all_clusters = []
    speaker_set = set()
    # Generate all cluster (no plural), we will add plural latter
    for query, annotations in answers:
        if isinstance(annotations, str):
            if annotations=="notPresent":
                all_clusters.append([query])
        elif len(annotations)==1:
            temp = [query]
            for token in annotations:
                if token[1]==0:
                    try:
                        speaker = " ".join(sentences[token[0]][token[1]: sentences[token[0]].index(":")]).lower()
                        temp.append(speaker)
                        speaker_set.add(speaker)
                    except:
                        continue
                else:
                    temp.append(token)
            all_clusters.append(temp)

    # Merge clusters if any clusters have common mentions
    merged_clusters = []
    for cluster in all_clusters:
        existing = None
        for mention in cluster:
            for merged_cluster in merged_clusters:
                if mention in merged_cluster:
                    existing = merged_cluster
                    break
            if existing is not None:
                break
        if existing is not None:
            existing.update(cluster)
        else:
            merged_clusters.append(set(cluster))
    merged_clusters = [list(cluster) for cluster in merged_clusters]

    # Add Plural
    for query, annotations in answers:
        if isinstance(annotations, list):
            if len(annotations)>1:
                temp_anno = []
                for token in annotations:
                    if token[1]==0:
                        speaker = " ".join(sentences[token[0]][token[1]: sentences[token[0]].index(":")]).lower()
                        temp_anno.append(speaker)
                        speaker_set.add(speaker)
                    # if token[1]==0:
                    #     try:
                    #         speaker = " ".join(sentences[token[0]][token[1]: sentences[token[0]].index(":")]).lower()
                    #         temp_anno.append(speaker)
                    #         speaker_set.add(speaker)
                    #     except:
                    #         continue
                    else:
                        # If the cluster is already in cluster, use the cluster id as identification, else use the index
                        cluster_idx = -1
                        for idx, cluster in enumerate(merged_clusters):
                            if token in cluster:
                                cluster_idx = idx
                                break
                        if cluster_idx != -1:
                            temp_anno.append(str(cluster_idx))
                        else:
                            temp_anno.append("*" + "*".join([str(num) for num in token]) + "*")
                temp_cluster = [query, "||".join(sorted(temp_anno))]
                merged_clusters.append(temp_cluster)

    # Merge Plural
    all_clusters = deepcopy(merged_clusters)
    merged_clusters = []
    for cluster in all_clusters:
        existing = None
        for mention in cluster:
            for merged_cluster in merged_clusters:
                if mention in merged_cluster:
                    existing = merged_cluster
                    break
            if existing is not None:
                break
        if existing is not None:
            existing.update(cluster)
        else:
            merged_clusters.append(set(cluster))
    merged_clusters = [list(cluster) for cluster in merged_clusters]

    temp_output = []
    for cluster in merged_clusters:
        temp_output.append([token for token in cluster if isinstance(token, tuple)])
    return temp_output


In [9]:
"""
Build Mention ID Clusters in English Side
"""

en_data = []
en_all_ids = []
en_mention_id_clusters = {}
with open('data/raw_source/dialogue_en/all_coref_data_en_finalized.json', 'r') as f:
# with open('data/raw_source/dialogue_zh/all_coref_data_en_zh_seg.json', 'r') as f:
# with open('data/raw_source/dialogue_zh/dev-test-batch1_zh.json', 'r') as f:
    reader = jsonlines.Reader(f)
    for bulk in reader:
        for idx, instance in enumerate(bulk):
            # if idx>=5:
            #     break
            scene_id = instance['scene_id']
            if scene_id == "":
                continue
            sentences = instance['sentences']
            # print(sentences)
            # sentences = [[token for token in "".join(sent)] for sent in instance['sentences']]
            annotations = instance['annotations']

            # Build tuple_id to mention_id dictionary
            tuple_mention_dict = {}
            for item in annotations:
                query = tuple((item['query']['sentenceIndex'], item['query']['startToken'], item['query']['endToken']))
                antecedents = item['antecedents']
                for antecedent in antecedents:
                    if isinstance(antecedent, dict):
                        temp_antecedent = tuple((antecedent['sentenceIndex'], antecedent['startToken'], antecedent['endToken']))
                        tuple_mention_dict[temp_antecedent] = antecedent['mention_id']
                tuple_mention_dict[query] = item['query']['mention_id']
            # print(tuple_mention_dict)
            # print()

            en_all_ids.append(scene_id)
            speakers = speaker_dict[scene_id]
            answers = []
            for item in annotations:
                query = (item['query']['sentenceIndex'], item['query']['startToken'], item['query']['endToken'])
                antecedents = item['antecedents']
                # print(query)
                # print(antecedents)
                # print()
                if antecedents in [['n', 'o', 't', 'P', 'r', 'e', 's', 'e', 'n', 't'], ['null_projection'], ['empty_subtitle']]:
                    answers.append([query, "notPresent"])
                else:
                    temp_answer = []
                    for antecedent in antecedents:
                        if isinstance(antecedent, dict):
                            temp_answer.append((antecedent['sentenceIndex'], antecedent['startToken'], antecedent['endToken']))
                        else:
                            temp_answer = " ".join(antecedents)
                    answers.append([query, temp_answer])
            mention_id_cluster = []
            for cluster in cluster_en_mentions_with_id(answers, sentences):
                temp = []
                for mention in cluster:
                    temp.append(tuple_mention_dict[mention])
                mention_id_cluster.append(temp)
            en_mention_id_clusters[scene_id] = mention_id_cluster

            # print(mention_id_cluster)
            # print("=="*50)

            en_data.append({
                "sentences": sentences,
                "answers": answers,
                "speakers": speakers,
                "scene_id": scene_id,
                "mention_id_cluster": mention_id_cluster
            })

with open('en_mention_id_cluster.pkl', 'wb') as f:
    pkl.dump(en_mention_id_clusters, f)

In [60]:
"""
Load English Mention ID Cluster
"""
with open('en_mention_id_cluster.pkl', 'rb') as f:
    en_mention_id_clusters = pkl.load(f)

In [61]:
"""
Get Projected Data in Chinese Side
"""

zh_data = []
zh_all_ids = []
with open('data/raw_source/dialogue_zh/all_coref_data_en_zh_with_id.json', 'r') as f:
# with open('data/raw_source/dialogue_zh/all_coref_data_en_zh_seg.json', 'r') as f:
# with open('data/raw_source/dialogue_zh/dev-test-batch1_zh.json', 'r') as f:
    reader = jsonlines.Reader(f)
    for bulk in reader:
        for idx, instance in enumerate(bulk):
            if idx>=5:
                break
            scene_id = instance['scene_id']
            if scene_id == "":
                continue
            sentences = instance['sentences']
            # print(sentences)
            # sentences = [[token for token in "".join(sent)] for sent in instance['sentences']]
            annotations = instance['annotations']

            zh_all_ids.append(scene_id)
            speakers = speaker_dict[scene_id]
            answers = []
            for item in annotations:
                query = (item['query']['sentenceIndex'], item['query']['startToken'], item['query']['endToken'])
                antecedents = item['antecedents']
                if antecedents in [['n', 'o', 't', 'P', 'r', 'e', 's', 'e', 'n', 't'], ['null_projection'], ['empty_subtitle']]:
                    answers.append([query, "notPresent"])
                else:
                    temp_answer = []
                    for antecedent in antecedents:
                        if isinstance(antecedent, dict):
                            temp_answer.append((antecedent['sentenceIndex'], antecedent['startToken'], antecedent['endToken']))
                        else:
                            temp_answer = " ".join(antecedents)
                    answers.append([query, temp_answer])
            # mention_id_cluster = []
            # for cluster in cluster_en_mentions_with_id(answers, sentences):
            #     temp = []
            #     for mention in cluster:
            #         temp.append(tuple_mention_dict[mention])
            #     mention_id_cluster.append(temp)
            # en_mention_id_clusters[scene_id] = mention_id_cluster


            zh_data.append({
                "sentences": sentences,
                "answers": answers,
                "speakers": speakers,
                "scene_id": scene_id,
                "mention_id_cluster": mention_id_cluster
            })

In [62]:
print(zh_data)

[{'sentences': [['将', '光', '子', '正', '对', '平', '面', '上', '的', '双', '缝', '观', '察', '任', '意', '一', '个', '隙', '缝', '它', '不', '会', '穿', '过', '那', '两', '个', '隙', '缝', '如', '果', '没', '被', '观', '察', '那', '就', '会', '总', '之', '如', '果', '观', '察', '它', '在', '离', '开', '平', '面', '到', '击', '中', '目', '标', '之', '前', '它', '就', '不', '会', '穿', '过', '那', '两', '个', '隙', '缝'], ['没', '错', '但', '你', '为', '什', '么', '要', '说', '这', '个', '?'], ['没', '什', '么', '我', '只', '是', '觉', '得', '这', '个', '主', '意', '可', '以', '用', '于', '设', '计', 'T', '恤', '衫'], [], [], ['横', '1', '是', 'A', 'e', 'g', 'e', 'a', 'n', '竖', '8', '是', 'N', 'a', 'b', 'o', 'k', 'o', 'v', '横', '2', '6', '是', 'M', 'C', 'M', '竖', '1', '4', '是', '.', '.', '.', '手', '指', '挪', '开', '点', '这', '样', '一', '来', '横', '1', '4', '就', '是', 'P', 'o', 'r', 't', '瞧', '提', '示', '是', '"', 'P', 'a', 'p', 'a', 'd', 'o', 'c', '的', '首', '都', '"', '所', '以', '是', '太', '子', '港', '海', '地', '的'], ['能', '为', '你', '效', '劳', '吗', '?'], ['这', '里', '是', '高', '智', '商', '精', '子', '银', 

In [11]:
data = []
all_ids = []
with open('data/raw_source/dialogue_zh/all_coref_data_en_zh_seg.json', 'r') as f:
# with open('data/raw_source/dialogue_zh/dev-test-batch1_zh.json', 'r') as f:
    reader = jsonlines.Reader(f)
    for bulk in reader:
        for idx, instance in enumerate(bulk):
            if idx>=5:
                break
            scene_id = instance['scene_id']
            if scene_id == "":
                continue
            sentences = instance['sentences']
            # print(sentences)
            # sentences = [[token for token in "".join(sent)] for sent in instance['sentences']]
            annotations = instance['annotations']
            all_ids.append(scene_id)
            speakers = speaker_dict[scene_id]
            answers = []
            for item in annotations:
                query = (item['query']['sentenceIndex'], item['query']['startToken'], item['query']['endToken'])
                antecedents = item['antecedents']
                # print(query)
                # print(antecedents)
                # print()
                if antecedents in [['n', 'o', 't', 'P', 'r', 'e', 's', 'e', 'n', 't'], ['null_projection'], ['empty_subtitle']]:
                    answers.append([query, "notPresent"])
                else:
                    temp_answer = []
                    for antecedent in antecedents:
                        if isinstance(antecedent, dict):
                            temp_answer.append((antecedent['sentenceIndex'], antecedent['startToken'], antecedent['endToken']))
                        else:
                            temp_answer = " ".join(antecedents)
                    answers.append([query, temp_answer])

            clusters = cluster_en_mentions_with_id(answers, sentences)
            print(clusters)

            data.append(remove_empty_sentences({
                "sentences": sentences,
                "answers": answers,
                "speakers": speakers,
                "scene_id": scene_id
            }))

[[(0, 21, 22), (0, 8, 9), (0, 1, 2), (0, 27, 28), (0, 11, 12)], [(0, 3, 4), (0, 30, 31)], [(0, 39, 41), (0, 6, 11), (0, 15, 16)], [(0, 32, 33)], [(0, 34, 35)], [(1, 3, 4)], [(1, 2, 6)], [(2, 0, 1), (1, 4, 6)], [(0, 1, 41), (2, 4, 5), (2, 5, 6)], [(2, 9, 10)], [(5, 3, 4)], [(5, 7, 8)], [(5, 11, 12)], [(5, 18, 19)], [(5, 23, 24), (5, 25, 31), (5, 34, 36), (5, 32, 33)], [(5, 34, 35)], [(7, 0, 1)], [(7, 2, 5), (9, 3, 4)], [(10, 1, 2)], [(12, 0, 1)], [(12, 4, 5)], [(14, 4, 7)], [(16, 0, 1), (15, 4, 5)], [(15, 15, 16)], [(15, 13, 15)], [(15, 16, 21)], [(15, 20, 25)], [(15, 25, 26), (15, 16, 17)], [(15, 28, 29)], [(16, 2, 5)], [(16, 7, 11)], [(16, 12, 16)], [(16, 14, 15)], [(17, 6, 7)], [(17, 8, 12)], [(17, 13, 14), (17, 9, 12), (18, 2, 3)], [(17, 13, 15)], [(17, 17, 20)], [(17, 36, 37), (18, 5, 6), (17, 37, 38)], [(17, 30, 31)], [(17, 28, 29)], [(17, 25, 26)], [(22, 1, 4)], [(23, 9, 13)], [(17, 19, 20), (23, 10, 11)], [(0, 3, 11)]]
[[(0, 4, 6)], [(2, 3, 8)], [(2, 5, 6)], [(4, 6, 7), (4, 6, 9

In [13]:
correction_dict = {(5, 7, 8): (5, 7, 9), (7, 13, 17): (7, 15, 17), (11, 0, 1): (11, 0, 2), (11, 5, 7): (11, 4, 7), (16, 0, 3): (16, 0, 1)}
for item in data:
    if item['scene_id'] in ['s07e17c08t1']:
        print(item['sentences'])
        print()
        for token in item['answers']:
            print(token)
        print()
        print(item['speakers'])
        print('=='*50)

[['是', 'S', 'h', 'e', 'l', 'd', 'o', 'n', '的', '信', '息', '他', '说', '他', '一', '会', '儿', '就', '到'], ['你', '在', '干', '吗', '?'], ['他', '要', '确', '认', '你', '车', '里', '没', '有', '空', '气', '清', '新', '剂', '才', '会', '出', '现'], ['你', '这', '周', '日', '子', '不', '好', '过', '了'], ['那', '还', '不', '是', '因', '为', '你'], ['我', '让', '你', '们', '不', '要', '互', '相', '揶', '揄', '又', '没', '让', '你', '带', '他', '去', '渡', '蜜', '月'], ['我', '都', '不', '清', '楚', '你', '说', '了', '什', '么', '该', '死', '你', '的', '大', '波', '和', '消', '音', '器', '一', '样'], ['回', '得', '州', '很', '开', '心', '吧', '?'], ['和', '一', '名', '真', '正', '的', '宇', '航', '员', '一', '起', '参', '观', '宇', '航', '局', '可', '不', '是', '天', '天', '都', '有', '的', '好', '事'], ['和', '哪', '个', '宇', '航', '员', '?'], ['奥', '尔', '德', '林']]

[(0, 1, 8), 'notPresent']
[(0, 11, 12), [(0, 1, 8)]]
[(0, 13, 14), [(0, 11, 12)]]
[(0, 14, 17), 'notPresent']
[(1, 2, 4), 'notPresent']
[(1, 0, 1), 'Howard']
[(2, 0, 1), [(0, 1, 8)]]
[(2, 2, 14), 'notPresent']
[(2, 9, 14), 'notPresent']
[(2, 4, 7), 'n

In [34]:
split = "test"

document = []
for i in range(len(data)):
    sample = data[i]
    if sample['scene_id'] not in split_dict[split]:
        continue
    # print(sample)
    # print()
    # original_sentences = sample['sentences']
    # original_clusters = cluster_mentions(sample['answers'], original_sentences)
    # sentences, clusters, speakers = remove_speaker_prefix(original_sentences, original_clusters)
    sentences = sample['sentences']
    clusters = cluster_mentions(sample['answers'], sentences)
    speakers = sample['speakers']
    scene_id = sample['scene_id']
    part = int(scene_id[7:9])
    begin_line = "#begin document " + "(" + scene_id + "); part " + "%03d" % part
    end_line = "#end document"

    # Prepare for clustering
    cluster_field = []
    for sent in sentences:
        cluster_field.append([""]*len(sent))
    # Add start
    for idx, cluster in enumerate(clusters):
        for sent_id, start, end in cluster:
            end = end - 1
            if start != end:
                # print(cluster_field[sent_id])
                # print(sent_id, start, end, len(sentences[sent_id]))
                # print(sentences[sent_id])
                if cluster_field[sent_id][start] == "":
                    cluster_field[sent_id][start] += "(" + str(idx)
                else:
                    cluster_field[sent_id][start] += "|" + "(" + str(idx)
    # Add start==end
    for idx, cluster in enumerate(clusters):
        for sent_id, start, end in cluster:
            end = end - 1
            if start == end:
                if cluster_field[sent_id][start] == "":
                    cluster_field[sent_id][start] += "(" + str(idx) + ")"
                else:
                    cluster_field[sent_id][start] += "|" + "(" + str(idx) + ")"
    # Add End
    for idx, cluster in enumerate(clusters):
        for sent_id, start, end in cluster:
            end = end - 1
            if start != end:
                try:
                    if cluster_field[sent_id][end] == "":
                        cluster_field[sent_id][end] += str(idx) + ")"
                    else:
                        cluster_field[sent_id][end] += "|" + str(idx) + ")"
                except:
                    pass
                # if cluster_field[sent_id][end] == "":
                #     cluster_field[sent_id][end] += str(idx) + ")"
                # else:
                #     cluster_field[sent_id][end] += "|" + str(idx) + ")"

    # Build document
    document.append(begin_line + "\n")
    for sent, speaker, cluster_value in zip(sentences, speakers, cluster_field):
        for j, word in enumerate(sent):
            cluster_id = cluster_value[j]
            if cluster_id == "":
                cluster_id = "-"
            temp = [scene_id, str(part), str(j), word, "na", "na", "na", "na", "na", speaker, "na", "na", "na", cluster_id]
            document.append(" ".join(temp)+ "\n")
        document.append("" + "\n")
    document.append(end_line + "\n")

with open("data/conll_style/dialogue_chinese/"+ split+'.chinese.v4_gold_conll', 'w') as f:
    f.writelines(document)

## Build Chinese Sample

In [44]:
print(len(data))

1515


In [50]:
split = "train"

document = []
for i in range(len(data[:1])):
    sample = data[i]
    if sample['scene_id'] not in split_dict[split]:
        continue
    # print(sample)
    # print()
    # original_sentences = sample['sentences']
    # original_clusters = cluster_mentions(sample['answers'], original_sentences)
    # sentences, clusters, speakers = remove_speaker_prefix(original_sentences, original_clusters)
    sentences = sample['sentences']
    clusters = cluster_mentions(sample['answers'], sentences)
    speakers = sample['speakers']
    scene_id = sample['scene_id']
    part = int(scene_id[7:9])
    begin_line = "#begin document " + "(" + scene_id + "); part " + "%03d" % part
    end_line = "#end document"

    # Prepare for clustering
    cluster_field = []
    for sent in sentences:
        cluster_field.append([""]*len(sent))
    # Add start
    for idx, cluster in enumerate(clusters):
        for sent_id, start, end in cluster:
            end = end - 1
            if start != end:
                # print(cluster_field[sent_id])
                # print(sent_id, start, end, len(sentences[sent_id]))
                # print(sentences[sent_id])
                if cluster_field[sent_id][start] == "":
                    cluster_field[sent_id][start] += "(" + str(idx)
                else:
                    cluster_field[sent_id][start] += "|" + "(" + str(idx)
    # Add start==end
    for idx, cluster in enumerate(clusters):
        for sent_id, start, end in cluster:
            end = end - 1
            if start == end:
                if cluster_field[sent_id][start] == "":
                    cluster_field[sent_id][start] += "(" + str(idx) + ")"
                else:
                    cluster_field[sent_id][start] += "|" + "(" + str(idx) + ")"
    # Add End
    for idx, cluster in enumerate(clusters):
        for sent_id, start, end in cluster:
            end = end - 1
            if start != end:
                try:
                    if cluster_field[sent_id][end] == "":
                        cluster_field[sent_id][end] += str(idx) + ")"
                    else:
                        cluster_field[sent_id][end] += "|" + str(idx) + ")"
                except:
                    pass
                # if cluster_field[sent_id][end] == "":
                #     cluster_field[sent_id][end] += str(idx) + ")"
                # else:
                #     cluster_field[sent_id][end] += "|" + str(idx) + ")"

    # Build document
    document.append(begin_line + "\n")
    for sent, speaker, cluster_value in zip(sentences, speakers, cluster_field):
        for j, word in enumerate(sent):
            cluster_id = cluster_value[j]
            if cluster_id == "":
                cluster_id = "-"
            temp = [scene_id, str(part), str(j), word, "na", "na", "na", "na", "na", speaker, "na", "na", "na", cluster_id]
            document.append(" ".join(temp)+ "\n")
        document.append("" + "\n")
    document.append(end_line + "\n")

with open("data/conll_style/overfit_chinese/"+ "train"+'.chinese.v4_gold_conll', 'w') as f:
    f.writelines(document)
with open("data/conll_style/overfit_chinese/"+ "dev"+'.chinese.v4_gold_conll', 'w') as f:
    f.writelines(document)
with open("data/conll_style/overfit_chinese/"+ "test"+'.chinese.v4_gold_conll', 'w') as f:
    f.writelines(document)

In [45]:
print(len(document))

44425


In [97]:
file_name = "test"
data = []
with open('data/'+ file_name+'_temp.pkl', 'rb') as f:
    data.extend(pkl.load(f))

document = []
for i in range(len(data)):
    if file_name=="train" and i==38:
        continue
    if file_name=="test" and i==28:
        continue

    # if i>=100:
    #     continue
    sample = data[i]
    original_sentences = sample['sentences']
    original_clusters = cluster_mentions(sample['answers'], original_sentences)

    # Get Data ready for conversion
    sentences, clusters, speakers = remove_speaker_prefix(original_sentences, original_clusters)
    scene_id = sample['scene_id']
    part = int(scene_id[7:9])
    begin_line = "#begin document " + "(" + scene_id + "); part " + "%03d" % part
    end_line = "#end document"

    # Prepare for clustering
    cluster_field = []
    for sent in sentences:
        cluster_field.append([""]*len(sent))
    # Add start
    for idx, cluster in enumerate(clusters):
        for sent_id, start, end in cluster:
            end = end - 1
            if start != end:
                if cluster_field[sent_id][start] == "":
                    cluster_field[sent_id][start] += "(" + str(idx)
                else:
                    cluster_field[sent_id][start] += "|" + "(" + str(idx)
    # Add start==end
    for idx, cluster in enumerate(clusters):
        for sent_id, start, end in cluster:
            end = end - 1
            if start == end:
                if cluster_field[sent_id][start] == "":
                    cluster_field[sent_id][start] += "(" + str(idx) + ")"
                else:
                    cluster_field[sent_id][start] += "|" + "(" + str(idx) + ")"
    # Add End
    for idx, cluster in enumerate(clusters):
        for sent_id, start, end in cluster:
            end = end - 1
            if start != end:
                try:
                    if cluster_field[sent_id][end] == "":
                        cluster_field[sent_id][end] += str(idx) + ")"
                    else:
                        cluster_field[sent_id][end] += "|" + str(idx) + ")"
                except:
                    pass
                # if cluster_field[sent_id][end] == "":
                #     cluster_field[sent_id][end] += str(idx) + ")"
                # else:
                #     cluster_field[sent_id][end] += "|" + str(idx) + ")"

    # Build document
    document.append(begin_line + "\n")
    for sent, speaker, cluster_value in zip(sentences, speakers, cluster_field):
        for j, word in enumerate(sent):
            cluster_id = cluster_value[j]
            if cluster_id == "":
                cluster_id = "-"
            temp = [scene_id, str(part), str(j), word, "na", "na", "na", "na", "na", speaker, "na", "na", "na", cluster_id]
            document.append(" ".join(temp)+ "\n")
        document.append("" + "\n")
    document.append(end_line + "\n")

with open("data/input/"+ file_name+'.english.v4_gold_conll', 'w') as f:
    f.writelines(document)