Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(self,
metric_types=[],
metric_dtypes=[],
save_path="./",
collate_fn=None,
custom_map_init=None,
custom_map_update=None,
custom_map_finalize=None,
Expand All @@ -46,6 +47,7 @@ def __init__(self,
self.metric_types = metric_types
self.metric_dtypes = metric_dtypes
self.save_path = save_path
self.collate_fn = collate_fn
self.custom_map_init = custom_map_init
self.custom_map_update = custom_map_update
self.custom_map_finalize = custom_map_finalize
Expand Down Expand Up @@ -153,11 +155,19 @@ def run_map_helper(self, thread_id):
sampler = BatchSampler(SequentialSampler(thread_dataset),
batch_size=self.batch_size,
drop_last=False)
iterator = iter(
DataLoader(thread_dataset,
batch_sampler=sampler,
num_workers=0,
pin_memory=False))
if self.collate_fn is None:
iterator = iter(
DataLoader(thread_dataset,
batch_sampler=sampler,
num_workers=0,
pin_memory=False))
else:
iterator = iter(
DataLoader(thread_dataset,
batch_sampler=sampler,
num_workers=0,
collate_fn=self.collate_fn,
pin_memory=False))
if self.custom_map_init is None:
metric_results = self.init_metric_results(thread_id,
self.metric_names,
Expand Down
15 changes: 11 additions & 4 deletions deepspeed/runtime/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,17 @@ def __next__(self):

def _create_dataloader(self):
if self.curriculum_learning_enabled:
self.dataloader = DataLoader(self.dataset,
pin_memory=self.pin_memory,
batch_sampler=self.data_sampler,
num_workers=self.num_local_io_workers)
if self.collate_fn is None:
self.dataloader = DataLoader(self.dataset,
pin_memory=self.pin_memory,
batch_sampler=self.data_sampler,
num_workers=self.num_local_io_workers)
else:
self.dataloader = DataLoader(self.dataset,
pin_memory=self.pin_memory,
batch_sampler=self.data_sampler,
collate_fn=self.collate_fn,
num_workers=self.num_local_io_workers)
self.data_iterator = iter(self.dataloader)
return self.dataloader
else:
Expand Down
14 changes: 12 additions & 2 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1697,8 +1697,8 @@ def deepspeed_io(self,
deepspeed_io_timer = self.tput_timer

# If mpu is provided, forward world size and parallel rank to sampler.
data_parallel_world_size = None
data_parallel_rank = None
data_parallel_world_size = self.dp_world_size
data_parallel_rank = self.global_rank
if self.mpu is not None:
data_parallel_world_size = self.mpu.get_data_parallel_world_size()
data_parallel_rank = self.mpu.get_data_parallel_rank()
Expand Down Expand Up @@ -3201,6 +3201,9 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}):
global_expert_id,
tag,
self.mpu)
if self.random_ltd_enabled():
expert_state_dict = remove_random_ltd_state_dict(
expert_state_dict)
self.checkpoint_engine.save(expert_state_dict, moe_save_path)
moe_layer_id += 1

Expand Down Expand Up @@ -3237,6 +3240,13 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}):
'lr_scheduler':
self.lr_scheduler.state_dict()
if self.lr_scheduler is not None else None,
'data_sampler':
self.training_dataloader.data_sampler.state_dict() if
(self.training_dataloader is not None
and self.curriculum_learning_enabled()) else None,
'random_ltd':
self.random_ltd_scheduler.state_dict()
if self.random_ltd_enabled() else None,
'sparse_tensor_module_names':
self.sparse_tensor_module_names,
'skipped_steps':
Expand Down
33 changes: 17 additions & 16 deletions docs/_posts/2022-12-12-data-efficiency.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ To solve the limitation of existing solutions, we design and implement a general

## Evaluation Results

Using this general and extensible curriculum learning solution for GPT-3 and BERT-Large model pretraining, we are able to easily analyze and index the huge training data based on up to 7 difficulty metrics and enable better data and training efficiency. For GPT-3 pretraining, our solution with the best difficulty metric (combination of truncation-based sequence length and vocabulary rarity) achieves 1.5x data and training cost saving while still maintaining model quality as baseline (Table 1 Case (8) vs. (1)). For BERT-Large pretraining, our solution with the best difficulty metric (vocabulary rarity) achieves 1.4x saving while still maintaining model quality (Table 2 Case (7) vs. (1)). On the other hand, our solutions can further improve model quality when using the same amount of data as baseline (Table 1 Case (2) to (6), Table 2 Case (2) to (6)).
Using this general and extensible curriculum learning solution for GPT-3 and BERT-Large model pretraining, we are able to easily analyze and index the huge training data based on up to 7 difficulty metrics and enable better data and training efficiency. For GPT-3 pretraining, our solution with the best difficulty metric (combination of truncation-based sequence length and vocabulary rarity) achieves 1.5x data and training cost saving while still maintaining model quality as baseline (Table 1 Case (8) vs. (1)). For BERT-Large pretraining, our solution with the best difficulty metric (vocabulary rarity) achieves 1.5x saving while still maintaining model quality (Table 2 Case (8) vs. (1)). On the other hand, our solutions can further improve model quality when using the same amount of data as baseline (Table 1 Case (2) to (6), Table 2 Case (2) to (6)).

| **Case** | **Pretrain data** | **Avg 0-shot accuracy** | **Avg 10-shot accuracy** |
| ---------- |---------- |---------- |---------- |
Expand All @@ -56,17 +56,18 @@ Using this general and extensible curriculum learning solution for GPT-3 and BER

*Table 1: GPT-3 1.3B pretraining data consumption and average evaluation accuracy on 19 tasks.*

| **Case** | **Pretrain data** | **Avg finetune accuracy** |
| **Case** | **Pretrain data** | **GLUE finetune score** |
| ---------- |---------- |---------- |
| (1) Baseline | 1049B | 85.42 |
| (2) CL truncation-based sequence length | 1049B | 85.77 |
| (3) CL reorder-based sequence length | 1049B | 85.46 |
| (4) CL vocabulary rarity | 1049B | **86.13** |
| (5) CL combining (2) and (4) | 1049B | 85.8 |
| (6) CL combining (3) and (4) | 1049B | 85.61 |
| (7) CL vocabulary rarity | **734B (1.4x)** | 85.59 |
| (1) Baseline | 1049B | 87.29 |
| (2) CL truncation-based sequence length | 1049B | 87.31 |
| (3) CL reorder-based sequence length | 1049B | 87.48 |
| (4) CL vocabulary rarity | 1049B | 87.36 |
| (5) CL combining (2) and (4) | 1049B | **87.60** |
| (6) CL combining (3) and (4) | 1049B | 87.06 |
| (7) Baseline | 703B (1.5x) | 87.19 |
| (8) CL combining (2) and (4) | **703B (1.5x)** | 87.29 |

*Table 2: BERT-Large pretraining data consumption and average finetuning accuracy on 4 tasks.*
*Table 2: BERT-Large pretraining data consumption and average GLUE finetuning score on 8 tasks.*

# Efficient Data Routing via Random Layerwise Token Dropping

Expand All @@ -88,7 +89,7 @@ Random-LTD is simple yet very effective. Particularly, compared to other existin

## Evaluation Results

Thanks to its great flexibility, we were able to apply random-LTD method to broader applications, including BERT and GPT pretraining as well as ViT and GPT finetuning tasks. For all cases, random-LTD achieves similar model quality as baseline while using less data, and/or achieve better model quality while using the same amount of data (Table 3 to 6). For GPT-3 and BERT-Large pretraining, random-LTD achieves 1.5x data saving while still maintaining the same model quality. For GPT-3 we also tested random-LTD with full data which further improves the model quality compared to baseline.
Thanks to its great flexibility, we were able to apply random-LTD method to broader applications, including BERT and GPT pretraining as well as ViT and GPT finetuning tasks. For all cases, random-LTD achieves similar model quality as baseline while using less data, and/or achieve better model quality while using the same amount of data (Table 3 to 6). For GPT-3 and BERT-Large pretraining, random-LTD achieves 1.5-2x data saving while still maintaining the same model quality. For GPT-3 we also tested random-LTD with full data which further improves the model quality compared to baseline.

| **Case** | **Pretrain data** | **Avg 0-shot accuracy** |
| ---------- |---------- |---------- |
Expand All @@ -98,12 +99,12 @@ Thanks to its great flexibility, we were able to apply random-LTD method to broa

*Table 3: GPT-3 1.3B pretraining data consumption and average evaluation accuracy on 19 tasks.*

| **Case** | **Pretrain data** | **Avg finetune accuracy** |
| **Case** | **Pretrain data** | **GLUE finetune score** |
| ---------- |---------- |---------- |
| (1) Baseline | 1049B | 85.42 |
| (2) Random-LTD | **723B (1.5x)** | **86.42** |
| (1) Baseline | 1049B | 87.29 |
| (2) Random-LTD | **524B (2x)** | **87.32** |

*Table 4: BERT-Large pretraining data consumption and average finetuning accuracy on 4 tasks.*
*Table 4: BERT-Large pretraining data consumption and average GLUE finetuning score on 8 tasks.*

| **Case** | **Train data** | **ImageNet Top-1 Acc** |
| ---------- |---------- |---------- |
Expand All @@ -123,7 +124,7 @@ Thanks to its great flexibility, we were able to apply random-LTD method to broa

The curriculum learning and random-LTD techniques are complementary. Inside DeepSpeed Data Efficiency framework, we seamlessly compose the two techniques as shown in Figure 2 above, where curriculum learning helps to sample the next data batch and random-LTD helps to decide how to route each sampled data inside the model. DeepSpeed Data Efficiency solves several complexities when composing the two techniques so that users can easily apply each technique or both to their training pipeline. The composability of DeepSpeed Data Efficiency also applies to data sampling and routing techniques in general, so that it provides a platform to implement and compose additional data efficiency techniques.

The composed DeepSpeed Data Efficiency solution leverages both data efficiency techniques and achieves even better data and training efficiency. Take the GPT-3 pretraining task as an example, composing CL and random-LTD, with 100% data, leads to the best model quality in our experiments (Table 7 Case (1) to (4)). When pretraining with 50% data, the baseline training results in worse zero-shot and 10-shot evaluation accuracy, and using either CL or random-LTD can only recover part of the 10-shot accuracy loss. On the other hand, the composed data efficiency solution achieves the same or better accuracy results as baseline with 100% data, demonstrating a 2x data and 2x time saving (Case (5) to (8)).
The composed DeepSpeed Data Efficiency solution leverages both data efficiency techniques and achieves even better data and training efficiency. Take the GPT-3 pretraining task as an example, composing CL and random-LTD, with 100% data, leads to the best model quality in our experiments (Table 7 Case (1) to (4)). When pretraining with 50% data, the baseline training results in worse zero-shot and 10-shot evaluation accuracy, and using either CL or random-LTD can only recover part of the 10-shot accuracy loss. On the other hand, the composed data efficiency solution achieves the same or better accuracy results as baseline with 100% data, demonstrating a 2x data and 2x time saving (Case (5) to (8)). Similar benefit such as 2x data saving was also observed when applying our solution to BERT pretraining.

| **Case** | **Pretrain data** | **Pretrain time (on 64 V100)** | **Avg 0-shot accuracy** | **Avg 10-shot accuracy** |
| ---------- |---------- |---------- |---------- |---------- |
Expand Down
2 changes: 1 addition & 1 deletion docs/_tutorials/curriculum-learning.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ tags: training pre-training
---

**Watch out!**
On 12/12/2022, we released DeepSpeed Data Efficiency Library which provides a more general curriculum learning support. This legacy curriculum learning feature below is still supported but we recommend to use the Data Efficiency Library.
On 12/12/2022, we released DeepSpeed Data Efficiency Library which provides a more general curriculum learning support. This legacy curriculum learning feature below is still supported but we recommend to use the Data Efficiency Library ([tutorial](/tutorials/data-efficiency/)).
{: .notice--warning}

**Note:**
Expand Down
18 changes: 12 additions & 6 deletions docs/_tutorials/data-efficiency.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ The `examples/data_efficiency` directory in our [Megatron-DeepSpeed repo](https:

**Eval/finetuning** `examples/data_efficiency/gpt/eval/` and `examples/data_efficiency/bert/finetune` include the example scripts for GPT-3 model's zero-/few-shot evaluation and BERT model's finetuning. Our [paper](https://arxiv.org/abs/2212.03597) includes the reference eval/finetune results if you follow our example scripts to perform the pretraining/eval/finetuning.

#### 1.3.2 GPT-2 finetuning
The `data_efficiency/gpt_finetuning` directory in our [DeepSpeedExamples repo](https://github.com/microsoft/DeepSpeedExamples) includes our examples of how to apply curriculum learning to GPT-2 finetuning. `data_efficiency/gpt_finetuning/finetune/ds_finetune_gpt2_run.sh` is the example finetuning script. For CL metrics that require data analysis (e.g., the vocabulary rarity metric), you need to first use ```data_efficiency/gpt_finetuning/finetune/ds_analyze_gpt_data_*``` to analyze and index the dataset, similar to the GPT-3 pre-training case described above in 1.3.1.

## 2. Random layerwise token dropping (random-LTD)

### 2.1 What is random-LTD
Expand All @@ -56,18 +59,18 @@ One can run our GPT finetuning example by:

```shell
DeepSpeedExamples/data_efficiency/gpt_finetuning$ pip install -r requirement.txt
DeepSpeedExamples/data_efficiency/gpt_finetuning$ bash ./bash_script/run_base.sh
DeepSpeedExamples/data_efficiency/gpt_finetuning$ bash ./bash_script/run_medium.sh
DeepSpeedExamples/data_efficiency/gpt_finetuning$ bash ./bash_script/run_base_random_ltd.sh
DeepSpeedExamples/data_efficiency/gpt_finetuning$ bash ./bash_script/run_medium_random_ltd.sh
```

And the reference final result is:

```shell
For run_base.sh:
'step':1047, 'ppl': 23.9859276900444, 'seq_len': 1024, 'consume layer-tokens': 19534848
For run_base_random_ltd.sh:
End of training epoch 3 step 1344 consumed_token 2148032 best perplexity 22.552324221233757 time 0.17486039188173083 hr

For run_medium.sh:
'step':1047, 'ppl': 18.569010769928337, 'seq_len': 1024, 'consume layer-tokens': 35567104
For run_medium_random_ltd.sh:
End of training epoch 3 step 1373 consumed_token 2147024 best perplexity 17.332243199130996 time 0.4661190489927928 hr
```

One can run our ViT finetuning example by:
Expand All @@ -92,3 +95,6 @@ iter 5474 | LR [0.0001]| val_acc 97.97000122070312 | layer_token 305784192
The `examples/data_efficiency` directory in our [Megatron-DeepSpeed repo](https://github.com/microsoft/Megatron-DeepSpeed) includes our examples of how to compose curriculum learning random-LTD, and apply both of them to GPT-3 and BERT pretraining.

The changes needed are the same as described in previous two sections, since DeepSpeed Data Efficiency already handles the complexity when composing the two techniques. However, one thing to note is that since both random-LTD and some of the curriculum learning metrics will change the sequence length, it could require some extra code to calculate the effective sequence length at each step. We provide an example implementation of this change in `megatron/training.py` function `train` where we calculate the `actual_seq_length`.

#### 3.2 GPT-2 finetuning
The `data_efficiency/gpt_finetuning` directory in our [DeepSpeedExamples repo](https://github.com/microsoft/DeepSpeedExamples) includes our examples of how to compose curriculum learning random-LTD for GPT-2 finetuning. `data_efficiency/gpt_finetuning/finetune/ds_finetune_gpt2_run.sh` is the example finetuning script.
7 changes: 3 additions & 4 deletions tests/unit/runtime/test_data_efficiency.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_curriculum_learning(self):
"num_workers": 0,
"curriculum_learning": {
"enabled": True,
"data_cluster_path": "data_clusters",
"data_cluster_path": "/tmp",
"curriculum_metrics": {
"dummy_metric": {
"index_to_sample_path": "dummy",
Expand Down Expand Up @@ -104,9 +104,8 @@ def data_post_process(data, data_sampler_state_dict):
training_data=dataset,
model_parameters=model.parameters(),
mpu=MPU(1))
if model.mpu.get_data_parallel_rank(
) == 0 and not os.path.exists('data_clusters'):
os.makedirs('data_clusters')
if model.mpu.get_data_parallel_rank() == 0 and not os.path.exists('/tmp'):
os.makedirs('/tmp')
model.set_data_post_process_func(data_post_process)
for n, batch in enumerate(data_loader):
x = batch[0].to(torch.cuda.current_device())
Expand Down