Skip to content

Memory leak with grain.experimental.device_put and ThreadPrefetchIterDataset #1021

@jjyyxx

Description

@jjyyxx

Description

It’s easy to leak memory when using ThreadPrefetchIterDataset (and more so when combined with grain.experimental.device_put).

  • When the iterator (ThreadPrefetchDatasetIterator) is not explicitly .close()d, its worker thread never exits and continues consuming memory.
  • If new iterators are repeatedly created in a loop (e.g., re-creating a validation iterator every N training steps), memory usage grows without bound.

Even worse, when device_put (or generally, any nested ThreadPrefetchIterDataset) is involved:

  • Calling .close() on the final GPU ThreadPrefetchDatasetIterator does not close its parent CPU iterator.
  • There seems to be no public API to properly close these parent iterators.

This makes it very difficult to use ThreadPrefetchIterDataset ergonomically in practice.

Example

val_dataset = MapDataset.source(val_dataset)
val_dataset = val_dataset.map(transform)
val_dataset = val_dataset.shuffle(seed=seed)
val_dataset = val_dataset.batch(batch_size, drop_remainder=True)
val_dataset = val_dataset.to_iter_dataset(ReadOptions(num_threads=0, prefetch_buffer_size=0))
val_dataset = ThreadPrefetchIterDataset(val_dataset, prefetch_buffer_size=4)
val_dataset = val_dataset.map(lambda x: jax.device_put(x, sharding))
val_dataset = ThreadPrefetchIterDataset(val_dataset, prefetch_buffer_size=2)

for i in range(100000):

    if i % 500 == 0:
        for batch in tqdm.tqdm(val_dataset):
            ...

Workaround

Currently, I avoid the leak by:

  • Maintaining a single infinite ThreadPrefetchDatasetIterator per data source (via .repeat()).
  • Wrapping it with an iterator that tracks epoch boundaries and raises StopIteration manually.

Questions

  1. Am I misusing the API? (e.g., are iterators not meant to be created frequently?)
  2. Could the ergonomics be improved? For example:
    • Automatically closing parent iterators when the child is closed.
    • More explicit guidance in the docs on lifecycle management.

Metadata

Metadata

Assignees

No one assigned

    Labels

    type:bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions