-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix iter_batches #5115
Fix iter_batches #5115
Conversation
The documentation is not available anymore as the PR was closed or merged. |
I also ran the code in #5111 and it works fine now :) |
This is ready for review :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix.
Just a few comments below.
chunks_buffer_size = 0 | ||
for chunk in pa_table.to_reader(max_chunksize=batch_size): | ||
if len(chunk) == 0: | ||
continue | ||
elif chunks_buffer_size + len(chunk) < batch_size: | ||
chunks_buffer.append(chunk) | ||
chunks_buffer_size += len(chunk) | ||
continue | ||
elif chunks_buffer_size + len(chunk) == batch_size: | ||
chunks_buffer.append(chunk) | ||
yield pa.Table.from_batches(chunks_buffer) | ||
chunks_buffer = [] | ||
chunks_buffer_size = 0 | ||
else: | ||
cropped_chunk_length = batch_size - chunks_buffer_size | ||
chunks_buffer.append(chunk.slice(0, cropped_chunk_length)) | ||
yield pa.Table.from_batches(chunks_buffer) | ||
chunks_buffer = [chunk.slice(cropped_chunk_length, len(chunk) - cropped_chunk_length)] | ||
chunks_buffer_size = len(chunk) - cropped_chunk_length |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can remove the variable chunks_buffer_size
:
chunks_buffer_size = 0 | |
for chunk in pa_table.to_reader(max_chunksize=batch_size): | |
if len(chunk) == 0: | |
continue | |
elif chunks_buffer_size + len(chunk) < batch_size: | |
chunks_buffer.append(chunk) | |
chunks_buffer_size += len(chunk) | |
continue | |
elif chunks_buffer_size + len(chunk) == batch_size: | |
chunks_buffer.append(chunk) | |
yield pa.Table.from_batches(chunks_buffer) | |
chunks_buffer = [] | |
chunks_buffer_size = 0 | |
else: | |
cropped_chunk_length = batch_size - chunks_buffer_size | |
chunks_buffer.append(chunk.slice(0, cropped_chunk_length)) | |
yield pa.Table.from_batches(chunks_buffer) | |
chunks_buffer = [chunk.slice(cropped_chunk_length, len(chunk) - cropped_chunk_length)] | |
chunks_buffer_size = len(chunk) - cropped_chunk_length | |
for chunk in pa_table.to_reader(max_chunksize=batch_size): | |
if len(chunk) == 0: | |
continue | |
elif len(chunks_buffer) + len(chunk) < batch_size: | |
chunks_buffer.append(chunk) | |
continue | |
elif len(chunks_buffer) + len(chunk) == batch_size: | |
chunks_buffer.append(chunk) | |
yield pa.Table.from_batches(chunks_buffer) | |
chunks_buffer = [] | |
else: | |
cropped_chunk_length = batch_size - len(chunks_buffer) | |
chunks_buffer.append(chunk.slice(0, cropped_chunk_length)) | |
yield pa.Table.from_batches(chunks_buffer) | |
chunks_buffer = [chunk.slice(cropped_chunk_length, len(chunk) - cropped_chunk_length)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
chunks_buffer_size
is the sum of the lengths of all the chunks in the buffer - not just the length of the buffer
chunks_buffer = [chunk.slice(cropped_chunk_length, len(chunk) - cropped_chunk_length)] | ||
chunks_buffer_size = len(chunk) - cropped_chunk_length | ||
if not drop_last_batch and chunks_buffer: | ||
yield pa.Table.from_batches(chunks_buffer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm just wondering if this function may have a performance impact, instead of just calling for batch in self.data.to_reader(max_chunksize=batch_size)
, as before.
If so, we should check how much impact, so that we do not lose the performance gain introduced by #5030.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code is roughly the same as in #5030
Also note that the worst case scenario for this implementation is when the dataset is made of chunks of length 1, but even in this case this is faster than calling __getitem__
for each item.
ds = concatenate_datasets([Dataset.from_dict({"a": [i]}) for i in range(100)])
%time list(ds._iter_batches(batch_size=10))
# <1ms
%time [ds[i:i+10] for i in range(0, len(ds), 10)]
# 1ms
%time list(ds)
# 3ms
%time [ds[i] for i in range(len(ds))]
# 5ms
It's even better for big datasets, since __getitem__
is not O(1) because of interpolation search. Here getting the next item is O(1)
Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, LGTM!
The
pa.Table.to_reader()
method available inpyarrow>=8.0.0
may return chunks of size <max_chunksize
, thereforeiter_batches
can return batches smaller than thebatch_size
specified by the userTherefore batched
map
couldn't always use batches of the right size, e.g. this fails because it runs only on one batch of one element:This was introduced in #5030
Close #5111
This will require a patch release along with #5113
TODO: