In [26]:
import datasets
from stanza.utils.conll import FIELD_TO_IDX, CoNLL

In [27]:
conllu_dataset = datasets.load_dataset("coref-data/conll2012_conllu")

In [28]:
for ex in conllu_dataset["train"].filter(lambda x: x["doc_name"] == "mz/sinorama/10/ectb_1030/part_0").iter(1):
    print(ex["sentences"][0])
    print(CoNLL.convert_conll(ex["sentences"][0]))
    break

Filter:   0%|          | 0/2802 [00:02<?, ? examples/s]


KeyError: 'document_id'

In [22]:
conllu_dataset.map(lambda _: {"sentences" : "hi"})

Map: 100%|██████████| 2802/2802 [00:00<00:00, 51273.68 examples/s]
Map: 100%|██████████| 343/343 [00:00<00:00, 43407.24 examples/s]
Map: 100%|██████████| 348/348 [00:00<00:00, 45178.22 examples/s]


DatasetDict({
    train: Dataset({
        features: ['doc_name', 'sentences', 'coref_chains'],
        num_rows: 2802
    })
    validation: Dataset({
        features: ['doc_name', 'sentences', 'coref_chains'],
        num_rows: 343
    })
    test: Dataset({
        features: ['doc_name', 'sentences', 'coref_chains'],
        num_rows: 348
    })
})

In [2]:
def split_doc_into_doc_parts(example):
    """take a doc and return the doc parts"""
    doc_parts_dict = {} # {0: [], 1:[], ...}
    for sent_dict in example['sentences'][0]:
        sent_part_id = sent_dict['part_id']
        if sent_part_id in doc_parts_dict:
            doc_parts_dict[sent_part_id].append(sent_dict)
        else:
            doc_parts_dict[sent_part_id] = [sent_dict]
    document_id = example['document_id'][0]
    return {'document_id': [f'{document_id}/part_{k}' for k in doc_parts_dict],
            'sentences': [doc_parts_dict[k] for k in doc_parts_dict]}

In [3]:
dataset = datasets.load_dataset("coref-data/conll2012_raw", "english_v4")



Downloading readme: 100%|██████████| 10.1k/10.1k [00:00<00:00, 12.5MB/s]
Downloading data: 100%|██████████| 16.8M/16.8M [00:01<00:00, 11.3MB/s]
Downloading data: 100%|██████████| 2.21M/2.21M [00:00<00:00, 6.16MB/s]
Downloading data: 100%|██████████| 2.24M/2.24M [00:00<00:00, 2.97MB/s]
Generating train split: 1940 examples [00:00, 8520.39 examples/s]
Generating validation split: 222 examples [00:00, 10351.59 examples/s]
Generating test split: 222 examples [00:00, 10224.62 examples/s]


In [4]:
d = dataset.map(
        split_doc_into_doc_parts,
        batched=True,
        batch_size=1,
    )

Map: 100%|██████████| 1940/1940 [00:07<00:00, 255.91 examples/s]
Map: 100%|██████████| 222/222 [00:00<00:00, 240.55 examples/s]
Map: 100%|██████████| 222/222 [00:01<00:00, 221.91 examples/s]


In [5]:
d = dataset.map(
        split_doc_into_doc_parts,
        batched=True,
        batch_size=1,
        num_proc=4
    )

Map (num_proc=4): 100%|██████████| 1940/1940 [00:02<00:00, 765.25 examples/s] 
Map (num_proc=4): 100%|██████████| 222/222 [00:00<00:00, 616.99 examples/s]
Map (num_proc=4): 100%|██████████| 222/222 [00:00<00:00, 578.12 examples/s]


In [6]:
d = dataset.map(
        split_doc_into_doc_parts,
        batched=True,
        batch_size=1,
        num_proc=12
    )

Map (num_proc=12): 100%|██████████| 1940/1940 [00:01<00:00, 1230.95 examples/s]
Map (num_proc=12): 100%|██████████| 222/222 [00:00<00:00, 845.29 examples/s] 
Map (num_proc=12): 100%|██████████| 222/222 [00:00<00:00, 804.42 examples/s] 


In [7]:
d = dataset.map(
        split_doc_into_doc_parts,
        batched=True,
        batch_size=1,
        num_proc=18
    )

Map (num_proc=18): 100%|██████████| 1940/1940 [00:01<00:00, 1357.55 examples/s]
Map (num_proc=18): 100%|██████████| 222/222 [00:00<00:00, 653.50 examples/s] 
Map (num_proc=18): 100%|██████████| 222/222 [00:00<00:00, 651.35 examples/s] 


In [8]:
def sum_list(x: list[int]) -> int:
    return sum(x)

print(sum_list([1, 2, 3]))

6


In [13]:
a, b = [1, 2]