-
Notifications
You must be signed in to change notification settings - Fork 63
Closed
Labels
type:bugSomething isn't workingSomething isn't working
Description
Description
It’s easy to leak memory when using ThreadPrefetchIterDataset (and more so when combined with grain.experimental.device_put).
- When the iterator (
ThreadPrefetchDatasetIterator) is not explicitly.close()d, its worker thread never exits and continues consuming memory. - If new iterators are repeatedly created in a loop (e.g., re-creating a validation iterator every N training steps), memory usage grows without bound.
Even worse, when device_put (or generally, any nested ThreadPrefetchIterDataset) is involved:
- Calling
.close()on the final GPUThreadPrefetchDatasetIteratordoes not close its parent CPU iterator. - There seems to be no public API to properly close these parent iterators.
This makes it very difficult to use ThreadPrefetchIterDataset ergonomically in practice.
Example
val_dataset = MapDataset.source(val_dataset)
val_dataset = val_dataset.map(transform)
val_dataset = val_dataset.shuffle(seed=seed)
val_dataset = val_dataset.batch(batch_size, drop_remainder=True)
val_dataset = val_dataset.to_iter_dataset(ReadOptions(num_threads=0, prefetch_buffer_size=0))
val_dataset = ThreadPrefetchIterDataset(val_dataset, prefetch_buffer_size=4)
val_dataset = val_dataset.map(lambda x: jax.device_put(x, sharding))
val_dataset = ThreadPrefetchIterDataset(val_dataset, prefetch_buffer_size=2)
for i in range(100000):
if i % 500 == 0:
for batch in tqdm.tqdm(val_dataset):
...Workaround
Currently, I avoid the leak by:
- Maintaining a single infinite
ThreadPrefetchDatasetIteratorper data source (via.repeat()). - Wrapping it with an iterator that tracks epoch boundaries and raises
StopIterationmanually.
Questions
- Am I misusing the API? (e.g., are iterators not meant to be created frequently?)
- Could the ergonomics be improved? For example:
- Automatically closing parent iterators when the child is closed.
- More explicit guidance in the docs on lifecycle management.
Metadata
Metadata
Assignees
Labels
type:bugSomething isn't workingSomething isn't working