/
arrow_dataset.py
6254 lines (5503 loc) 路 282 KB
/
arrow_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2020 The HuggingFace Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
""" Simple Dataset wrapping an Arrow Table."""
import contextlib
import copy
import itertools
import json
import os
import posixpath
import re
import shutil
import sys
import tempfile
import time
import warnings
import weakref
from collections import Counter
from collections.abc import Mapping
from copy import deepcopy
from fnmatch import fnmatch
from functools import partial, wraps
from io import BytesIO
from math import ceil, floor
from pathlib import Path
from random import sample
from typing import (
TYPE_CHECKING,
Any,
BinaryIO,
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
Tuple,
Union,
overload,
)
from typing import Sequence as Sequence_
import fsspec
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.compute as pc
from huggingface_hub import DatasetCard, DatasetCardData, HfApi, HfFolder
from multiprocess import Pool
from requests import HTTPError
from . import config
from .arrow_reader import ArrowReader
from .arrow_writer import ArrowWriter, OptimizedTypedSequence
from .data_files import sanitize_patterns
from .download.download_config import DownloadConfig
from .download.streaming_download_manager import xgetsize
from .features import Audio, ClassLabel, Features, Image, Sequence, Value
from .features.features import (
FeatureType,
_align_features,
_check_if_features_can_be_aligned,
generate_from_arrow_type,
pandas_types_mapper,
require_decoding,
)
from .filesystems import extract_path_from_uri, is_remote_filesystem
from .fingerprint import (
fingerprint_transform,
format_kwargs_for_fingerprint,
format_transform_for_fingerprint,
generate_fingerprint,
generate_random_fingerprint,
get_temporary_cache_files_directory,
is_caching_enabled,
maybe_register_dataset_for_temp_dir_deletion,
update_fingerprint,
validate_fingerprint,
)
from .formatting import format_table, get_format_type_from_alias, get_formatter, query_table
from .formatting.formatting import LazyDict, _is_range_contiguous
from .info import DatasetInfo, DatasetInfosDict
from .naming import _split_re
from .search import IndexableMixin
from .splits import NamedSplit, Split, SplitDict, SplitInfo
from .table import (
InMemoryTable,
MemoryMappedTable,
Table,
_memory_mapped_record_batch_reader_from_file,
cast_array_to_feature,
concat_tables,
embed_table_storage,
list_table_cache_files,
table_cast,
table_iter,
table_visitor,
)
from .tasks import TaskTemplate
from .utils import logging
from .utils.deprecation_utils import deprecated
from .utils.file_utils import _retry, cached_path, estimate_dataset_size
from .utils.hub import hf_hub_url
from .utils.info_utils import is_small_dataset
from .utils.metadata import MetadataConfigs
from .utils.py_utils import (
Literal,
asdict,
convert_file_size_to_int,
glob_pattern_to_regex,
iflatmap_unordered,
string_to_dict,
unique_values,
)
from .utils.stratify import stratified_shuffle_split_generate_indices
from .utils.tf_utils import dataset_to_tf, minimal_tf_collate_fn, multiprocess_dataset_to_tf
from .utils.typing import ListLike, PathLike
if TYPE_CHECKING:
import sqlite3
import pyspark
import sqlalchemy
from .dataset_dict import DatasetDict
from .iterable_dataset import IterableDataset
logger = logging.get_logger(__name__)
PUSH_TO_HUB_WITHOUT_METADATA_CONFIGS_SPLIT_PATTERN_SHARDED = (
"data/{split}-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.parquet"
)
class DatasetInfoMixin:
"""This base class exposes some attributes of DatasetInfo
at the base level of the Dataset for easy access.
"""
def __init__(self, info: DatasetInfo, split: Optional[NamedSplit]):
self._info = info
self._split = split
@property
def info(self):
"""[`~datasets.DatasetInfo`] object containing all the metadata in the dataset."""
return self._info
@property
def split(self):
"""[`~datasets.NamedSplit`] object corresponding to a named dataset split."""
return self._split
@property
def builder_name(self) -> str:
return self._info.builder_name
@property
def citation(self) -> str:
return self._info.citation
@property
def config_name(self) -> str:
return self._info.config_name
@property
def dataset_size(self) -> Optional[int]:
return self._info.dataset_size
@property
def description(self) -> str:
return self._info.description
@property
def download_checksums(self) -> Optional[dict]:
return self._info.download_checksums
@property
def download_size(self) -> Optional[int]:
return self._info.download_size
@property
def features(self) -> Optional[Features]:
return self._info.features.copy() if self._info.features is not None else None
@property
def homepage(self) -> Optional[str]:
return self._info.homepage
@property
def license(self) -> Optional[str]:
return self._info.license
@property
def size_in_bytes(self) -> Optional[int]:
return self._info.size_in_bytes
@property
def supervised_keys(self):
return self._info.supervised_keys
@property
def task_templates(self):
return self._info.task_templates
@property
def version(self):
return self._info.version
class TensorflowDatasetMixin:
_TF_DATASET_REFS = set()
@staticmethod
def _get_output_signature(
dataset: "Dataset",
collate_fn: Callable,
collate_fn_args: dict,
cols_to_retain: Optional[List[str]] = None,
batch_size: Optional[int] = None,
num_test_batches: int = 20,
):
"""Private method used by `to_tf_dataset()` to find the shapes and dtypes of samples from this dataset
after being passed through the collate_fn. Tensorflow needs an exact signature for tf.numpy_function, so
the only way to do this is to run test batches - the collator may add or rename columns, so we can't figure
it out just by inspecting the dataset.
Args:
dataset (`Dataset`): Dataset to load samples from.
collate_fn(`bool`): Shuffle the dataset order when loading. Recommended True for training, False for
validation/evaluation.
collate_fn(`Callable`): A function or callable object (such as a `DataCollator`) that will collate
lists of samples into a batch.
collate_fn_args (`Dict`): A `dict` of keyword arguments to be passed to the
`collate_fn`.
batch_size (`int`, optional): The size of batches loaded from the dataset. Used for shape inference.
Can be None, which indicates that batch sizes can be variable.
num_test_batches (`int`): The number of batches to load from the dataset for shape inference.
Returns:
`dict`: Dict mapping column names to tf.Tensorspec objects
`dict`: Dict mapping column names to np.dtype objects
"""
if config.TF_AVAILABLE:
import tensorflow as tf
else:
raise ImportError("Called a Tensorflow-specific function but Tensorflow is not installed.")
if len(dataset) == 0:
raise ValueError("Unable to get the output signature because the dataset is empty.")
if batch_size is not None:
batch_size = min(len(dataset), batch_size)
test_batch_size = 1
if cols_to_retain is not None:
cols_to_retain = list(set(cols_to_retain + ["label_ids", "label", "labels"]))
test_batches = []
for _ in range(num_test_batches):
indices = sample(range(len(dataset)), test_batch_size)
test_batch = dataset[indices]
if cols_to_retain is not None:
test_batch = {key: value for key, value in test_batch.items() if key in cols_to_retain}
test_batch = [{key: value[i] for key, value in test_batch.items()} for i in range(test_batch_size)]
test_batch = collate_fn(test_batch, **collate_fn_args)
test_batches.append(test_batch)
tf_columns_to_signatures = {}
np_columns_to_dtypes = {}
for column in test_batches[0].keys():
raw_arrays = [batch[column] for batch in test_batches]
# In case the collate_fn returns something strange
np_arrays = []
for array in raw_arrays:
if isinstance(array, np.ndarray):
np_arrays.append(array)
elif isinstance(array, tf.Tensor):
np_arrays.append(array.numpy())
else:
np_arrays.append(np.array(array))
if np.issubdtype(np_arrays[0].dtype, np.integer) or np_arrays[0].dtype == bool:
tf_dtype = tf.int64
np_dtype = np.int64
elif np.issubdtype(np_arrays[0].dtype, np.number):
tf_dtype = tf.float32
np_dtype = np.float32
elif np_arrays[0].dtype.kind == "U": # Unicode strings
np_dtype = np.unicode_
tf_dtype = tf.string
else:
raise RuntimeError(
f"Unrecognized array dtype {np_arrays[0].dtype}. \n"
"Nested types and image/audio types are not supported yet."
)
shapes = [array.shape for array in np_arrays]
static_shape = []
for dim in range(len(shapes[0])):
sizes = {shape[dim] for shape in shapes}
if dim == 0:
static_shape.append(batch_size)
continue
if len(sizes) == 1: # This dimension looks constant
static_shape.append(sizes.pop())
else: # Use None for variable dimensions
static_shape.append(None)
tf_columns_to_signatures[column] = tf.TensorSpec(shape=static_shape, dtype=tf_dtype)
np_columns_to_dtypes[column] = np_dtype
return tf_columns_to_signatures, np_columns_to_dtypes
def to_tf_dataset(
self,
batch_size: Optional[int] = None,
columns: Optional[Union[str, List[str]]] = None,
shuffle: bool = False,
collate_fn: Optional[Callable] = None,
drop_remainder: bool = False,
collate_fn_args: Optional[Dict[str, Any]] = None,
label_cols: Optional[Union[str, List[str]]] = None,
prefetch: bool = True,
num_workers: int = 0,
num_test_batches: int = 20,
):
"""Create a `tf.data.Dataset` from the underlying Dataset. This `tf.data.Dataset` will load and collate batches from
the Dataset, and is suitable for passing to methods like `model.fit()` or `model.predict()`. The dataset will yield
`dicts` for both inputs and labels unless the `dict` would contain only a single key, in which case a raw
`tf.Tensor` is yielded instead.
Args:
batch_size (`int`, *optional*):
Size of batches to load from the dataset. Defaults to `None`, which implies that the dataset won't be
batched, but the returned dataset can be batched later with `tf_dataset.batch(batch_size)`.
columns (`List[str]` or `str`, *optional*):
Dataset column(s) to load in the `tf.data.Dataset`.
Column names that are created by the `collate_fn` and that do not exist in the original dataset can be used.
shuffle(`bool`, defaults to `False`):
Shuffle the dataset order when loading. Recommended `True` for training, `False` for
validation/evaluation.
drop_remainder(`bool`, defaults to `False`):
Drop the last incomplete batch when loading. Ensures
that all batches yielded by the dataset will have the same length on the batch dimension.
collate_fn(`Callable`, *optional*):
A function or callable object (such as a `DataCollator`) that will collate
lists of samples into a batch.
collate_fn_args (`Dict`, *optional*):
An optional `dict` of keyword arguments to be passed to the
`collate_fn`.
label_cols (`List[str]` or `str`, defaults to `None`):
Dataset column(s) to load as labels.
Note that many models compute loss internally rather than letting Keras do it, in which case
passing the labels here is optional, as long as they're in the input `columns`.
prefetch (`bool`, defaults to `True`):
Whether to run the dataloader in a separate thread and maintain
a small buffer of batches for training. Improves performance by allowing data to be loaded in the
background while the model is training.
num_workers (`int`, defaults to `0`):
Number of workers to use for loading the dataset. Only supported on Python versions >= 3.8.
num_test_batches (`int`, defaults to `20`):
Number of batches to use to infer the output signature of the dataset.
The higher this number, the more accurate the signature will be, but the longer it will take to
create the dataset.
Returns:
`tf.data.Dataset`
Example:
```py
>>> ds_train = ds["train"].to_tf_dataset(
... columns=['input_ids', 'token_type_ids', 'attention_mask', 'label'],
... shuffle=True,
... batch_size=16,
... collate_fn=data_collator,
... )
```
"""
if config.TF_AVAILABLE:
import tensorflow as tf
else:
raise ImportError("Called a Tensorflow-specific function but Tensorflow is not installed.")
if (isinstance(columns, list) and len(columns) == 1) or (
isinstance(label_cols, list) and len(label_cols) == 1
):
warnings.warn(
"The output of `to_tf_dataset` will change when a passing single element list for `labels` or "
"`columns` in the next datasets version. To return a tuple structure rather than dict, pass a "
"single string.\n"
"Old behaviour: columns=['a'], labels=['labels'] -> (tf.Tensor, tf.Tensor) \n"
" : columns='a', labels='labels' -> (tf.Tensor, tf.Tensor) \n"
"New behaviour: columns=['a'],labels=['labels'] -> ({'a': tf.Tensor}, {'labels': tf.Tensor}) \n"
" : columns='a', labels='labels' -> (tf.Tensor, tf.Tensor) ",
FutureWarning,
)
if isinstance(tf.distribute.get_strategy(), tf.distribute.TPUStrategy):
logger.warning(
"Note that to_tf_dataset() loads the data with a generator rather than a full tf.data "
"pipeline and is not compatible with remote TPU connections. If you encounter errors, please "
"try using a TPU VM or, if your data can fit in memory, loading it into memory as a dict of "
"Tensors instead of streaming with to_tf_dataset()."
)
if collate_fn is None:
# Set a very simple default collator that just stacks things together
collate_fn = minimal_tf_collate_fn
if collate_fn_args is None:
collate_fn_args = {}
if label_cols and not columns:
raise ValueError("Cannot specify label_cols without specifying columns!")
if label_cols is None:
label_cols = []
elif isinstance(label_cols, str):
label_cols = [label_cols]
if len(set(label_cols)) < len(label_cols):
raise ValueError("List of label_cols contains duplicates.")
if columns:
if isinstance(columns, str):
columns = [columns]
if len(set(columns)) < len(columns):
raise ValueError("List of columns contains duplicates.")
cols_to_retain = list(set(columns + label_cols))
else:
cols_to_retain = None # Indicates keeping all valid columns
columns = []
if self.format["type"] not in ["custom", "numpy"]:
dataset = self.with_format("numpy")
else:
dataset = self
# TODO(Matt, QL): deprecate the retention of label_ids and label
output_signature, columns_to_np_types = dataset._get_output_signature(
dataset,
collate_fn=collate_fn,
collate_fn_args=collate_fn_args,
cols_to_retain=cols_to_retain,
batch_size=batch_size if drop_remainder and batch_size is not None else None,
num_test_batches=num_test_batches,
)
if "labels" in output_signature:
if ("label_ids" in columns or "label" in columns) and "labels" not in columns:
columns = [col for col in columns if col not in ["label_ids", "label"]] + ["labels"]
if ("label_ids" in label_cols or "label" in label_cols) and "labels" not in label_cols:
label_cols = [col for col in label_cols if col not in ["label_ids", "label"]] + ["labels"]
for col in columns:
if col not in output_signature:
raise ValueError(f"Column {col} not found in dataset!")
for col in label_cols:
if col not in output_signature:
raise ValueError(f"Label column {col} not found in dataset!")
if num_workers == 0:
tf_dataset = dataset_to_tf(
dataset=dataset,
cols_to_retain=cols_to_retain,
collate_fn=collate_fn,
collate_fn_args=collate_fn_args,
columns_to_np_types=columns_to_np_types,
output_signature=output_signature,
shuffle=shuffle,
batch_size=batch_size,
drop_remainder=drop_remainder,
)
elif num_workers > 0:
if batch_size is None:
raise NotImplementedError(
"`batch_size` must be specified when using multiple workers, as unbatched multiprocessing "
"is not supported yet. Please provide a `batch_size` if `num_workers` is greater than 0."
)
tf_dataset = multiprocess_dataset_to_tf(
dataset=dataset,
cols_to_retain=cols_to_retain,
collate_fn=collate_fn,
collate_fn_args=collate_fn_args,
columns_to_np_types=columns_to_np_types,
output_signature=output_signature,
shuffle=shuffle,
batch_size=batch_size,
drop_remainder=drop_remainder,
num_workers=num_workers,
)
else:
raise ValueError("num_workers must be >= 0")
def split_features_and_labels(input_batch):
# TODO(Matt, QL): deprecate returning the dict content when there's only one key
features = {key: tensor for key, tensor in input_batch.items() if key in columns}
labels = {key: tensor for key, tensor in input_batch.items() if key in label_cols}
if len(features) == 1:
features = list(features.values())[0]
if len(labels) == 1:
labels = list(labels.values())[0]
if isinstance(labels, dict) and len(labels) == 0:
return features
else:
return features, labels
if cols_to_retain is not None:
tf_dataset = tf_dataset.map(split_features_and_labels)
if prefetch:
tf_dataset = tf_dataset.prefetch(tf.data.experimental.AUTOTUNE)
# Remove a reference to the open Arrow file on delete
def cleanup_callback(ref):
dataset.__del__()
self._TF_DATASET_REFS.remove(ref)
self._TF_DATASET_REFS.add(weakref.ref(tf_dataset, cleanup_callback))
return tf_dataset
class DatasetTransformationNotAllowedError(Exception):
pass
def transmit_format(func):
"""Wrapper for dataset transforms that recreate a new Dataset to transmit the format of the original dataset to the new dataset"""
@wraps(func)
def wrapper(*args, **kwargs):
if args:
self: "Dataset" = args[0]
args = args[1:]
else:
self: "Dataset" = kwargs.pop("self")
# don't use self.format since it returns a list of columns for 'columns' even if self_format_columns is None
unformatted_columns = set(self.column_names) - set(self._format_columns or [])
self_format = {
"type": self._format_type,
"format_kwargs": self._format_kwargs,
"columns": self._format_columns,
"output_all_columns": self._output_all_columns,
}
# apply actual function
out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
datasets: List["Dataset"] = list(out.values()) if isinstance(out, dict) else [out]
# re-apply format to the output
for dataset in datasets:
new_format = self_format.copy()
if new_format["columns"] is not None: # new formatted columns = (columns - previously unformatted columns)
# sort the columns to have a deterministic list of columns that we can compare with `out_format`
new_format["columns"] = sorted(set(dataset.column_names) - unformatted_columns)
out_format = {
"type": dataset._format_type,
"format_kwargs": dataset._format_kwargs,
"columns": sorted(dataset._format_columns) if dataset._format_columns is not None else None,
"output_all_columns": dataset._output_all_columns,
}
if out_format != new_format:
fingerprint = dataset._fingerprint
dataset.set_format(**new_format)
dataset._fingerprint = fingerprint
return out
wrapper._decorator_name_ = "transmit_format"
return wrapper
def transmit_tasks(func):
"""Wrapper for dataset transforms that recreate a new Dataset to transmit the task templates of the original dataset to the new dataset"""
@wraps(func)
def wrapper(*args, **kwargs):
if args:
self: "Dataset" = args[0]
args = args[1:]
else:
self: "Dataset" = kwargs.pop("self")
# apply actual function
out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
datasets: List["Dataset"] = list(out.values()) if isinstance(out, dict) else [out]
for dataset in datasets:
# Remove task templates if a column mapping of the template is no longer valid
if self.info.task_templates is not None:
dataset.info.task_templates = [
template
for template in self.info.task_templates
if all(
dataset._info.features.get(k) == self._info.features.get(k)
for k in template.column_mapping.keys()
)
]
return out
wrapper._decorator_name_ = "transmit_tasks"
return wrapper
def update_metadata_with_features(table: Table, features: Features):
"""To be used in dataset transforms that modify the features of the dataset, in order to update the features stored in the metadata of its schema."""
features = Features({col_name: features[col_name] for col_name in table.column_names})
if table.schema.metadata is None or b"huggingface" not in table.schema.metadata:
pa_metadata = ArrowWriter._build_metadata(DatasetInfo(features=features))
else:
metadata = json.loads(table.schema.metadata[b"huggingface"].decode())
if "info" not in metadata:
metadata["info"] = asdict(DatasetInfo(features=features))
else:
metadata["info"]["features"] = asdict(DatasetInfo(features=features))["features"]
pa_metadata = {"huggingface": json.dumps(metadata)}
table = table.replace_schema_metadata(pa_metadata)
return table
def _check_table(table) -> Table:
"""We check the table type to make sure it's an instance of :class:`datasets.table.Table`"""
if isinstance(table, pa.Table):
# for a pyarrow table, we can just consider it as a in-memory table
# this is here for backward compatibility
return InMemoryTable(table)
elif isinstance(table, Table):
return table
else:
raise TypeError(f"Expected a pyarrow.Table or a datasets.table.Table object, but got {table}.")
def _check_column_names(column_names: List[str]):
"""Check the column names to make sure they don't contain duplicates."""
counter = Counter(column_names)
if not all(count == 1 for count in counter.values()):
duplicated_columns = [col for col in counter if counter[col] > 1]
raise ValueError(f"The table can't have duplicated columns but columns {duplicated_columns} are duplicated.")
def _check_valid_indices_value(index, size):
if (index < 0 and index + size < 0) or (index >= size):
raise IndexError(f"Index {index} out of range for dataset of size {size}.")
class NonExistentDatasetError(Exception):
"""Used when we expect the existence of a dataset"""
pass
class Dataset(DatasetInfoMixin, IndexableMixin, TensorflowDatasetMixin):
"""A Dataset backed by an Arrow table."""
def __init__(
self,
arrow_table: Table,
info: Optional[DatasetInfo] = None,
split: Optional[NamedSplit] = None,
indices_table: Optional[Table] = None,
fingerprint: Optional[str] = None,
):
info = info.copy() if info is not None else DatasetInfo()
DatasetInfoMixin.__init__(self, info=info, split=split)
IndexableMixin.__init__(self)
self._data: Table = _check_table(arrow_table)
self._indices: Optional[Table] = _check_table(indices_table) if indices_table is not None else None
maybe_register_dataset_for_temp_dir_deletion(self)
self._format_type: Optional[str] = None
self._format_kwargs: dict = {}
self._format_columns: Optional[list] = None
self._output_all_columns: bool = False
self._fingerprint: str = fingerprint
# Read metadata
if self._data.schema.metadata is not None and b"huggingface" in self._data.schema.metadata:
metadata = json.loads(self._data.schema.metadata[b"huggingface"].decode())
if (
"fingerprint" in metadata and self._fingerprint is None
): # try to load fingerprint from the arrow file metadata
self._fingerprint = metadata["fingerprint"]
# Infer features if None
inferred_features = Features.from_arrow_schema(arrow_table.schema)
if self.info.features is None:
self.info.features = inferred_features
else: # make sure the nested columns are in the right order
try:
self.info.features = self.info.features.reorder_fields_as(inferred_features)
except ValueError as e:
raise ValueError(
f"{e}\nThe 'source' features come from dataset_info.json, and the 'target' ones are those of the dataset arrow file."
)
# Infer fingerprint if None
if self._fingerprint is None:
self._fingerprint = generate_fingerprint(self)
# Sanity checks
if self._info.features is None:
raise ValueError("Features can't be None in a Dataset object")
if self._fingerprint is None:
raise ValueError("Fingerprint can't be None in a Dataset object")
if self.info.features.type != inferred_features.type:
raise ValueError(
f"External features info don't match the dataset:\nGot\n{self.info.features}\nwith type\n{self.info.features.type}\n\nbut expected something like\n{inferred_features}\nwith type\n{inferred_features.type}"
)
if self._indices is not None:
if not pa.types.is_unsigned_integer(self._indices.column(0).type):
raise ValueError(
f"indices must be an Arrow table of unsigned integers, current type is {self._indices.column(0).type}"
)
_check_column_names(self._data.column_names)
self._data = update_metadata_with_features(self._data, self._info.features)
@property
def features(self) -> Features:
features = super().features
if features is None: # this is already checked in __init__
raise ValueError("Features can't be None in a Dataset object")
return features
@classmethod
def from_file(
cls,
filename: str,
info: Optional[DatasetInfo] = None,
split: Optional[NamedSplit] = None,
indices_filename: Optional[str] = None,
in_memory: bool = False,
) -> "Dataset":
"""Instantiate a Dataset backed by an Arrow table at filename.
Args:
filename (`str`):
File name of the dataset.
info (`DatasetInfo`, *optional*):
Dataset information, like description, citation, etc.
split (`NamedSplit`, *optional*):
Name of the dataset split.
indices_filename (`str`, *optional*):
File names of the indices.
in_memory (`bool`, defaults to `False`):
Whether to copy the data in-memory.
Returns:
[`Dataset`]
"""
table = ArrowReader.read_table(filename, in_memory=in_memory)
if indices_filename is not None:
indices_pa_table = ArrowReader.read_table(indices_filename, in_memory=in_memory)
else:
indices_pa_table = None
return cls(
arrow_table=table,
info=info,
split=split,
indices_table=indices_pa_table,
)
@classmethod
def from_buffer(
cls,
buffer: pa.Buffer,
info: Optional[DatasetInfo] = None,
split: Optional[NamedSplit] = None,
indices_buffer: Optional[pa.Buffer] = None,
) -> "Dataset":
"""Instantiate a Dataset backed by an Arrow buffer.
Args:
buffer (`pyarrow.Buffer`):
Arrow buffer.
info (`DatasetInfo`, *optional*):
Dataset information, like description, citation, etc.
split (`NamedSplit`, *optional*):
Name of the dataset split.
indices_buffer (`pyarrow.Buffer`, *optional*):
Indices Arrow buffer.
Returns:
[`Dataset`]
"""
table = InMemoryTable.from_buffer(buffer)
if indices_buffer is not None:
indices_table = InMemoryTable.from_buffer(buffer)
else:
indices_table = None
return cls(table, info=info, split=split, indices_table=indices_table)
@classmethod
def from_pandas(
cls,
df: pd.DataFrame,
features: Optional[Features] = None,
info: Optional[DatasetInfo] = None,
split: Optional[NamedSplit] = None,
preserve_index: Optional[bool] = None,
) -> "Dataset":
"""
Convert `pandas.DataFrame` to a `pyarrow.Table` to create a [`Dataset`].
The column types in the resulting Arrow Table are inferred from the dtypes of the `pandas.Series` in the
DataFrame. In the case of non-object Series, the NumPy dtype is translated to its Arrow equivalent. In the
case of `object`, we need to guess the datatype by looking at the Python objects in this Series.
Be aware that Series of the `object` dtype don't carry enough information to always lead to a meaningful Arrow
type. In the case that we cannot infer a type, e.g. because the DataFrame is of length 0 or the Series only
contains `None/nan` objects, the type is set to `null`. This behavior can be avoided by constructing explicit
features and passing it to this function.
Args:
df (`pandas.DataFrame`):
Dataframe that contains the dataset.
features ([`Features`], *optional*):
Dataset features.
info (`DatasetInfo`, *optional*):
Dataset information, like description, citation, etc.
split (`NamedSplit`, *optional*):
Name of the dataset split.
preserve_index (`bool`, *optional*):
Whether to store the index as an additional column in the resulting Dataset.
The default of `None` will store the index as a column, except for `RangeIndex` which is stored as metadata only.
Use `preserve_index=True` to force it to be stored as a column.
Returns:
[`Dataset`]
Example:
```py
>>> ds = Dataset.from_pandas(df)
```
"""
if info is not None and features is not None and info.features != features:
raise ValueError(
f"Features specified in `features` and `info.features` can't be different:\n{features}\n{info.features}"
)
features = features if features is not None else info.features if info is not None else None
if info is None:
info = DatasetInfo()
info.features = features
table = InMemoryTable.from_pandas(
df=df,
preserve_index=preserve_index,
)
if features is not None:
# more expensive cast than InMemoryTable.from_pandas(..., schema=features.arrow_schema)
# needed to support the str to Audio conversion for instance
table = table.cast(features.arrow_schema)
return cls(table, info=info, split=split)
@classmethod
def from_dict(
cls,
mapping: dict,
features: Optional[Features] = None,
info: Optional[DatasetInfo] = None,
split: Optional[NamedSplit] = None,
) -> "Dataset":
"""
Convert `dict` to a `pyarrow.Table` to create a [`Dataset`].
Args:
mapping (`Mapping`):
Mapping of strings to Arrays or Python lists.
features ([`Features`], *optional*):
Dataset features.
info (`DatasetInfo`, *optional*):
Dataset information, like description, citation, etc.
split (`NamedSplit`, *optional*):
Name of the dataset split.
Returns:
[`Dataset`]
"""
if info is not None and features is not None and info.features != features:
raise ValueError(
f"Features specified in `features` and `info.features` can't be different:\n{features}\n{info.features}"
)
features = features if features is not None else info.features if info is not None else None
arrow_typed_mapping = {}
for col, data in mapping.items():
if isinstance(data, (pa.Array, pa.ChunkedArray)):
data = cast_array_to_feature(data, features[col]) if features is not None else data
else:
data = OptimizedTypedSequence(
features.encode_column(data, col) if features is not None else data,
type=features[col] if features is not None else None,
col=col,
)
arrow_typed_mapping[col] = data
mapping = arrow_typed_mapping
pa_table = InMemoryTable.from_pydict(mapping=mapping)
if info is None:
info = DatasetInfo()
info.features = features
if info.features is None:
info.features = Features(
{
col: generate_from_arrow_type(data.type)
if isinstance(data, (pa.Array, pa.ChunkedArray))
else data.get_inferred_type()
for col, data in mapping.items()
}
)
return cls(pa_table, info=info, split=split)
@classmethod
def from_list(
cls,
mapping: List[dict],
features: Optional[Features] = None,
info: Optional[DatasetInfo] = None,
split: Optional[NamedSplit] = None,
) -> "Dataset":
"""
Convert a list of dicts to a `pyarrow.Table` to create a [`Dataset`]`.
Note that the keys of the first entry will be used to determine the dataset columns,
regardless of what is passed to features.
Args:
mapping (`List[dict]`): A list of mappings of strings to row values.
features (`Features`, optional): Dataset features.
info (`DatasetInfo`, optional): Dataset information, like description, citation, etc.
split (`NamedSplit`, optional): Name of the dataset split.
Returns:
[`Dataset`]
"""
# for simplicity and consistency wrt OptimizedTypedSequence we do not use InMemoryTable.from_pylist here
mapping = {k: [r.get(k) for r in mapping] for k in mapping[0]} if mapping else {}
return cls.from_dict(mapping, features, info, split)
@staticmethod
def from_csv(
path_or_paths: Union[PathLike, List[PathLike]],
split: Optional[NamedSplit] = None,
features: Optional[Features] = None,
cache_dir: str = None,
keep_in_memory: bool = False,
num_proc: Optional[int] = None,
**kwargs,
):
"""Create Dataset from CSV file(s).
Args:
path_or_paths (`path-like` or list of `path-like`):
Path(s) of the CSV file(s).
split ([`NamedSplit`], *optional*):
Split name to be assigned to the dataset.
features ([`Features`], *optional*):
Dataset features.
cache_dir (`str`, *optional*, defaults to `"~/.cache/huggingface/datasets"`):
Directory to cache data.
keep_in_memory (`bool`, defaults to `False`):
Whether to copy the data in-memory.
num_proc (`int`, *optional*, defaults to `None`):
Number of processes when downloading and generating the dataset locally.
This is helpful if the dataset is made of multiple files. Multiprocessing is disabled by default.
<Added version="2.8.0"/>
**kwargs (additional keyword arguments):
Keyword arguments to be passed to [`pandas.read_csv`].
Returns:
[`Dataset`]
Example:
```py
>>> ds = Dataset.from_csv('path/to/dataset.csv')
```
"""
# Dynamic import to avoid circular dependency
from .io.csv import CsvDatasetReader
return CsvDatasetReader(
path_or_paths,
split=split,
features=features,
cache_dir=cache_dir,