Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support pyarrow LargeListType #6835

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/datasets/features/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -1418,7 +1418,7 @@ def generate_from_arrow_type(pa_type: pa.DataType) -> FeatureType:
return {field.name: generate_from_arrow_type(field.type) for field in pa_type}
elif isinstance(pa_type, pa.FixedSizeListType):
return Sequence(feature=generate_from_arrow_type(pa_type.value_type), length=pa_type.list_size)
elif isinstance(pa_type, pa.ListType):
elif isinstance(pa_type, (pa.ListType, pa.LargeListType)):
feature = generate_from_arrow_type(pa_type.value_type)
if isinstance(feature, (dict, tuple, list)):
return [feature]
Expand Down
16 changes: 10 additions & 6 deletions src/datasets/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1811,13 +1811,13 @@ def _are_list_values_of_length(array: pa.ListArray, length: int) -> bool:
return pc.all(pc.equal(array.value_lengths(), length)).as_py() or array.null_count == len(array)


def _combine_list_array_offsets_with_mask(array: pa.ListArray) -> pa.Array:
"""Add the null bitmap to the offsets of a `pa.ListArray`."""
def _combine_list_array_offsets_with_mask(array: Union[pa.ListArray, pa.LargeListArray]) -> pa.Array:
"""Add the null bitmap to the offsets of a `pa.ListArray` or `pa.LargeListArray`."""
offsets = array.offsets
if array.null_count > 0:
offsets = pa.concat_arrays(
[
pc.replace_with_mask(offsets[:-1], array.is_null(), pa.nulls(len(array), pa.int32())),
pc.replace_with_mask(offsets[:-1], array.is_null(), pa.nulls(len(array), offsets.type)),
offsets[-1:],
]
)
Expand Down Expand Up @@ -2012,7 +2012,7 @@ def cast_array_to_feature(
return array
arrays = [_c(array.field(name), subfeature) for name, subfeature in feature.items()]
return pa.StructArray.from_arrays(arrays, names=list(feature), mask=array.is_null())
elif pa.types.is_list(array.type):
elif pa.types.is_list(array.type) or pa.types.is_large_list(array.type):
# feature must be either [subfeature] or Sequence(subfeature)
if isinstance(feature, list):
casted_array_values = _c(array.values, feature[0])
Expand All @@ -2021,7 +2021,9 @@ def cast_array_to_feature(
else:
# Merge offsets with the null bitmap to avoid the "Null bitmap with offsets slice not supported" ArrowNotImplementedError
array_offsets = _combine_list_array_offsets_with_mask(array)
return pa.ListArray.from_arrays(array_offsets, casted_array_values)
if pa.types.is_list(array.type):
return pa.ListArray.from_arrays(array_offsets, casted_array_values)
return pa.LargeListArray.from_arrays(array_offsets, casted_array_values)
elif isinstance(feature, Sequence):
if feature.length > -1:
if _are_list_values_of_length(array, feature.length):
Expand Down Expand Up @@ -2071,7 +2073,9 @@ def cast_array_to_feature(
else:
# Merge offsets with the null bitmap to avoid the "Null bitmap with offsets slice not supported" ArrowNotImplementedError
array_offsets = _combine_list_array_offsets_with_mask(array)
return pa.ListArray.from_arrays(array_offsets, casted_array_values)
if pa.types.is_list(array.type):
return pa.ListArray.from_arrays(array_offsets, casted_array_values)
return pa.LargeListArray.from_arrays(array_offsets, casted_array_values)
elif pa.types.is_fixed_size_list(array.type):
# feature must be either [subfeature] or Sequence(subfeature)
if isinstance(feature, list):
Expand Down