In [1]:
import pickle as pkl
from copy import deepcopy
import jsonlines
from utils.my_util import cluster_mentions, remove_speaker_prefix
import json

## Check whether data is correct： Correct

In [4]:
with open('data/raw_source/dialogue_proj_zh/all_coref_data_en_zh_finalized_word.json', 'r') as f:
    original_data = json.load(f)
with open('data/raw_source/dialogue_proj_zh/all_coref_data_en_zh_finalized_word_prob.json', 'r') as f:
    new_data = json.load(f)
print(len(original_data), len(new_data))

1240 1240


In [5]:
names = ['sentenceIndex', 'startToken', 'endToken', 'mention_id']
for i in range(len(original_data)):
    original_sample = original_data[i]
    new_sample = new_data[i]
    original_annotations = original_sample['annotations']
    new_annotations = new_sample['annotations']
    for a, b in zip(original_annotations, new_annotations):
        if sum([a['query'][name]==b['query'][name] for name in names])!=4:
            print(sum([a['query'][name]==b['query'][name] for name in names]))
        # print([a['query'][name]==b['query'][name] for name in names], sum([a['query'][name]==b['query'][name] for name in names]))

## Prepare Dialogue Data

In [6]:
speaker_dict = {}
with open('data/raw_source/dialogue_zh/all_coref_data_en.json', 'r') as f:
    temp = json.load(f)
    for line in temp:
        scene_id = line['scene_id'][:-1]
        speakers = []
        for sent in line['sentences']:
            speakers.append(" ".join(sent[:sent.index(":")]))
        speaker_dict[scene_id] = speakers

split_dict = {"train":[], "dev":[], "test":[]}
with open('data/raw_source/dialogue_zh/dev_finalized.pkl', 'rb') as f:
    temp = pkl.load(f)
    for line in temp:
        split_dict['dev'].append(line['scene_id'])
with open('data/raw_source/dialogue_zh/test_finalized.pkl', 'rb') as f:
    temp = pkl.load(f)
    for line in temp:
        split_dict['test'].append(line['scene_id'])
with open('data/raw_source/dialogue_zh/train_finalized.pkl', 'rb') as f:
    temp = pkl.load(f)
    for line in temp:
        split_dict['train'].append(line['scene_id'])

In [7]:
def remove_empty_sentences(instance):
    sentences = instance['sentences']
    answers = instance['answers']
    speakers = instance['speakers']

    # Build old sent_id to new sent_id map
    map_sent_id = {}
    count = 0
    for i, sent in enumerate(sentences):
        if sent == []:
            continue
        map_sent_id[i] = count
        count += 1

    # Collect answers, speakers for each sentence
    temp = []
    for i, sent in enumerate(sentences):
        if sent == []:
            continue
        annotations = []
        for answer in answers:
            if answer[0][0]==i:
                annotations.append(answer)
        temp.append([sent, annotations, speakers[i]])

    # Change Sentence ID
    sentences = []
    answers = []
    speakers = []
    for i, (sent, annotations, speaker) in enumerate(temp):
        # print(i, speaker, sent)
        sentences.append(sent)
        temp_answers = []
        for query, antecedents in annotations:
            new_query = tuple((map_sent_id[query[0]], query[1], query[2], query[3]))
            # print(query, new_query)
            new_antecedents = []
            if isinstance(antecedents, str):
                new_antecedents = antecedents
                # print(new_antecedents)
            else:
                # print(antecedents)
                for antecedent in antecedents:
                    new_antecedents.append((map_sent_id[antecedent[0]], antecedent[1], antecedent[2], antecedent[3]))
            # print(new_antecedents)
            temp_answers.append([new_query, new_antecedents])
        answers.extend(temp_answers)
        speakers.append(speaker)

    return {
        "sentences": sentences,
        "answers": answers,
        "speakers": speakers,
        "scene_id": instance['scene_id']
    }

In [9]:
"""
Cluster Chinese Mentions using English-Side Mention_IDs
"""
data = []
all_ids = []
with open('data/raw_source/dialogue_zh/all_coref_data_en_zh_finalized.json', 'r') as f:
# with open('data/raw_source/dialogue_zh/all_coref_data_en_zh_finalized_back.json', 'r') as f:
# with open('data/raw_source/dialogue_zh/dev-test-batch1_zh.json', 'r') as f:
    reader = jsonlines.Reader(f)
    for bulk in reader:
        for idx, instance in enumerate(bulk):
            # if idx>=10:
            #     break
            scene_id = instance['scene_id']
            if scene_id == "":
                continue
            sentences = instance['sentences']
            # print(sentences)
            # sentences = [[token for token in "".join(sent)] for sent in instance['sentences']]
            annotations = instance['annotations']
            all_ids.append(scene_id)
            speakers = speaker_dict[scene_id]
            answers = []
            for item in annotations:
                # Correct query
                query = (item['query']['sentenceIndex'], item['query']['startToken'], item['query']['endToken'], item['query']['mention_id'])
                antecedents = item['antecedents']
                if antecedents in [['n', 'o', 't', 'P', 'r', 'e', 's', 'e', 'n', 't'], ['null_projection'], ['empty_subtitle']]:
                    answers.append([query, "notPresent"])
                else:
                    temp_answer = []
                    for antecedent in antecedents:
                        if isinstance(antecedent, dict):
                            temp_answer.append((antecedent['sentenceIndex'], antecedent['startToken'], antecedent['endToken'], antecedent['mention_id']))
                        else:
                            temp_answer = " ".join(antecedents)
                    answers.append([query, temp_answer])

            # # Add correction results
            # if scene_id in correction_result:
            #     to_correct = correction_result[scene_id]
            #     correction_dict = to_correct['correction_dict']
            #     remove_set = to_correct['remove_set']
            #     add_dict = to_correct['add_dict']
            #     # Deal with to add_dict
            #     for item in add_dict:
            #         temp = [
            #             (add_dict[item][0], add_dict[item][1], add_dict[item][2], item),
            #             'notPresent'
            #         ]
            #         answers.append(temp)

            data.append(remove_empty_sentences({
                "sentences": sentences,
                "answers": answers,
                "speakers": speakers,
                "scene_id": scene_id
            }))

In [10]:
print(data[0])

{'sentences': [['将', '光子', '正对', '平面', '上', '的', '双缝', '观察', '任意', '一个', '隙缝', '它', '不会', '穿过', '那', '两个', '隙缝', '如果', '没', '被', '观察', '那', '就', '会', '总之', '如果', '观察', '它', '在', '离开', '平面', '到', '击中目标', '之前', '它', '就', '不会', '穿过', '那', '两个', '隙缝'], ['没错', '但', '你', '为什么', '要说', '这个', '?'], ['没什么', '我', '只是', '觉得', '这个', '主意', '可以', '用于', '设计', 'T恤衫'], ['横', '1', '是', 'Aegean', '竖', '8', '是', 'Nabokov', '横', '26', '是', 'MCM', '竖', '14', '是', '...', '手指', '挪开', '点', '这样一来', '横', '14', '就是', 'Port', '瞧', '提示', '是', '"', 'Papadoc', '的', '首都', '"', '所以', '是', '太子港', '海地', '的'], ['能', '为', '你', '效劳', '吗', '?'], ['这里', '是', '高智商', '精子', '银行', '吗', '?'], ['如果', '你', '这么', '问', '也许', '你', '不该', '来', '这'], ['我', '想', '就是', '这', '没错', '了'], ['把', '这个', '填一填'], ['谢谢', '我们', '马上', '好'], ['慢慢来', '我', '还要', '玩', '填字游戏', '噢', '慢', '着'], ['我', '办不到'], ['开玩笑', '?', '你', '可是', '半', '职业', '人士'], ['不', '我们', '这样', '是', '诈骗', '我们', '没法', '保证', '生', '出来', '的', '一定', '是', '高智商', '小孩', '我', '姐姐', '跟', '我', '有'

In [12]:
"""
Clustering Algorithm with Chinese Correction
"""
def cluster_mention_id_index(answers, sentences, en_clusters):
    """
    We cluster Chinese Side mentions according to the index in English Side
    """
    # Collect Mention_ID to Chinese Side tuple
    zh_mention_dict = {}
    for answer in answers:
        query = answer[0]
        zh_mention_dict[query[3]] = (query[0], query[1], query[2])
        antecedents = answer[1]
        if isinstance(antecedents, list):
            for antecedent in antecedents:
                zh_mention_dict[antecedent[3]] = (antecedent[0], antecedent[1], antecedent[2])

    # Incorporate Maunal Correction
    scene_id = list(zh_mention_dict.keys())[0]
    if scene_id[:10] in correction_result:
        to_correct = correction_result[scene_id[:10]]
        correction_dict = to_correct['correction_dict']
        remove_set = to_correct['remove_set']
        source_zh_mention_dict = deepcopy(zh_mention_dict)
        zh_mention_dict = {}

        # Perform Correction
        for mention_id in source_zh_mention_dict:
            # remove mentions
            if mention_id in remove_set:
                continue
            # Correct start, end
            elif mention_id in correction_dict:
                zh_mention_dict[mention_id] = tuple([source_zh_mention_dict[mention_id][0], correction_dict[mention_id][1], correction_dict[mention_id][2]])
            else:
                zh_mention_dict[mention_id] = source_zh_mention_dict[mention_id]

    # Gather Chinese Side cluster according to en_clusters
    new_cluster = []
    for cluster in en_clusters:
        temp = []
        for mention_id in cluster:
            if mention_id in zh_mention_dict:
                temp.append(zh_mention_dict[mention_id])
        if temp:
            new_cluster.append(temp)

    # Merge Cluster using (sent_id, start_id, end_id)
    all_clusters = deepcopy(new_cluster)
    merged_clusters = []
    for cluster in all_clusters:
        existing = None
        for mention in cluster:
            for merged_cluster in merged_clusters:
                if mention in merged_cluster:
                    existing = merged_cluster
                    break
            if existing is not None:
                break
        if existing is not None:
            existing.update(cluster)
        else:
            merged_clusters.append(set(cluster))
    merged_clusters = [list(cluster) for cluster in merged_clusters]
    return merged_clusters

In [None]:
split = "train"

with open('en_mention_id_cluster.pkl', 'rb') as f:
    en_mention_id_clusters = pkl.load(f)

document = []
for i in range(len(data)):
    # if i > 2:
    #     continue
    sample = data[i]
    if sample['scene_id'] not in split_dict[split]:
        continue

    # if sample['scene_id'] != "s05e14c00t0":
    #     continue

    # print(sample)
    # print()
    # original_sentences = sample['sentences']
    # original_clusters = cluster_mentions(sample['answers'], original_sentences)
    # sentences, clusters, speakers = remove_speaker_prefix(original_sentences, original_clusters)
    sentences = sample['sentences']
    # clusters = cluster_mentions(sample['answers'], sentences)
    speakers = sample['speakers']
    scene_id = sample['scene_id']
    clusters = cluster_mention_id_index(sample['answers'], sentences, en_mention_id_clusters[scene_id])
    part = int(scene_id[7:9])
    begin_line = "#begin document " + "(" + scene_id + "); part " + "%03d" % part
    end_line = "#end document"

    # Prepare for clustering
    cluster_field = []
    for sent in sentences:
        cluster_field.append([""]*len(sent))
    # Add start
    for idx, cluster in enumerate(clusters):
        for sent_id, start, end in cluster:
            end = end - 1
            if start != end:
                # print(cluster_field[sent_id])
                # print(sent_id, start, end, len(sentences[sent_id]))
                # print(sentences[sent_id])
                if cluster_field[sent_id][start] == "":
                    cluster_field[sent_id][start] += "(" + str(idx)
                else:
                    cluster_field[sent_id][start] += "|" + "(" + str(idx)
    # Add start==end
    for idx, cluster in enumerate(clusters):
        for sent_id, start, end in cluster:
            end = end - 1
            if start == end:
                if cluster_field[sent_id][start] == "":
                    cluster_field[sent_id][start] += "(" + str(idx) + ")"
                else:
                    cluster_field[sent_id][start] += "|" + "(" + str(idx) + ")"
    # Add End
    for idx, cluster in enumerate(clusters):
        for sent_id, start, end in cluster:
            end = end - 1
            if start != end:
                try:
                    if cluster_field[sent_id][end] == "":
                        cluster_field[sent_id][end] += str(idx) + ")"
                    else:
                        cluster_field[sent_id][end] += "|" + str(idx) + ")"
                except:
                    print("ERROR")
                    pass
                # if cluster_field[sent_id][end] == "":
                #     cluster_field[sent_id][end] += str(idx) + ")"
                # else:
                #     cluster_field[sent_id][end] += "|" + str(idx) + ")"

    # Build document
    document.append(begin_line + "\n")
    for sent, speaker, cluster_value in zip(sentences, speakers, cluster_field):
        for j, word in enumerate(sent):
            cluster_id = cluster_value[j]
            if cluster_id == "":
                cluster_id = "-"
            temp = [scene_id, str(part), str(j), word, "na", "na", "na", "na", "na", speaker, "na", "na", "na", cluster_id]
            document.append(" ".join(temp)+ "\n")
        document.append("" + "\n")
    document.append(end_line + "\n")

with open("data/conll_style/dialogue_prob_source_chinese/"+ split+'.chinese.v4_gold_conll', 'w') as f:
    f.writelines(document)

print(len(document))