Skip to content

Commit

Permalink
Elastic world size deterministic shuffle with mid-epoch resumption (#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
knighton committed Dec 7, 2022
1 parent 594f1cb commit c9c8273
Show file tree
Hide file tree
Showing 48 changed files with 2,046 additions and 1,159 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/install.yaml
Expand Up @@ -17,9 +17,9 @@ jobs:
strategy:
matrix:
python_version:
- "3.7"
- "3.8"
- "3.9"
- "3.10"
install_version:
- ""
steps:
Expand Down
4 changes: 2 additions & 2 deletions STYLE_GUIDE.md
Expand Up @@ -208,9 +208,9 @@ For example, from [streaming/base/dataset.py](streaming/base/dataset.py)
from torch.utils.data import IterableDataset

from streaming.base.format import reader_from_json
from streaming.base.index import Index, Partition
from streaming.base.index import Index

__all__ = ["Dataset"] # export only the Dataset, not other imports like `Index`, `Partition`, or `reader_from_json`
__all__ = ["Dataset"] # export only the Dataset, not other imports like `Index` or `reader_from_json`


class Dataset(IterableDataset):
Expand Down
6 changes: 3 additions & 3 deletions docs/source/getting_started/quick_start.md
Expand Up @@ -48,18 +48,18 @@ Start training your model with the Streaming dataset in a few steps!
$ aws s3 cp dirname s3://mybucket/myfolder --recursive
```

3. Replace the original {class}`torch.utils.data.IterableDataset` with your new {class}`streaming.Dataset`.
3. Replace the original {class}`torch.utils.data.IterableDataset` with your new {class}`streaming.StreamingDataset`.
<!--pytest.mark.skip-->
```python
from torch.utils.data import DataLoader
from streaming import Dataset
from streaming import StreamingDataset

# Remote directory (S3 or local filesystem) where dataset is stored
remote_dir = 's3://datapath'

# Local directory where dataset is cached during operation
local_dir = 'local_dir'
dataset = Dataset(local=local_dir, remote=remote_dir, split=None, shuffle=True)
dataset = StreamingDataset(local=local_dir, remote=remote_dir, split=None, shuffle=True)

# Create PyTorch DataLoader
dataloader = DataLoader(dataset)
Expand Down
14 changes: 7 additions & 7 deletions docs/source/getting_started/user_guide.md
@@ -1,6 +1,6 @@
# 🖼️ User Guide

At a very high level, one needs to convert a raw dataset into streaming format files and then use the same streaming format files using {class}`streaming.Dataset` class for model training.
At a very high level, one needs to convert a raw dataset into streaming format files and then use the same streaming format files using {class}`streaming.StreamingDataset` class for model training.

Streaming supports different dataset writers based on your need for conversion of raw datasets into a streaming format such as
- {class}`streaming.MDSWriter`: Writes the dataset into `.mds` (Mosaic Data Shard) extension. It supports various encoding/decoding formats(`str`, `int`, `bytes`, `jpeg`, `png`, `pil`, `pkl`, and `json`) which convert the data from that format to bytes and vice-versa.
Expand All @@ -11,15 +11,15 @@ Streaming supports different dataset writers based on your need for conversion o

For more information about writers and their parameters, look at the [API reference doc](../api_reference/streaming.rst).

After the dataset has been converted to one of our streaming formats, one just needs to instantiate the {class}`streaming.Dataset` class by providing the dataset path of the streaming formats and use that dataset object in PyTorch {class}`torch.utils.data.DataLoader` class. For more information about `streaming.Dataset` and its parameters, look at the {class}`streaming.Dataset` API reference doc.
After the dataset has been converted to one of our streaming formats, one just needs to instantiate the {class}`streaming.StreamingDataset` class by providing the dataset path of the streaming formats and use that dataset object in PyTorch {class}`torch.utils.data.DataLoader` class. For more information about `streaming.StreamingDataset` and its parameters, look at the {class}`streaming.StreamingDataset` API reference doc.

Streaming supports various dataset compression formats (Brotli, Bzip2, Gzip, Snappy, and Zstandard) that reduces downloading time and cloud egress fees. Additionally, Streaming also supports various hashing algorithms (SHA2, SHA3, MD5, xxHash, etc.) that ensures data integrity through cryptographic and non-cryptographic hashing algorithm.

Let's jump right into an example on how to convert a raw dataset into a streaming format and load the same streaming format dataset for model training.

## Writing a dataset to streaming format

This guide shows you how to use your custom Dataset with {class}`streaming.MDSWriter`, but the steps would remain the same for other writers.
This guide shows you how to use your custom StreamingDataset with {class}`streaming.MDSWriter`, but the steps would remain the same for other writers.

The {class}`streaming.MDSWriter` takes the raw dataset and converts it into a sharded `.mds` format for fast data access.

Expand Down Expand Up @@ -133,12 +133,12 @@ $ aws s3 cp dirname s3://mybucket/myfolder --recursive

After writing a dataset in the streaming format in the previous step and uploading to a cloud object storage as s3, we are ready to start loading the data.

To load the same dataset files that were created in the above steps, create a `CustomDataset` class by inheriting the {class}`streaming.Dataset` class and override the `__getitem__(idx: int)` method to get the samples. The {class}`streaming.Dataset` class requires two mandatory parameters which are `remote` which is a remote directory (S3 or local filesystem) where dataset is stored and `local` which is a local directory where dataset is cached during operation.
To load the same dataset files that were created in the above steps, create a `CustomDataset` class by inheriting the {class}`streaming.StreamingDataset` class and override the `__getitem__(idx: int)` method to get the samples. The {class}`streaming.StreamingDataset` class requires two mandatory parameters which are `remote` which is a remote directory (S3 or local filesystem) where dataset is stored and `local` which is a local directory where dataset is cached during operation.
<!--pytest-codeblocks:cont-->
```python
from streaming.base import Dataset
from streaming import StreamingDataset

class CustomDataset(Dataset):
class CustomDataset(StreamingDataset):
def __init__(self, local, remote):
super().__init__(local, remote)

Expand Down Expand Up @@ -171,4 +171,4 @@ You've now seen an in-depth look at how to prepare and use streaming datasets wi

## Other options

Please look at the API reference page for the complete list of {class}`streaming.Dataset` supporting parameters.
Please look at the API reference page for the complete list of {class}`streaming.StreamingDataset` supporting parameters.
2 changes: 1 addition & 1 deletion scripts/compression/bench.py
Expand Up @@ -9,7 +9,7 @@

import numpy as np

from streaming.base.compression.compression import compress, decompress, get_compressions
from streaming.base.compression import compress, decompress, get_compressions


def parse_args() -> Namespace:
Expand Down
2 changes: 1 addition & 1 deletion scripts/hashing/bench.py
Expand Up @@ -9,7 +9,7 @@

import numpy as np

from streaming.base.hashing.hashing import get_hash, get_hashes
from streaming.base.hashing import get_hash, get_hashes


def parse_args() -> Namespace:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -36,9 +36,9 @@

classifiers = [
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
]

install_requires = [
Expand Down
8 changes: 4 additions & 4 deletions streaming/__init__.py
Expand Up @@ -6,10 +6,10 @@
import streaming.text as text
import streaming.vision as vision
from streaming._version import __version__
from streaming.base import (CSVWriter, Dataset, JSONWriter, LocalDataset, MDSWriter, TSVWriter,
XSVWriter)
from streaming.base import (CSVWriter, JSONWriter, LocalDataset, MDSWriter, StreamingDataLoader,
StreamingDataset, TSVWriter, XSVWriter)

__all__ = [
'Dataset', 'CSVWriter', 'JSONWriter', 'MDSWriter', 'TSVWriter', 'XSVWriter', 'LocalDataset',
'vision', 'text'
'StreamingDataLoader', 'StreamingDataset', 'CSVWriter', 'JSONWriter', 'MDSWriter', 'TSVWriter',
'XSVWriter', 'LocalDataset', 'vision', 'text'
]
6 changes: 4 additions & 2 deletions streaming/base/__init__.py
Expand Up @@ -3,10 +3,12 @@

"""MosaicML Streaming Datasets for cloud-native model training."""

from streaming.base.dataset import Dataset
from streaming.base.dataloader import StreamingDataLoader
from streaming.base.dataset import StreamingDataset
from streaming.base.format import CSVWriter, JSONWriter, MDSWriter, TSVWriter, XSVWriter
from streaming.base.local import LocalDataset

__all__ = [
'Dataset', 'CSVWriter', 'JSONWriter', 'LocalDataset', 'MDSWriter', 'TSVWriter', 'XSVWriter'
'StreamingDataLoader', 'StreamingDataset', 'CSVWriter', 'JSONWriter', 'LocalDataset',
'MDSWriter', 'TSVWriter', 'XSVWriter'
]
File renamed without changes.
12 changes: 0 additions & 12 deletions streaming/base/compression/__init__.py

This file was deleted.

92 changes: 92 additions & 0 deletions streaming/base/dataloader.py
@@ -0,0 +1,92 @@
# Copyright 2022 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0

"""Streaming DataLoader."""

from typing import Any, Dict, Iterator, Optional

from torch import Tensor
from torch.utils.data import DataLoader
from transformers.tokenization_utils_base import BatchEncoding

from streaming.base.dataset import StreamingDataset
from streaming.base.world import World


class StreamingDataLoader(DataLoader):
"""A streaming data loader.
Provides an additional checkpoint/resumption interface, for which it tracks the number of
samples seen by the model this rank.
Args:
*args: List arguments.
**kwargs: Keyword arguments.
"""

def __init__(self, *args, **kwargs) -> None: # pyright: ignore
super().__init__(*args, **kwargs)
self.num_samples_yielded = 0

def _get_batch_size(self, batch: Any) -> int:
"""Get the number of samples in a batch.
Args:
batch (Any): The batch.
Returns:
int: Number of samples.
"""
if isinstance(batch, (dict, BatchEncoding)):
for value in batch.values():
return len(value)
raise ValueError('Batch is empty')
elif isinstance(batch, Tensor):
return len(batch)
else:
return len(batch[0])

def __iter__(self) -> Iterator[Any]:
"""Iterate over this DataLoader, yielding batches.
Also tracks the number of samples seen this rank.
Returns:
Iterator[Any]: Each batch.
"""
self.num_samples_yielded = 0
for batch in super().__iter__():
self.num_samples_yielded += self._get_batch_size(batch)
yield batch

def state_dict(self) -> Optional[Dict[str, Any]]:
"""Get a dict containing training state (called from non-worker process).
This is called on rank zero.
Args:
samples_in_epoch (int): The number of samples processed so far in the current epoch.
Returns:
Optional[Dict[str, Any]]: The state, if a streaming dataset.
"""
if isinstance(self.dataset, StreamingDataset):
world = World()
return self.dataset.state_dict(self.num_samples_yielded * world.num_ranks)
return None

def load_state_dict(self, obj: Dict[str, Any]) -> None:
"""Load a dict containing training state (called from non-worker process).
This is called on each copy of the dataset when resuming.
Args:
obj (Dict[str, Any]): The state.
"""
if isinstance(self.dataset, StreamingDataset):
return self.dataset.load_state_dict(obj)

def __del__(self) -> None:
"""Terminate the workers during cleanup."""
if self._iterator is not None:
self._iterator._shutdown_workers() # type: ignore [reportGeneralTypeIssues]

0 comments on commit c9c8273

Please sign in to comment.