In [36]:
from pprint import pprint

import datasets
from stanza.utils.conll import FIELD_TO_IDX, CoNLL

In [37]:
dataset = datasets.load_dataset("coref-data/gum_raw")

In [49]:
etypes = set()
for ds in dataset.values():
    for coref_entities in ds["coref_entities"]:
        for entity in coref_entities:
            for mention in entity:
                    etypes.add(mention["span"])

In [51]:
[x for x in etypes if "," in x]

[]

In [48]:
[x for x in etypes if "link:" in x] # "identity:" not in x and "minspan:" not in x and "centering:" not in x]

['link:appos',
 'link:disc',
 'link:pred',
 'link:coref',
 'link:ana',
 'link:cata',
 'link:sgl']

In [33]:
pprint(dataset["train"].features)

{'coref_entities': [[{'eid': Value(dtype='string', id=None),
                      'eid_or_grp': Value(dtype='string', id=None),
                      'etype': Value(dtype='string', id=None),
                      'other': Value(dtype='string', id=None),
                      'sent_id': Value(dtype='string', id=None),
                      'span': Value(dtype='string', id=None)}]],
 'doc_id': Value(dtype='string', id=None),
 'ontogum_coref_chains': Sequence(feature=Sequence(feature=Sequence(feature=Value(dtype='int64',
                                                                                  id=None),
                                                                    length=-1,
                                                                    id=None),
                                                   length=-1,
                                                   id=None),
                                  length=-1,
                                  id=None),
 'ontogum_sent

In [20]:
sents = [0]

In [29]:
def a():
    for sents in dataset["train"]["sentences"]:
        for sent in sents:
            for tok in sent["tokens"]:
                if tok["coref_mentions"]:
                    pprint(tok)
                    return

a()

{'coref_mentions': [{'eid': 'd2.1',
                     'eid_or_grp': '1',
                     'etype': 'person',
                     'other': {'centering': 'cf1',
                               'identity': None,
                               'infstat': 'new',
                               'link': 'coref',
                               'minspan': '1'},
                     'span': '3'}],
 'deprel': 'nsubj:pass',
 'feats': 'NumForm=Digit|NumType=Card',
 'form': '107',
 'head': 4,
 'lemma': '107',
 'ord': 3.0,
 'upos': 'NUM',
 'xpos': 'CD'}


In [13]:
pprint(dataset["train"]["sentences"][0].keys())

AttributeError: 'list' object has no attribute 'keys'

In [15]:
import re
re.search("GUM_(.*)_", "GUM_news_crane").group(1)

'news'

In [2]:
def split_doc_into_doc_parts(example):
    """take a doc and return the doc parts"""
    doc_parts_dict = {} # {0: [], 1:[], ...}
    for sent_dict in example['sentences'][0]:
        sent_part_id = sent_dict['part_id']
        if sent_part_id in doc_parts_dict:
            doc_parts_dict[sent_part_id].append(sent_dict)
        else:
            doc_parts_dict[sent_part_id] = [sent_dict]
    document_id = example['document_id'][0]
    return {'document_id': [f'{document_id}/part_{k}' for k in doc_parts_dict],
            'sentences': [doc_parts_dict[k] for k in doc_parts_dict]}

In [3]:
dataset = datasets.load_dataset("coref-data/conll2012_raw", "english_v4")



Downloading readme: 100%|██████████| 10.1k/10.1k [00:00<00:00, 12.5MB/s]
Downloading data: 100%|██████████| 16.8M/16.8M [00:01<00:00, 11.3MB/s]
Downloading data: 100%|██████████| 2.21M/2.21M [00:00<00:00, 6.16MB/s]
Downloading data: 100%|██████████| 2.24M/2.24M [00:00<00:00, 2.97MB/s]
Generating train split: 1940 examples [00:00, 8520.39 examples/s]
Generating validation split: 222 examples [00:00, 10351.59 examples/s]
Generating test split: 222 examples [00:00, 10224.62 examples/s]


In [4]:
d = dataset.map(
        split_doc_into_doc_parts,
        batched=True,
        batch_size=1,
    )

Map: 100%|██████████| 1940/1940 [00:07<00:00, 255.91 examples/s]
Map: 100%|██████████| 222/222 [00:00<00:00, 240.55 examples/s]
Map: 100%|██████████| 222/222 [00:01<00:00, 221.91 examples/s]


In [5]:
d = dataset.map(
        split_doc_into_doc_parts,
        batched=True,
        batch_size=1,
        num_proc=4
    )

Map (num_proc=4): 100%|██████████| 1940/1940 [00:02<00:00, 765.25 examples/s] 
Map (num_proc=4): 100%|██████████| 222/222 [00:00<00:00, 616.99 examples/s]
Map (num_proc=4): 100%|██████████| 222/222 [00:00<00:00, 578.12 examples/s]


In [6]:
d = dataset.map(
        split_doc_into_doc_parts,
        batched=True,
        batch_size=1,
        num_proc=12
    )

Map (num_proc=12): 100%|██████████| 1940/1940 [00:01<00:00, 1230.95 examples/s]
Map (num_proc=12): 100%|██████████| 222/222 [00:00<00:00, 845.29 examples/s] 
Map (num_proc=12): 100%|██████████| 222/222 [00:00<00:00, 804.42 examples/s] 


In [7]:
d = dataset.map(
        split_doc_into_doc_parts,
        batched=True,
        batch_size=1,
        num_proc=18
    )

Map (num_proc=18): 100%|██████████| 1940/1940 [00:01<00:00, 1357.55 examples/s]
Map (num_proc=18): 100%|██████████| 222/222 [00:00<00:00, 653.50 examples/s] 
Map (num_proc=18): 100%|██████████| 222/222 [00:00<00:00, 651.35 examples/s] 


In [8]:
def sum_list(x: list[int]) -> int:
    return sum(x)

print(sum_list([1, 2, 3]))

6


In [13]:
a, b = [1, 2]