In [1]:
"""
In this script, we want to convert annotations to conll style. There are several steps:
1.Clustering input data
2.Turn cluster data into conll format
3.Try to run the preprocess script from HOI code base
"""

'\nIn this script, we want to convert annotations to conll style. There are several steps:\n1.clustering input data\n2.Turn cluster data into conll format\n3.Try to run the preprocess script from HOI code base\n'

In [1]:
import pickle as pkl
from copy import deepcopy
import jsonlines
from utils.my_util import cluster_mentions, remove_speaker_prefix
import json

## Prepare Dialogue Data

In [2]:
{
    "0": {
        "train": [0, 1, 2],
        "val": [3],
        "test": [4]
    },
    "1": {
        "train": [0, 1, 2],
        "val": [3],
        "test": [4]
    },
}

{'0': {'train': [0, 1, 2], 'val': [3], 'test': [4]},
 '1': {'train': [0, 1, 2], 'val': [3], 'test': [4]}}

In [None]:
python train_data.py --set=0

In [6]:
speaker_dict = {}
with open('data/raw_source/dialogue_zh/all_coref_data_en.json', 'r') as f:
    temp = json.load(f)
    for line in temp:
        scene_id = line['scene_id']
        speakers = []
        for sent in line['sentences']:
            speakers.append(" ".join(sent[:sent.index(":")]))
        speaker_dict[scene_id] = speakers

split_dict = {"train":[], "dev":[], "test":[]}
with open('data/raw_source/dialogue_zh/dev_temp.pkl', 'rb') as f:
    temp = pkl.load(f)
    for line in temp:
        split_dict['dev'].append(line['scene_id']+"0")
with open('data/raw_source/dialogue_zh/test_temp.pkl', 'rb') as f:
    temp = pkl.load(f)
    for line in temp:
        split_dict['test'].append(line['scene_id']+"0")
with open('data/raw_source/dialogue_zh/train_temp.pkl', 'rb') as f:
    temp = pkl.load(f)
    for line in temp:
        split_dict['train'].append(line['scene_id'])

In [5]:
def remove_empty_sentences(instance):
    sentences = instance['sentences']
    answers = instance['answers']
    speakers = instance['speakers']

    # Build old sent_id to new sent_id map
    map_sent_id = {}
    count = 0
    for i, sent in enumerate(sentences):
        if sent == []:
            continue
        map_sent_id[i] = count
        count += 1

    # Collect answers, speakers for each sentence
    temp = []
    for i, sent in enumerate(sentences):
        if sent == []:
            continue
        annotations = []
        for answer in answers:
            if answer[0][0]==i:
                annotations.append(answer)
        temp.append([sent, annotations, speakers[i]])

    # Change Sentence ID
    sentences = []
    answers = []
    speakers = []
    for i, (sent, annotations, speaker) in enumerate(temp):
        # print(i, speaker, sent)
        sentences.append(sent)
        temp_answers = []
        for query, antecedents in annotations:
            new_query = tuple((map_sent_id[query[0]], query[1], query[2], query[3]))
            # print(query, new_query)
            new_antecedents = []
            if isinstance(antecedents, str):
                new_antecedents = antecedents
                # print(new_antecedents)
            else:
                # print(antecedents)
                for antecedent in antecedents:
                    new_antecedents.append((map_sent_id[antecedent[0]], antecedent[1], antecedent[2], antecedent[3]))
            # print(new_antecedents)
            temp_answers.append([new_query, new_antecedents])
        answers.extend(temp_answers)
        speakers.append(speaker)

    return {
        "sentences": sentences,
        "answers": answers,
        "speakers": speakers,
        "scene_id": instance['scene_id']
    }

In [None]:
def cluster_mentions(answers, sentences):
    """
    Cluster mention including plural. The clustering steps are as follows:
    1. Gather all non-plural query-annotation pairs
    2. Merge all non-plural pairs to build big cluster
    3. Add Plural: turn each mention in plural into {1.speaker name, 2.cluster id, 3.turn sent_id, start, end into special identification}, then use these to build a string for clustering
    4. Add each pair to cluster and do merging
    5. Remove strings from each cluster
    ?One Thing to Consider: Whether use speaker mention in sentence to merge to speaker cluster?
    """
    all_clusters = []
    speaker_set = set()
    # Generate all cluster (no plural), we will add plural latter
    for query, annotations in answers:
        if isinstance(annotations, str):
            if annotations=="notPresent":
                all_clusters.append([query])
        elif len(annotations)==1:
            temp = [query]
            for token in annotations:
                if token[1]==0:
                    try:
                        speaker = " ".join(sentences[token[0]][token[1]: sentences[token[0]].index(":")]).lower()
                        temp.append(speaker)
                        speaker_set.add(speaker)
                    except:
                        continue
                else:
                    temp.append(token)
            all_clusters.append(temp)

    # Merge clusters if any clusters have common mentions
    merged_clusters = []
    for cluster in all_clusters:
        existing = None
        for mention in cluster:
            for merged_cluster in merged_clusters:
                if mention in merged_cluster:
                    existing = merged_cluster
                    break
            if existing is not None:
                break
        if existing is not None:
            existing.update(cluster)
        else:
            merged_clusters.append(set(cluster))
    merged_clusters = [list(cluster) for cluster in merged_clusters]

    # Add Plural
    for query, annotations in answers:
        if isinstance(annotations, list):
            if len(annotations)>1:
                temp_anno = []
                for token in annotations:
                    if token[1]==0:
                        speaker = " ".join(sentences[token[0]][token[1]: sentences[token[0]].index(":")]).lower()
                        temp_anno.append(speaker)
                        speaker_set.add(speaker)
                    # if token[1]==0:
                    #     try:
                    #         speaker = " ".join(sentences[token[0]][token[1]: sentences[token[0]].index(":")]).lower()
                    #         temp_anno.append(speaker)
                    #         speaker_set.add(speaker)
                    #     except:
                    #         continue
                    else:
                        # If the cluster is already in cluster, use the cluster id as identification, else use the index
                        cluster_idx = -1
                        for idx, cluster in enumerate(merged_clusters):
                            if token in cluster:
                                cluster_idx = idx
                                break
                        if cluster_idx != -1:
                            temp_anno.append(str(cluster_idx))
                        else:
                            temp_anno.append("*" + "*".join([str(num) for num in token]) + "*")
                temp_cluster = [query, "||".join(sorted(temp_anno))]
                merged_clusters.append(temp_cluster)

    # Merge Plural
    all_clusters = deepcopy(merged_clusters)
    merged_clusters = []
    for cluster in all_clusters:
        existing = None
        for mention in cluster:
            for merged_cluster in merged_clusters:
                if mention in merged_cluster:
                    existing = merged_cluster
                    break
            if existing is not None:
                break
        if existing is not None:
            existing.update(cluster)
        else:
            merged_clusters.append(set(cluster))
    merged_clusters = [list(cluster) for cluster in merged_clusters]

    output = []
    for cluster in merged_clusters:
        output.append([token for token in cluster if isinstance(token, tuple)])

    return output


In [41]:
data = []
all_ids = []
with open('data/raw_source/dialogue_zh/all_coref_data_en.json', 'r') as f:
# with open('data/raw_source/dialogue_zh/dev-test-batch1_zh.json', 'r') as f:
    reader = jsonlines.Reader(f)
    for bulk in reader:
        for idx, instance in enumerate(bulk):
            # if idx>=5:
            #     break
            scene_id = instance['scene_id']
            if scene_id == "":
                continue
            sentences = instance['sentences']
            # print(sentences)
            # sentences = [[token for token in "".join(sent)] for sent in instance['sentences']]
            annotations = instance['annotations']
            all_ids.append(scene_id)
            speakers = speaker_dict[scene_id]
            answers = []
            for item in annotations:
                query = (item['query']['sentenceIndex'], item['query']['startToken'], item['query']['endToken'])
                antecedents = item['antecedents']
                # print(query)
                # print(antecedents)
                # print()
                if antecedents in [['n', 'o', 't', 'P', 'r', 'e', 's', 'e', 'n', 't'], ['null_projection'], ['empty_subtitle']]:
                    answers.append([query, "notPresent"])
                else:
                    temp_answer = []
                    for antecedent in antecedents:
                        if isinstance(antecedent, dict):
                            temp_answer.append((antecedent['sentenceIndex'], antecedent['startToken'], antecedent['endToken']))
                        else:
                            temp_answer = " ".join(antecedents)
                    answers.append([query, temp_answer])
            data.append(remove_empty_sentences({
                "sentences": sentences,
                "answers": answers,
                "speakers": speakers,
                "scene_id": scene_id
            }))

IndexError: tuple index out of range

In [42]:
with open('en_mention_id_cluster.pkl', 'rb') as f:
    en_mention_id_clusters = pkl.load(f)

In [43]:
"""
Cluster Chinese Mentions using English-Side Mention_IDs
"""
data = []
all_ids = []
with open('data/raw_source/dialogue_zh/all_coref_data_en_zh_with_id.json', 'r') as f:
# with open('data/raw_source/dialogue_zh/dev-test-batch1_zh.json', 'r') as f:
    reader = jsonlines.Reader(f)
    for bulk in reader:
        for idx, instance in enumerate(bulk):
            if idx>=10:
                break
            scene_id = instance['scene_id']
            if scene_id == "":
                continue
            sentences = instance['sentences']
            # print(sentences)
            # sentences = [[token for token in "".join(sent)] for sent in instance['sentences']]
            annotations = instance['annotations']
            all_ids.append(scene_id)
            speakers = speaker_dict[scene_id]
            answers = []
            for item in annotations:
                query = (item['query']['sentenceIndex'], item['query']['startToken'], item['query']['endToken'], item['query']['mention_id'])
                antecedents = item['antecedents']
                if antecedents in [['n', 'o', 't', 'P', 'r', 'e', 's', 'e', 'n', 't'], ['null_projection'], ['empty_subtitle']]:
                    answers.append([query, "notPresent"])
                else:
                    temp_answer = []
                    for antecedent in antecedents:
                        if isinstance(antecedent, dict):
                            temp_answer.append((antecedent['sentenceIndex'], antecedent['startToken'], antecedent['endToken'], antecedent['mention_id']))
                        else:
                            temp_answer = " ".join(antecedents)
                    answers.append([query, temp_answer])
            data.append(remove_empty_sentences({
                "sentences": sentences,
                "answers": answers,
                "speakers": speakers,
                "scene_id": scene_id
            }))

In [None]:
correction_dict = {(5, 7, 8): (5, 7, 9), (7, 13, 17): (7, 15, 17), (11, 0, 1): (11, 0, 2), (11, 5, 7): (11, 4, 7), (16, 0, 3): (16, 0, 1)}
for item in data:
    if item['scene_id'] in ['s07e17c08t1']:
        print(item['sentences'])
        print()
        for token in item['answers']:
            print(token)
        print()
        print(item['speakers'])
        print('=='*50)

In [44]:
def cluster_mention_by_en_id(answers, sentences):
    """
    We cluster Chinese Side mentions according to the index in English Side
    """
    print(answers)
    print()
    print(sentences)
    print("=="*50)
    pass

In [45]:
split = "test"

document = []
for i in range(len(data)):
    if i >= 2:
        continue
    sample = data[i]
    if sample['scene_id'] not in split_dict[split]:
        continue
    # print(sample)
    # print()
    # original_sentences = sample['sentences']
    # original_clusters = cluster_mentions(sample['answers'], original_sentences)
    # sentences, clusters, speakers = remove_speaker_prefix(original_sentences, original_clusters)
    sentences = sample['sentences']
    cluster_mention_by_en_id(sample['answers'], sentences)
    clusters = cluster_mentions(sample['answers'], sentences)
    speakers = sample['speakers']
    scene_id = sample['scene_id']
    part = int(scene_id[7:9])
    begin_line = "#begin document " + "(" + scene_id + "); part " + "%03d" % part
    end_line = "#end document"

    # Prepare for clustering
    cluster_field = []
    for sent in sentences:
        cluster_field.append([""]*len(sent))
    # Add start
    for idx, cluster in enumerate(clusters):
        for sent_id, start, end in cluster:
            end = end - 1
            if start != end:
                # print(cluster_field[sent_id])
                # print(sent_id, start, end, len(sentences[sent_id]))
                # print(sentences[sent_id])
                if cluster_field[sent_id][start] == "":
                    cluster_field[sent_id][start] += "(" + str(idx)
                else:
                    cluster_field[sent_id][start] += "|" + "(" + str(idx)
    # Add start==end
    for idx, cluster in enumerate(clusters):
        for sent_id, start, end in cluster:
            end = end - 1
            if start == end:
                if cluster_field[sent_id][start] == "":
                    cluster_field[sent_id][start] += "(" + str(idx) + ")"
                else:
                    cluster_field[sent_id][start] += "|" + "(" + str(idx) + ")"
    # Add End
    for idx, cluster in enumerate(clusters):
        for sent_id, start, end in cluster:
            end = end - 1
            if start != end:
                try:
                    if cluster_field[sent_id][end] == "":
                        cluster_field[sent_id][end] += str(idx) + ")"
                    else:
                        cluster_field[sent_id][end] += "|" + str(idx) + ")"
                except:
                    pass
                # if cluster_field[sent_id][end] == "":
                #     cluster_field[sent_id][end] += str(idx) + ")"
                # else:
                #     cluster_field[sent_id][end] += "|" + str(idx) + ")"

    # Build document
    document.append(begin_line + "\n")
    for sent, speaker, cluster_value in zip(sentences, speakers, cluster_field):
        for j, word in enumerate(sent):
            cluster_id = cluster_value[j]
            if cluster_id == "":
                cluster_id = "-"
            temp = [scene_id, str(part), str(j), word, "na", "na", "na", "na", "na", speaker, "na", "na", "na", cluster_id]
            document.append(" ".join(temp)+ "\n")
        document.append("" + "\n")
    document.append(end_line + "\n")

# with open("data/conll_style/dialogue_chinese/"+ split+'.chinese.v4_gold_conll', 'w') as f:
#     f.writelines(document)

## Build Chinese Sample

In [44]:
print(len(data))

1515


In [50]:
split = "train"

document = []
for i in range(len(data[:1])):
    sample = data[i]
    if sample['scene_id'] not in split_dict[split]:
        continue
    # print(sample)
    # print()
    # original_sentences = sample['sentences']
    # original_clusters = cluster_mentions(sample['answers'], original_sentences)
    # sentences, clusters, speakers = remove_speaker_prefix(original_sentences, original_clusters)
    sentences = sample['sentences']
    clusters = cluster_mentions(sample['answers'], sentences)
    speakers = sample['speakers']
    scene_id = sample['scene_id']
    part = int(scene_id[7:9])
    begin_line = "#begin document " + "(" + scene_id + "); part " + "%03d" % part
    end_line = "#end document"

    # Prepare for clustering
    cluster_field = []
    for sent in sentences:
        cluster_field.append([""]*len(sent))
    # Add start
    for idx, cluster in enumerate(clusters):
        for sent_id, start, end in cluster:
            end = end - 1
            if start != end:
                # print(cluster_field[sent_id])
                # print(sent_id, start, end, len(sentences[sent_id]))
                # print(sentences[sent_id])
                if cluster_field[sent_id][start] == "":
                    cluster_field[sent_id][start] += "(" + str(idx)
                else:
                    cluster_field[sent_id][start] += "|" + "(" + str(idx)
    # Add start==end
    for idx, cluster in enumerate(clusters):
        for sent_id, start, end in cluster:
            end = end - 1
            if start == end:
                if cluster_field[sent_id][start] == "":
                    cluster_field[sent_id][start] += "(" + str(idx) + ")"
                else:
                    cluster_field[sent_id][start] += "|" + "(" + str(idx) + ")"
    # Add End
    for idx, cluster in enumerate(clusters):
        for sent_id, start, end in cluster:
            end = end - 1
            if start != end:
                try:
                    if cluster_field[sent_id][end] == "":
                        cluster_field[sent_id][end] += str(idx) + ")"
                    else:
                        cluster_field[sent_id][end] += "|" + str(idx) + ")"
                except:
                    pass
                # if cluster_field[sent_id][end] == "":
                #     cluster_field[sent_id][end] += str(idx) + ")"
                # else:
                #     cluster_field[sent_id][end] += "|" + str(idx) + ")"

    # Build document
    document.append(begin_line + "\n")
    for sent, speaker, cluster_value in zip(sentences, speakers, cluster_field):
        for j, word in enumerate(sent):
            cluster_id = cluster_value[j]
            if cluster_id == "":
                cluster_id = "-"
            temp = [scene_id, str(part), str(j), word, "na", "na", "na", "na", "na", speaker, "na", "na", "na", cluster_id]
            document.append(" ".join(temp)+ "\n")
        document.append("" + "\n")
    document.append(end_line + "\n")

with open("data/conll_style/overfit_chinese/"+ "train"+'.chinese.v4_gold_conll', 'w') as f:
    f.writelines(document)
with open("data/conll_style/overfit_chinese/"+ "dev"+'.chinese.v4_gold_conll', 'w') as f:
    f.writelines(document)
with open("data/conll_style/overfit_chinese/"+ "test"+'.chinese.v4_gold_conll', 'w') as f:
    f.writelines(document)

In [45]:
print(len(document))

44425


In [97]:
file_name = "test"
data = []
with open('data/'+ file_name+'_temp.pkl', 'rb') as f:
    data.extend(pkl.load(f))

document = []
for i in range(len(data)):
    if file_name=="train" and i==38:
        continue
    if file_name=="test" and i==28:
        continue

    # if i>=100:
    #     continue
    sample = data[i]
    original_sentences = sample['sentences']
    original_clusters = cluster_mentions(sample['answers'], original_sentences)

    # Get Data ready for conversion
    sentences, clusters, speakers = remove_speaker_prefix(original_sentences, original_clusters)
    scene_id = sample['scene_id']
    part = int(scene_id[7:9])
    begin_line = "#begin document " + "(" + scene_id + "); part " + "%03d" % part
    end_line = "#end document"

    # Prepare for clustering
    cluster_field = []
    for sent in sentences:
        cluster_field.append([""]*len(sent))
    # Add start
    for idx, cluster in enumerate(clusters):
        for sent_id, start, end in cluster:
            end = end - 1
            if start != end:
                if cluster_field[sent_id][start] == "":
                    cluster_field[sent_id][start] += "(" + str(idx)
                else:
                    cluster_field[sent_id][start] += "|" + "(" + str(idx)
    # Add start==end
    for idx, cluster in enumerate(clusters):
        for sent_id, start, end in cluster:
            end = end - 1
            if start == end:
                if cluster_field[sent_id][start] == "":
                    cluster_field[sent_id][start] += "(" + str(idx) + ")"
                else:
                    cluster_field[sent_id][start] += "|" + "(" + str(idx) + ")"
    # Add End
    for idx, cluster in enumerate(clusters):
        for sent_id, start, end in cluster:
            end = end - 1
            if start != end:
                try:
                    if cluster_field[sent_id][end] == "":
                        cluster_field[sent_id][end] += str(idx) + ")"
                    else:
                        cluster_field[sent_id][end] += "|" + str(idx) + ")"
                except:
                    pass
                # if cluster_field[sent_id][end] == "":
                #     cluster_field[sent_id][end] += str(idx) + ")"
                # else:
                #     cluster_field[sent_id][end] += "|" + str(idx) + ")"

    # Build document
    document.append(begin_line + "\n")
    for sent, speaker, cluster_value in zip(sentences, speakers, cluster_field):
        for j, word in enumerate(sent):
            cluster_id = cluster_value[j]
            if cluster_id == "":
                cluster_id = "-"
            temp = [scene_id, str(part), str(j), word, "na", "na", "na", "na", "na", speaker, "na", "na", "na", cluster_id]
            document.append(" ".join(temp)+ "\n")
        document.append("" + "\n")
    document.append(end_line + "\n")

with open("data/input/"+ file_name+'.english.v4_gold_conll', 'w') as f:
    f.writelines(document)