Skip to content

Latest commit

 

History

History
52 lines (41 loc) · 2.08 KB

pytorch.rst

File metadata and controls

52 lines (41 loc) · 2.08 KB

PyTorch

The Squirrel api is designed to support fast streaming of datasets to a multi-rank, distributed system, as often encountered in modern deep learning applications involving multiple GPUs. To this end, we can use :pysplit_by_rank_pytorch and :pysplit_by_worker_pytorch and wrap the final iterator in a torch Dataloader object

import torch.utils.data as tud from squirrel.iterstream.source import IterableSource

def times_two(x: float) -> float:

return x * 2

samples = list(range(100)) batch_size = 5 num_workers = 4 it = ( IterableSource(samples) .split_by_rank_pytorch() .async_map(times_two) .split_by_worker_pytorch() .batched(batch_size) .to_torch_iterable() ) dl = tud.DataLoader(it, num_workers=num_workers)

Note that the rank of the distributed system depends on the torch distributed process group and is automatically determined.

Note

:pysplit_by_rank_pytorch, :pysplit_by_worker_pytorch and :pyto_torch_iterable are simply convenience functions to chain your iterator with PyTorch specific iterators. These are implemented as specific Composables. An example of such a PyTorch specific Composable is given below through :pySplitByWorker. To see how to chain Composables, see advanced/iterstream:Custom Composable.

And using :pysquirrel.driver api:

from squirrel.driver import MessagepackDriver
url = ""
it = MessagepackDriver(url).get_iter(key_hooks=[SplitByWorker]).async_map(times_two).batched(batch_size).compose(TorchIterable)
dl = DataLoader(it, num_workers=num_workers)

In this example, key_hooks=[SplitByWorker] ensures that keys are split between workers before fetching the data and we achieve two level of parallelism; multi-processing provided by torch.utils.data.DataLoader, and multi-threading inside each process for efficiently fetching samples by get_iter.