Skip to content

Commit

Permalink
Merge branch 'master' into arashb/opt
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffra committed Aug 9, 2022
2 parents 84d954f + 28dfca8 commit 545bb42
Show file tree
Hide file tree
Showing 19 changed files with 231 additions and 55 deletions.
7 changes: 5 additions & 2 deletions .github/workflows/amd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ jobs:
python --version
which hipcc
hipcc --version
pip install --upgrade pip
pip uninstall --yes torch torchvision
pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.1.1
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
sudo apt-get update
Expand Down Expand Up @@ -63,5 +66,5 @@ jobs:
run: |
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
cd tests
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -x -n 4 unit/
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -x -m 'sequential' unit/
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --verbose -n 4 unit/{comm,inference,monitor,ops,profiling,runtime}
#TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --verbose -m 'sequential' unit/{comm,inference,monitor,ops,profiling,runtime}
2 changes: 1 addition & 1 deletion .github/workflows/nv-lightning-v100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
nvcc --version
pip install --upgrade pip
pip uninstall --yes torch torchvision
pip install torch==1.8.2+cu111 torchvision==0.9.2+cu111 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
Expand Down
12 changes: 8 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
* Train/Inference dense or sparse models with billions or trillions of parameters
* Achieve excellent system throughput and efficiently scale to thousands of GPUs
* Train/Inference on resource constrained GPU systems
* Achieve unprecedented low latency and high thoughput for inference
* Achieve unprecedented low latency and high throughput for inference
* Achieve extreme compression for an unparalleled inference latency and model size reduction with low costs

---
Expand All @@ -43,7 +43,7 @@ DeepSpeed offers a confluence of system innovations, that has made large scale D

## DeepSpeed-Inference

DeepSpeed brings together innovations in parallelism technology such as tensor, pipeline, expert and ZeRO-parallelism, and combines them with high performance custom inference kernels, communication optimizations and heterogeneous memory technologies to enable inference at an unprecedented scale, while achieving unparalleled latency, thoughput and cost reduction. This systematic composition of system technologies for inference falls under the inference pillar. Learn more: [DeepSpeed-Inference](https://www.deepspeed.ai/inference)
DeepSpeed brings together innovations in parallelism technology such as tensor, pipeline, expert and ZeRO-parallelism, and combines them with high performance custom inference kernels, communication optimizations and heterogeneous memory technologies to enable inference at an unprecedented scale, while achieving unparalleled latency, throughput and cost reduction. This systematic composition of system technologies for inference falls under the inference pillar. Learn more: [DeepSpeed-Inference](https://www.deepspeed.ai/inference)


## DeepSpeed-Compression
Expand All @@ -56,7 +56,7 @@ To further increase the inference efficiency, DeepSpeed offers easy-to-use and f

## DeepSpeed Library

The [DeepSpeed](https://github.com/microsoft/deepspeed) library (this repository) implements and packages the innovations and technologies in DeepSpeed Training, Inference and Compression Pillars into a single easy-to-use, open-sourced repository. It allows for easy composition of multitude of features within a single training, infernece or compression pipeline. The DeepSpeed Library is heavily adopted by the DL community, and has been used to enable some of the most powerful models (see [DeepSpeed Adoption](#deepspeed-adoption)).
The [DeepSpeed](https://github.com/microsoft/deepspeed) library (this repository) implements and packages the innovations and technologies in DeepSpeed Training, Inference and Compression Pillars into a single easy-to-use, open-sourced repository. It allows for easy composition of multitude of features within a single training, inference or compression pipeline. The DeepSpeed Library is heavily adopted by the DL community, and has been used to enable some of the most powerful models (see [DeepSpeed Adoption](#deepspeed-adoption)).

## Model Implementations for Inference (MII)

Expand All @@ -80,8 +80,12 @@ DeepSpeed has been used to train many different large-scale models, below is a l
* [Megatron-Turing NLG (530B)](https://www.microsoft.com/en-us/research/blog/using-deepspeed-and-megatron-to-train-megatron-turing-nlg-530b-the-worlds-largest-and-most-powerful-generative-language-model/)
* [Jurassic-1 (178B)](https://uploads-ssl.webflow.com/60fd4503684b466578c0d307/61138924626a6981ee09caf6_jurassic_tech_paper.pdf)
* [BLOOM (176B)](https://huggingface.co/blog/bloom-megatron-deepspeed)
* [GLM (130B)](https://github.com/THUDM/GLM-130B)
* [YaLM (100B)](https://github.com/yandex/YaLM-100B)
* [GPT-NeoX (20B)](https://github.com/EleutherAI/gpt-neox)
* [AlexaTM (20B)](https://www.amazon.science/blog/20b-parameter-alexa-model-sets-new-marks-in-few-shot-learning)
* [Turing NLG (17B](https://www.microsoft.com/en-us/research/blog/turing-nlg-a-17-billion-parameter-language-model-by-microsoft/)
* [METRO-LM (5.4B)](https://arxiv.org/pdf/2204.06644.pdf)

DeepSpeed has been integrated with several different popular open-source DL frameworks such as:

Expand All @@ -90,7 +94,7 @@ DeepSpeed has been integrated with several different popular open-source DL fram
<img src="docs/assets/images/transformers-light.png#gh-light-mode-only" width="250px"><img src="docs/assets/images/transformers-dark.png#gh-dark-mode-only" width="250px"> | [Transformers with DeepSpeed](https://huggingface.co/docs/transformers/main/main_classes/deepspeed) |
| <img src="docs/assets/images/accelerate-light.png#gh-light-mode-only" width="250px"><img src="docs/assets/images/accelerate-dark.png#gh-dark-mode-only" width="250px"> | [Accelerate with DeepSpeed](https://huggingface.co/docs/accelerate/main/en/deepspeed) |
| <img src="docs/assets/images/lightning-light.svg#gh-light-mode-only" width="200px"><img src="docs/assets/images/lightning-dark.svg#gh-dark-mode-only" width="200px"> | [Lightning with DeepSpeed](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.strategies.DeepSpeedStrategy.html) |
| <img src="docs/assets/images/mosaicml.svg" width="200px"> | [MosaicML with DeepSpeed](https://docs.mosaicml.com/en/v0.8.0/trainer/using_the_trainer.html?highlight=deepspeed#deepspeed-integration) |
| <img src="docs/assets/images/mosaicml.svg" width="200px"> | [MosaicML with DeepSpeed](https://docs.mosaicml.com/en/latest/trainer/using_the_trainer.html?highlight=deepspeed#deepspeed-integration) |

---

Expand Down
2 changes: 1 addition & 1 deletion csrc/lamb/fused_lamb_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ __device__ void reduce_block_in_shared_memory(T* s_a, T* s_b, T* g_a, T* g_b)

cg::sync(cta);

#if (__CUDA_ARCH__ >= 300)
#if (__CUDA_ARCH__ >= 300) || (defined(__HIP_PLATFORM_HCC__) && HIP_VERSION >= 502)
if (tid < 32) {
cg::coalesced_group active = cg::coalesced_threads();

Expand Down
1 change: 1 addition & 0 deletions deepspeed/autotuning/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import time
import datetime
import math
import hjson

from ..runtime.config_utils import dict_raise_error_on_duplicate_keys
from ..runtime.constants import *
Expand Down
8 changes: 5 additions & 3 deletions deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,8 @@ def _replace(child, name, conv_linear_layer):
bias_data = None if child.bias is None else mp_replace.copy(
new_bias,
child.bias.data).to(torch.cuda.current_device())
return LinearLayer(data.to(torch.cuda.current_device()), bias_data)
return LinearLayer(weight=data.to(torch.cuda.current_device()),
bias=bias_data)

def _slice_embedding(child, name, conv_linear_layer):
mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group)
Expand Down Expand Up @@ -807,8 +808,9 @@ def replace_fn(child, _policy, layer_id=0):
assert world_size >= ckpt_mp_size,\
"Currently, merging checkpoints is not supported (when world_size is smaller than #checkpoints)!"
checkpoint_stride = world_size // ckpt_mp_size
pbar = tqdm.tqdm(total=num_checkpoints,
desc=f"Loading {num_checkpoints} checkpoint shards")
if not deepspeed.comm.is_initialized() or deepspeed.comm.get_rank() == 0:
pbar = tqdm.tqdm(total=num_checkpoints,
desc=f"Loading {num_checkpoints} checkpoint shards")
for i in range(num_checkpoints):
if not deepspeed.comm.is_initialized() or deepspeed.comm.get_rank() == 0:
pbar.update(1)
Expand Down
4 changes: 4 additions & 0 deletions deepspeed/ops/transformer/inference/transformer_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,6 +821,10 @@ def forward(self,
past_key_value=None):
get_present = (get_present or get_key_value or use_cache)
input_mask = input_mask if attention_mask is None else attention_mask

# We set the prev key/value to None when there is a prompt
if input.shape[1] > 1:
self.layer_past = None
layer_past = layer_past if layer_past is not None else self.layer_past
head_mask = layer_head_mask if layer_head_mask is not None else head_mask

Expand Down
20 changes: 12 additions & 8 deletions deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,8 @@ def __init__(self, config: Union[str, dict], mpu=None):
self._param_dict[TRAIN_MICRO_BATCH_SIZE_PER_GPU] = micro_batch_size
self._param_dict[GRADIENT_ACCUMULATION_STEPS] = gradient_accu_steps

self._initialize_params(self._param_dict)
# Pass a copy so that user json is unmodified, e.g. for logging
self._initialize_params(copy.copy(self._param_dict))
self._configure_train_batch_size()
self._do_sanity_check()

Expand Down Expand Up @@ -971,13 +972,7 @@ def _do_sanity_check(self):

self._do_warning_check()

def print(self, name):
logger.info("{}:".format(name))
for arg in sorted(vars(self)):
if arg != "_param_dict":
dots = "." * (29 - len(arg))
logger.info(" {} {} {}".format(arg, dots, getattr(self, arg)))

def print_user_config(self):
logger.info(" json = {}".format(
json.dumps(
self._param_dict,
Expand All @@ -988,6 +983,15 @@ def print(self, name):
":"),
)))

def print(self, name):
logger.info("{}:".format(name))
for arg in sorted(vars(self)):
if arg != "_param_dict":
dots = "." * (29 - len(arg))
logger.info(" {} {} {}".format(arg, dots, getattr(self, arg)))

self.print_user_config()

def _do_error_check(self):
assert (
self.train_micro_batch_size_per_gpu
Expand Down
10 changes: 7 additions & 3 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2292,9 +2292,6 @@ def sparse_allreduce_bucket(self, bucket, dp_group):
return sparse_list

def sparse_allreduce(self, sparse, dp_group):
# Pre-divide for fp16 stability
sparse.values.mul_(1.0 / dist.get_world_size(group=dp_group))

original_data_type = sparse.values.dtype
if self.communication_data_type != sparse.values.dtype:
if self.communication_data_type in (torch.float16, torch.bfloat16):
Expand All @@ -2306,6 +2303,13 @@ def sparse_allreduce(self, sparse, dp_group):
indices = sparse.indices
values = sparse.values

if self.postscale_gradients():
if self.gradient_average:
values.mul_(self.gradient_predivide_factor() /
dist.get_world_size(group=dp_group))
else:
values.mul_(1. / dist.get_world_size(group=dp_group))

indices_device_list = self.sparse_all_gather(indices, dp_group)
values_device_list = self.sparse_all_gather(values, dp_group)

Expand Down
8 changes: 3 additions & 5 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1384,18 +1384,16 @@ def set_none_gradients_to_zero(self, i, partition_id):

######################Reduction Related Methods##############################

def allreduce_bucket(self,
bucket,
communication_data_type=torch.float16,
rank=None,
log=None):
def allreduce_bucket(self, bucket, rank=None, log=None):
rank = None
tensor = self.flatten(bucket)

tensor_to_allreduce = tensor

if pg_correctness_test:
communication_data_type = torch.float32
else:
communication_data_type = self.communication_data_type

if communication_data_type != tensor.dtype:
tensor_to_allreduce = tensor.to(communication_data_type)
Expand Down
17 changes: 11 additions & 6 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,9 +966,13 @@ def average_tensor(self, tensor):
if not self.ipg_bucket_has_moe_params:
tensor.div_(dist.get_world_size(group=self.dp_process_group))

tensor_to_reduce = tensor
if self.communication_data_type != tensor.dtype:
tensor_to_reduce = tensor.to(self.communication_data_type)

async_handles = []
for i, (dst, bucket_offset, numel) in enumerate(rank_and_offsets):
grad_slice = tensor.narrow(0, int(bucket_offset), int(numel))
grad_slice = tensor_to_reduce.narrow(0, int(bucket_offset), int(numel))
# if dist.get_rank() == 0:
# print(f"Rank {dist.get_rank()} rank offset id {i} real dp size {dist.get_world_size(group=real_dp_process_group[i])} and dst: {dst}")
# dist.barrier()
Expand All @@ -983,6 +987,9 @@ def average_tensor(self, tensor):
for handle in async_handles:
handle.wait()

if self.communication_data_type != tensor.dtype:
tensor.copy_(tensor_to_reduce)

##############################################################################
############################# CPU Offload Methods#############################
##############################################################################
Expand Down Expand Up @@ -1337,18 +1344,16 @@ def set_none_gradients_to_zero(self, i, partition_id):
param.grad = torch.zero_like(param)

######################Reduction Related Methods##############################
def allreduce_bucket(self,
bucket,
communication_data_type=torch.float16,
rank=None,
log=None):
def allreduce_bucket(self, bucket, rank=None, log=None):
rank = None
tensor = self.flatten(bucket)

tensor_to_allreduce = tensor

if pg_correctness_test:
communication_data_type = torch.float32
else:
communication_data_type = self.communication_data_type

if communication_data_type != tensor.dtype:
tensor_to_allreduce = tensor.to(communication_data_type)
Expand Down
8 changes: 4 additions & 4 deletions docs/_pages/config-json.md
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ Enabling and configuring ZeRO memory optimizations

| Description | Default |
| ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- |
| Stage 2 optimization for CPU offloading that parallelizes gradient copying to CPU memory among ranks by fine-grained gradient partitioning. Performance benefit grows with gradient accumulation steps (more copying between optimizer steps) or GPU count (increased parallelism). | `False` |
| Stage 1 and 2 optimization for CPU offloading that parallelizes gradient copying to CPU memory among ranks by fine-grained gradient partitioning. Performance benefit grows with gradient accumulation steps (more copying between optimizer steps) or GPU count (increased parallelism). | `False` |

***offload_param***: [dictionary]

Expand All @@ -439,7 +439,7 @@ Enabling and configuring ZeRO memory optimizations

| Description | Default |
| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ------- |
| Enable offloading of optimizer state to CPU or NVMe, and optimizer computation to CPU. This frees up GPU memory for larger models or batch sizes. Valid only with stage 2 and 3. See [here](#optimizer-offloading) for more details. | `False` |
| Enable offloading of optimizer state to CPU or NVMe, and optimizer computation to CPU. This frees up GPU memory for larger models or batch sizes. Valid for ZeRO stage 1, 2, 3. See [here](#optimizer-offloading) for more details. | `False` |

***stage3_max_live_parameters***: [integer]

Expand Down Expand Up @@ -481,7 +481,7 @@ Enabling and configuring ZeRO memory optimizations

| Description | Default |
| ------------------------------------------------------------------------------------------------------------------------------------------------- | ------- |
| Enable offloading of optimizer memory and computation to CPU. This frees up GPU memory for larger models or batch sizes. Valid only with stage 2. | `False` |
| Enable offloading of optimizer memory and computation to CPU. This frees up GPU memory for larger models or batch sizes. Valid with stage 1 and 2. | `False` |


### Parameter offloading
Expand Down Expand Up @@ -536,7 +536,7 @@ Note that if the value of "device" is not specified or not supported, an asserti
| Number of parameter elements to maintain in CPU memory when offloading to NVMe is enabled. | 1e9 |

### Optimizer offloading
Enabling and configuring ZeRO optimization of offloading optimizer computation to CPU and state to CPU/NVMe. CPU offloading is available with ZeRO stage 2 or 3. NVMe offloading is available only with ZeRO stage 3.
Enabling and configuring ZeRO optimization of offloading optimizer computation to CPU and state to CPU/NVMe. CPU offloading is available with ZeRO stage 1, 2, 3. NVMe offloading is available only with ZeRO stage 3.
Note that if the value of "device" is not specified or not supported, an assertion will be triggered.
```json
"offload_optimizer": {
Expand Down

0 comments on commit 545bb42

Please sign in to comment.