In [23]:

import itertools
import random
from functools import partial

import datasets
import huggingface_hub
from tqdm import tqdm

In [24]:
config_names = [
    'conll2012_indiscrim_english_v4',
    'gum_indiscrim_ontogum',
    'arrau_indiscrim_default',
    'gap_indiscrim_default',
    'davis_pdp_indiscrim_default',
    'preco_indiscrim_default',
    'litbank_indiscrim_split_0',
    'gum_indiscrim_original',
    'phrase_detectives_indiscrim_default',
    'mmc_indiscrim_mmc_en',
    'davis_wsc_indiscrim_wsc273',
    'superglue_wsc_indiscrim_default',
    'dpr_indiscrim_default',
    'knowref_60k_indiscrim_default',
    'pronominal_winogrande_default'
]

# config_names = [
#     'davis_wsc_indiscrim_wsc273',
# ]

In [25]:
def get_token_mapping(sentences, context_start, context_end):
    local_to_global = {}
    global_to_local = {}
    t = 0
    for s_i in range(context_start, context_end):
        sentence = sentences[s_i]
        for i in range(len(sentence["tokens"])):
            local_to_global[(s_i, i)] = t
            global_to_local[t] = (s_i, i)
            t += 1
    return local_to_global, global_to_local


def local_mention_to_global(local_to_global, mention):
    sent, start, end = mention
    return (
                local_to_global[(sent, start)],
                local_to_global[(sent, end)]
            )


def global_mention_to_local(global_to_local, mention):
    start, end = mention
    start_sent, start_tok = global_to_local[start]
    end_sent, end_tok = global_to_local[end]
    assert start_sent == end_sent and end_tok >= start_tok
    return [start_sent, start_tok, end_tok]

In [26]:
def get_examples(config_name, split, dataset, use_local_context, include_speaker):
    examples = []
    for ex in tqdm(dataset):
        sentences = ex["sentences"]

        context_start = 0
        context_end = len(sentences)

        ex_id = ex["id"]
        psent, pstart, pend = ex["pronoun"]
        ex_id = str(ex["id"]) + f"_{psent}_{pstart}_{pend}"
        
        if use_local_context:
            context_start = ex["local_context_start"]
            context_end = ex["local_context_end"]

        local_to_global, global_to_local = get_token_mapping(sentences, context_start, context_end)
        words = [[x["text"] for x in s["tokens"]] for s in sentences[context_start:context_end]]

        speakers = None
        if include_speaker:
            speakers = [[s["speaker"] if s["speaker"] is not None else ""]*len(s["tokens"])
                        for s in sentences[context_start:context_end]]
            speakers = [spk for s in speakers for spk in s]

        lm_to_global = partial(local_mention_to_global, local_to_global)
        mentions = [lm_to_global(ex["pronoun"]),
                    lm_to_global(ex["antecedents"][0]),
                    lm_to_global(ex["distractors"][0])] # (start, end)
        
        # make sure each
        instructions = "Annotate all entity mentions in the following text with coreference clusters. " \
                "Use Markdown tags to indicate clusters in the output, " \
                "with the following format [mention](#cluster_name)\n\n"
        
        passage_words = [w for s in words for w in s]

        global_antecedent = mentions[1]
        expected_output_words = passage_words[global_antecedent[0] : global_antecedent[1] + 1]
        expected_output = " ".join(expected_output_words).lower()

        global_distractor = mentions[2]
        global_distractor_words = passage_words[global_distractor[0] : global_distractor[1] + 1]
        negative_output = " ".join(global_distractor_words).lower()

        global_pronoun = mentions[0]
        assert global_pronoun[0] == global_pronoun[1], "Pronoun should be exactly one word"
        original_pronoun = passage_words[global_pronoun[0]]
        passage_words[global_pronoun[0]] = f"[{original_pronoun}](#)" # add astericks around pronoun

        cluster_ids = [0, 1]
        random.shuffle(cluster_ids)
        gold_id = cluster_ids[0]
        distractor_id = cluster_ids[1]

        # add square brackets to words in passage
        passage_words[global_antecedent[0]] = f"[{passage_words[global_antecedent[0]]}"
        passage_words[global_antecedent[1]] = f"{passage_words[global_antecedent[1]]}](cluster_{gold_id})"

        passage_words[global_distractor[0]] = f"[{passage_words[global_distractor[0]]}"
        passage_words[global_distractor[1]] = f"{passage_words[global_distractor[1]]}](cluster_{distractor_id})"

        def words_to_passage(passage_words):
            if include_speaker:
                last_speaker = None
                passage = ""
                for i, w in enumerate(passage_words):
                    curr_speaker = speakers[i] if speakers[i] else "Anonymous"
                    if curr_speaker != last_speaker:
                        passage += f"\n\n{curr_speaker}:\n"
                        last_speaker = curr_speaker
                    passage += (" " if passage else "") + w
            else:
                passage = " ".join(passage_words)
            return passage
        
        passage = words_to_passage(passage_words)

        question = f"In the above passage, what \"*{original_pronoun}*\" refer to?"

        input_str = instructions + "Input: " + passage + "\n" + \
            "Output: "
        
        passage_words[global_pronoun[0]] = f"[{original_pronoun}](#cluster_{gold_id})"
        expected_output = words_to_passage(passage_words)

        passage_words[global_pronoun[0]] = f"[{original_pronoun}](#cluster_{distractor_id})"
        negative_output = words_to_passage(passage_words)
        
        # (dataset, split, example_id, local_context, include_speaker, input, expected_output)
        output_example = {
            "dataset": config_name,
            "split": split,
            "example_id": ex_id,
            "local_context": use_local_context,
            "include_speaker": include_speaker,
            "input": input_str,
            "expected_output": expected_output,
            "negative_output": negative_output,
            "passage_words": passage_words,
            "mentions": mentions,
        }
        examples.append(output_example)
    return examples

In [27]:
"""
Convert all examples to a GPT-3 style input string.

(dataset, split, example_id, local_context, include_speaker, input, expected_output)

Scored using uncased exact match.

Prompt:
```
Please carefully read the following passages. For each passage, you must identify
which noun the mention marked in *bold* refers to.

Passage: [Tom] and [Mary] go to [the park]. *It* was full of trees.
Question: In the above passage, what does *It* refer to?
Answer: *It* refers to [the park]
```
"""

def get_all_examples(config_name, split, dataset):
    examples = []
    for use_local_context, include_speaker in itertools.product([True, False], [True, False]):
        examples += get_examples(config_name, split, dataset,
                        use_local_context=use_local_context, include_speaker=include_speaker)
    return examples


def main():
    examples = []
    for config_name in config_names:
        dataset_name = "coref-data/pcr_single_antecedent"
        dataset = datasets.load_dataset(dataset_name, config_name)
        for split in ["validation", "test"]:
            if split not in dataset:
                continue
            examples += get_all_examples(config_name, split, dataset[split])
    return examples

data = main()

100%|██████████| 1536/1536 [00:06<00:00, 243.92it/s]
100%|██████████| 1536/1536 [00:06<00:00, 245.20it/s]
100%|██████████| 1536/1536 [00:07<00:00, 203.72it/s]
100%|██████████| 1536/1536 [00:06<00:00, 231.62it/s]
100%|██████████| 1642/1642 [00:07<00:00, 234.04it/s]
100%|██████████| 1642/1642 [00:06<00:00, 252.89it/s]
100%|██████████| 1642/1642 [00:07<00:00, 233.41it/s]
100%|██████████| 1642/1642 [00:07<00:00, 231.65it/s]
100%|██████████| 272/272 [00:02<00:00, 114.35it/s]
100%|██████████| 272/272 [00:02<00:00, 130.73it/s]
100%|██████████| 272/272 [00:02<00:00, 118.46it/s]
100%|██████████| 272/272 [00:02<00:00, 122.57it/s]
100%|██████████| 236/236 [00:02<00:00, 112.57it/s]
100%|██████████| 236/236 [00:01<00:00, 130.33it/s]
100%|██████████| 236/236 [00:01<00:00, 121.29it/s]
100%|██████████| 236/236 [00:01<00:00, 128.89it/s]
100%|██████████| 179/179 [00:03<00:00, 55.69it/s] 
100%|██████████| 179/179 [00:02<00:00, 68.71it/s] 
100%|██████████| 179/179 [00:03<00:00, 54.10it/s] 
100%|██████████

In [28]:
len(data)

152436

In [29]:
for d in data[:1]:
    print(d["input"])
    print("*"*20)
    print(d["expected_output"])

Annotate all entity mentions in the following text with coreference clusters. Use Markdown tags to indicate clusters in the output, with the following format [mention](#cluster_name)

Input: 

Speaker#1:
 [The world 's fifth [Disney](cluster_1) park](cluster_0) will soon open to the public here .

Zhou_liangshuyi:
 The most important thing about Disney is that it is a global brand . Well , for several years , although [it](#) was still under construction and , er , not yet open , it can be said that many people have viewed Hong Kong with new respect .
Output: 
********************


Speaker#1:
 [The world 's fifth [Disney](cluster_1) park](cluster_0) will soon open to the public here .

Zhou_liangshuyi:
 The most important thing about Disney is that it is a global brand . Well , for several years , although [it](#cluster_0) was still under construction and , er , not yet open , it can be said that many people have viewed Hong Kong with new respect .


In [30]:
dataset = datasets.Dataset.from_list(data)
dataset.push_to_hub("coref-data/pcr_doc_prompt", private=True)

Creating parquet from Arrow format: 100%|██████████| 77/77 [00:01<00:00, 66.06ba/s]
Creating parquet from Arrow format: 100%|██████████| 77/77 [00:00<00:00, 375.63ba/s]
Uploading the dataset shards: 100%|██████████| 2/2 [00:23<00:00, 11.63s/it]


CommitInfo(commit_url='https://huggingface.co/datasets/coref-data/pcr_doc_prompt/commit/54f740a82a8a09b6bc3c1bfcedf9e5cfc45dae47', commit_message='Upload dataset', commit_description='', oid='54f740a82a8a09b6bc3c1bfcedf9e5cfc45dae47', pr_url=None, pr_revision=None, pr_num=None)