# Epcohraft Tutorial (beta)


## Installation

```bash
pip install epochraft
pip install smart_open[s3] # Install S3 deps
```

# Quick Start

## Dataset Construction

This is an example of building a typical pretraining dataset.

In [1]:
from epochraft import CheckpointableDataset
from transformers import LlamaTokenizer
import torch

tokenizer = LlamaTokenizer.from_pretrained("NovelAI/nerdstash-tokenizer-v1")

url = "s3://polyglot-ja-west/2_quality_filter/v2/cc-100/cc-100_00.jsonl"

dataset = (
    CheckpointableDataset
    .from_files(url, repeat=True, shuffle_shards=True)
    .tokenize(tokenizer)        # Tokenize the texts
    .ensure_bos_eos(tokenizer)  # Add BOS and EOS tokens where necessary
    .concat_chunk(1024)         # Concatenate and chunk the tokens into a fixed length of 1024 tokens
    .shuffle(1000)              # Shuffle the sequences using a buffer of size 1000
    .batch(8)                   # Group the data into mini-batches with a batch size of 8
)

it = dataset.iter()
batch = next(it)
type(batch["input_ids"]), batch["input_ids"].shape

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. If you see this, DO NOT PANIC! This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


(torch.Tensor, torch.Size([8, 1024]))

## Checkpointing

Normally, you would obtain and save the state_dict of the model and optimizer. In addition to that, please also obtain and save the state_dict of the iterator

In [2]:
state_dict = {}
state_dict["it"] = it.state_dict()
torch.save(state_dict, "checkpoint.pth")

## Resumption

You can restore the state of the iterator by passing the `state_dict` to the iter method of the `CheckpointableDataset` instance.

In [3]:
state_dict = torch.load("checkpoint.pth")
it = dataset.iter(state_dict=state_dict["it"])
batch = next(it)
type(batch["input_ids"]), batch["input_ids"].shape

(torch.Tensor, torch.Size([8, 1024]))

# Overview

## Design
Epochraft is designed to achieve the following three features:

1. Streaming from Cloud Stoarge
2. On-the-Fly Tokenization
3. Data Loader Checkpointing

To my knowledge, only epochraft offers all three of these features. For more details, please refer to the README.md at https://github.com/iwiwi/epochraft/.

## Main Components
The main class in epochraft is `CheckpointableDataset`. It is constructed as follows:

* Use class methods like `from_files` to create an instance of `CheckpointableDataset`.
* Apply transformations to the dataset by calling functions such as `tokenize` or `shuffle`.
* Combine multiple `CheckpointableDataset` instances using methods like `interleave_datasets`.

When you want to actually read data from the `CheckpointableDataset`, you call the `iter` method to obtain a `CheckpointableIterator`. This provides a standard Python iterator interface, so you can retrieve elements using the `next` function or loop through them using a `for` loop.

Furthermore, by calling the `state_dict` method of the `CheckpointableIterator`, you can capture its current state. This state can be passed back to the iter method to restore the iterator's position. Epochraft is fully deterministic, ensuring an exact restoration of the state.



# Data Source

## `from_files` Method

To stream data from Cloud Storage, you use the `CheckpointableDataset.from_files` method. The first argument, `urls`, should be provided with either a string representing the URL or a list of such strings.

Any URL that is supported by `smart_open` should work with this function. Naturally, it can also read local files.

In [4]:
# A single URL
dataset: CheckpointableDataset = CheckpointableDataset.from_files("s3://polyglot-ja-west/2_quality_filter/v2/cc-100/cc-100_00.jsonl")
sample = next(iter(dataset))
print(sample)


{'text': '嬉しい事に、お料理だけでなく日々のおやつもワタシのレシピを活用して下さっている方もいらっしゃるとのこと。'}


In [5]:
# A list of URLs
dataset: CheckpointableDataset = CheckpointableDataset.from_files([
    "s3://polyglot-ja-west/2_quality_filter/v2/cc-100/cc-100_00.jsonl",
    "s3://polyglot-ja-west/2_quality_filter/v2/cc-100/cc-100_01.jsonl"
])
sample = next(iter(dataset))
print(sample)

{'text': '嬉しい事に、お料理だけでなく日々のおやつもワタシのレシピを活用して下さっている方もいらっしゃるとのこと。'}


In [6]:
# `braceexpand` is automatically applied to the URL
dataset: CheckpointableDataset = CheckpointableDataset.from_files(
    "s3://polyglot-ja-west/2_quality_filter/v2/cc-100/cc-100_{00..99}.jsonl",
)
sample = next(iter(dataset))
print(sample)

{'text': '嬉しい事に、お料理だけでなく日々のおやつもワタシのレシピを活用して下さっている方もいらっしゃるとのこと。'}


### Data Preparation

* Currently, we support JSONL and CBOR formats. While we attempt to infer the format from the file extension, if this fails you can specify the format explicitly using the `format` argument.
* Each sample is expected to be a `dict`. In a typical LLM (Language Model) training, it should contain a `text` field.
* Large datasets should be split into multiple shard files. This allows for operations like shuffling and facilitates data partitioning in data-parallel training.





### Important Arguments

* The `repeat` and `shuffle_shards` are arguably the most important arguments. Typically, for a training dataset, both would be set to `True`, while for a validation dataset, both would be set to `False`.
* `n_active_shards` specifies the number of shards that are opened and read simultaneously (called *active shards*). Samples will be alternately read from these shards.
* `n_standby_shards` defines the number of shards that are pre-opened and pre-fetched in the background (called *standby shards*). This is used to cover the time taken to open or read files. When one of the active shards reaches its end, a standby shard becomes an active shard, and a new standby shard is opened.



## Other Construction Methods

You can also construct datasets using methods like `CheckpointableDataset.from_sequence` or `CheckpointableDataset.from_iterable`. These are particularly handy during development and debugging phases. Moreover, if you want to use HuggingFace Dataset as your data source, these methods are applicable as well.

In [7]:
samples = [
    {"text": "Hello world!"},
    {"text": "こんにちは世界！"},
    {"text": "你好世界！"},
    {"text": "안녕하세요 세계!"},
]

dataset = CheckpointableDataset.from_sequence(samples, shuffle=True)
for sample in dataset:
    print(sample)

{'text': '你好世界！'}
{'text': 'こんにちは世界！'}
{'text': '안녕하세요 세계!'}
{'text': 'Hello world!'}


# Transforms

## General Transforms: `map`, `filter`, and `filter_map`

We can arbitrarily modify or filter samples by using these methods.

In [8]:
def f(sample):
    sample = sample.copy()
    sample["text"] = sample["text"].upper()
    return sample

dataset = CheckpointableDataset.from_sequence(samples).map(f)
print(next(iter(dataset)))

{'text': 'HELLO WORLD!'}


In [9]:
def f(sample):
    return len(sample["text"]) < 10

dataset = CheckpointableDataset.from_sequence(samples).filter(f)
for sample in dataset:
    print(sample)

{'text': 'こんにちは世界！'}
{'text': '你好世界！'}
{'text': '안녕하세요 세계!'}


## Parallel Transforms: `parallel_map`, `parallel_filter`, and `parallel_filter_map`

They are the parallel versions of `map`, `filter`, and `filter_map`, respectively. While `map`, `filter`, and `filter_map` applies the given method in the main thread, the parallel versions spawn workers and applies the method in the background.

By specifying `"thread"` for `executor_type`, it runs in multithreading mode, while specifying `"process"` runs it in multiprocessing mode. The number of workers can be defined with `max_workers`.

Using processes won't be limited by the GIL (Global Interpreter Lock), but there's a significant overhead in starting up workers. For tasks that aren't hampered by the GIL, it's recommended to use threads. This includes operations like IO, image decoding, and native tokenizers. Furthermore, it's advisable to keep the number of workers to the essential minimum.

In [10]:
def f(sample):
    sample = sample.copy()
    sample["text"] = sample["text"].upper()
    return sample

dataset = CheckpointableDataset.from_sequence(samples).parallel_map(f, executor_type="thread", max_workers=1)
print(next(iter(dataset)))

{'text': 'HELLO WORLD!'}


If the speed is insufficient, besides increasing the values of `max_workers` and `prefetch_factor`, you can also set `ordered` to `False`. In this scenario, the dataset's order will not be preserved. While this will improve throughput, it will reduce reproducibility.


## Tokenization

Using the `tokenize` method, you can apply a tokenizer to the text. The tokenizer is expected to adhere to the HuggingFace interface.

You can specify the field name containing the text using `target_column` (default is `"text"`). Internally, it's implemented using `parallel_map`, and thus has similar arguments related to parallelization (if you set `parallel=False`, it will use `map`). Typically, tokenizers release the GIL (Global Interpreter Lock), so `executor_type` being set to `"thread"` should be sufficient. If the speed is inadequate, consider increasing the `max_workers` value.

In [11]:
dataset = (
    CheckpointableDataset
    .from_sequence(samples)
    .tokenize(tokenizer, max_workers=2)
)
print(next(iter(dataset)))

{'text': 'Hello world!', 'input_ids': [2, 13071, 1190, 49338], 'attention_mask': [1, 1, 1, 1]}


## BOS and EOS

In pretraining, documents are split or combined so that the sequence length matches the specified context length exactly (in epochraft, this process is referred to as *chunking*). Before this, it's necessary to add BOS (beginning of sentence) and EOS (end of sentence) tokens to the start and end of each document.

The `ensure_bos_eos` method checks for the presence of BOS and EOS tokens at the beginning and end, and adds them only if they are absent. However, if the token IDs for BOS and EOS are the same, it won't add the other if one is already present. This ensures that the tokens don't appear consecutively when sentences are concatenated.

The process handles the `target_column` field of each sample (default: `"input_ids"`).

In [12]:
dataset = (
    CheckpointableDataset
    .from_sequence(samples)
    .tokenize(tokenizer, max_workers=2)
    .ensure_bos_eos(tokenizer)
)
print(next(iter(dataset)))

{'text': 'Hello world!', 'input_ids': tensor([    2, 13071,  1190, 49338,     3]), 'attention_mask': [1, 1, 1, 1]}


## Chunking


For pretraining purposes, the original documents vary in length, but they need to be split or combined to match the model's context length exactly.

There are two methods available: `chunk` and `concat_chunk`.

* `chunk` splits each sample at the intervals of `chunk_length`. The remainder parts of the text shorter than `chunk_length` are discarded (this behavior can be modified using the `drop_remainder` argument). More than one document never appears in the same sequence.
* `concat_chunk` concats remainder parts of the text shorter than `chunk_length` with the next sample. No portions are discarded in this method. More than one document may appear in the same sequence.

In [13]:
samples = [
    {"input_ids": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]},
    {"input_ids": [10, 11, 12]},
    {"input_ids": [20, 21, 22, 23, 24, 25, 26]}
]

dataset = CheckpointableDataset.from_sequence(samples).chunk(6)
for sample in dataset:
    print(sample)

{'input_ids': tensor([0, 1, 2, 3, 4, 5])}
{'input_ids': tensor([20, 21, 22, 23, 24, 25])}


In [14]:
dataset = CheckpointableDataset.from_sequence(samples).concat_chunk(6)
for sample in dataset:
    print(sample)

{'input_ids': tensor([0, 1, 2, 3, 4, 5])}
{'input_ids': tensor([ 6,  7,  8,  9, 10, 11])}
{'input_ids': tensor([12, 20, 21, 22, 23, 24])}


# Other

There are several transformation functions available, with the following two being the most essential:

* `batch` groups the specified number of samples into a single batch. You will likely use this function all the time.
* `shuffle` buffers the specified number of samples and randomly samples from this buffer to partially shuffle the order. Although this isn't a full shuffle, it follows the same approach used by successful libraries such as `tf.data` and `webdataset`.

Other useful functions include `take`, `stride`, and `cache`. The `cache` function is particularly handy for short and repeatedly accessed datasets, like validation datasets.

In SFT (Supervised Fine-Tuning), padding is often performed instead of chunking. In such cases, you can use the `pad` function. There's also a `pack_chunk` function designed for packing, as described in the Orca paper. 








# Combinations
You can combine multiple `CheckpointableDataset` instances to create a single `CheckpointableDataset`.

* By using `interleave_datasets`, you can sample from multiple datasets alternately. It's also possible to specify weights for each dataset. This method is primarily used for training data.
* With `concat_datasets`, you can concatenate multiple datasets in sequence. 

In [15]:
from epochraft import interleave_datasets

dataset1: CheckpointableDataset = CheckpointableDataset.from_sequence([{"text": "Hello world!"}], repeat=True)
dataset2: CheckpointableDataset = CheckpointableDataset.from_sequence([{"text": "こんにちは世界！"}], repeat=True)

dataset = interleave_datasets([dataset1, dataset2], weights=[1.5, 1])
it = iter(dataset)
for _ in range(10):
    print(next(it))

{'text': 'Hello world!'}
{'text': 'こんにちは世界！'}
{'text': 'Hello world!'}
{'text': 'こんにちは世界！'}
{'text': 'Hello world!'}
{'text': 'Hello world!'}
{'text': 'こんにちは世界！'}
{'text': 'Hello world!'}
{'text': 'こんにちは世界！'}
{'text': 'Hello world!'}


In [16]:
from epochraft import concat_datasets

dataset1: CheckpointableDataset = CheckpointableDataset.from_sequence([{"text": "Hello world!"}], repeat=True).take(3)
dataset2: CheckpointableDataset = CheckpointableDataset.from_sequence([{"text": "こんにちは世界！"}], repeat=True).take(3)

dataset = concat_datasets([dataset1, dataset2])
for sample in dataset:
    print(sample)

{'text': 'Hello world!'}
{'text': 'Hello world!'}
{'text': 'Hello world!'}
{'text': 'こんにちは世界！'}
{'text': 'こんにちは世界！'}
{'text': 'こんにちは世界！'}


# Distributed Training

When conducting data-parallel (DP) training, it's crucial that each DP worker handles different data. There are two methods to achieve this:

1. Use the `stride` method of `CheckpointableDataset`. This approach uses only every `offset`-th sample modulo `interval` and discards the rest. By setting `interval` to the world size and `offset` to the rank, data can be distributed among the workers. An advantage of this method is that, even if you change the number of DP workers, the order of samples remains unchanged, ensuring high reproducibility. However, it increases the demand on the speed of dataset loading.

2. Provide different URLs to `from_files` for each DP worker. For instance, each DP worker uses `urls[rank::world_size]`, where `urls` is the full URL list of the shards. The advantage here is that each DP worker reads a different file, which is efficient. However, changing the number of DP workers alters the order of the samples, and it's no longer possible to load a previously saved state_dict.

Choosing between the two methods will depend on your training setup and the priorities of your workload, whether it's more crucial to maintain reproducibility or to optimize for loading speed and efficiency.

TODO: write some examples

# Performance Optimization

TODO: write more

### Overlapping GPU Computations and Data Fetching

The primary objective is to guarantee that calculations on GPUs remain on the critical path. Executing actual computations on GPUs and the enqueueing of CUDA kernels from the CPU are asynchronous. It's vital to keep the CUDA queue full to achieve optimal training performance.

To ensure that data preparation doesn't slow down the training process, the specific place in the training loop to fetch the next data batch becomes pivotal. Initiating the fetch immediately after enqueueing many computationally intensive CUDA kernels ensures that the data is prepared well before the CUDA queue is exhausted. Such a strategy effectively offsets the time spent on data retrieval. Procuring data immediately after either the forward or backward pass is probably a good strategy, like the following.

```python
it = iter(dataset)
batch = next(it)  # Fetching the first batch
for _ in range(n_steps):
    ...
    loss = model(batch["input_ids"])
    batch = next(it)  # Fetching the next batch
    loss.backward()
    ...
```

By the way, by using pinned memory, data transfer operations between GPUs and CPUs can also be made asynchronous.


### Reducing Fetch Time

If the time taken by `next(it)` becomes longer than GPU computations, performance enhancements are needed.

If data loading is slow, try increasing the `n_active_shards` value in `from_files`.

Additionally, if transforms like tokenization are slow, make sure to utilize parallelism. Ensure also that the related arguments are set correctly.
Besides increasing the values of `max_workers` and `prefetch_factor`, you can also set `ordered` to `False`. In this scenario, the dataset's order will not be maintained. Although this boosts throughput, it compromises reproducibility.

For validation datasets, which use the same data repeatedly, it's advisable to use the `cache` method.





### Reducing Start Time

Depending on the configuration, instantiating the dataset or iterator, as well as fetching the first batch, can take a very long time. If it feels like it's taking forever, try the following solutions:

* Reduce the `buffer_size` argument of the `shuffle`. Since the shuffle buffer must be filled before starting, a large buffer size can significantly slow down initialization.
* Set the `executor_type` of parallel transforms to `"thread"`. Launching child processes can be slow, whereas initiating child threads is faster. Furthermore, many operations like IO and tokenization are not hampered by GIL, making threads suitable for these tasks.
* Reduce the value of `max_workers` in parallel transforms.








# Working with PyTorch's DataLoader

The `CheckpointableDataset` class inherits from `torch.utils.data.IterableDataset`. Therefore, it can be used with the `torch.utils.data.DataLoader` if you want.

However, this is *not recommended* for the following reasons:

1. **Loss of Checkpointing**: The `DataLoader` class is designed to handle all data preparation using child processes. As a result, it becomes impossible to retrieve and save the `state_dict` from the main process.
2. **It's Unnecessary**: Epochraft is designed to carry out slow operations, such as file reading and tokenization, in child threads or processes. Thus, there is no merits to use `DataLoader`.

Given these reasons, while it's technically possible to use `CheckpointableDataset` with `DataLoader`, it's best to avoid doing so to make the most of its features and maintain the intended workflow.




