Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ZeRO++ #3784

Merged
merged 21 commits into from
Jun 23, 2023
Merged

ZeRO++ #3784

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
1e7a41c
ZeRO++ clean release (#526)
samadejacobs Jun 8, 2023
440d4fb
Merge remote-tracking branch 'origin' into HeyangQin/staging-zero-pp-v1
HeyangQin Jun 21, 2023
b011295
Update zeropp.md
GuanhuaWang Jun 21, 2023
97726c9
catch an edge case where MiCS and ZeRO++ both enable
HeyangQin Jun 21, 2023
3a37eca
Merge branch 'HeyangQin/staging-zero-pp-v1' of https://github.com/mic…
HeyangQin Jun 21, 2023
9f70a56
fix format
HeyangQin Jun 21, 2023
7c764d8
fix assert
HeyangQin Jun 22, 2023
36f082f
Merge branch 'master' into HeyangQin/staging-zero-pp-v1
HeyangQin Jun 22, 2023
9d0d5d8
rearrange positional param
HeyangQin Jun 22, 2023
5a300bd
Merge branch 'HeyangQin/staging-zero-pp-v1' of https://github.com/mic…
HeyangQin Jun 22, 2023
1111e10
Merge branch 'master' into HeyangQin/staging-zero-pp-v1
HeyangQin Jun 22, 2023
520d1d4
Merge branch 'master' into HeyangQin/staging-zero-pp-v1
loadams Jun 22, 2023
b77fdad
fix images and tags
jeffra Jun 22, 2023
559c730
update news items
jeffra Jun 22, 2023
13c52be
update link for msr blog
jeffra Jun 22, 2023
489c409
update links
HeyangQin Jun 22, 2023
29f26af
add default values for AllGatherCoalescedHandle
HeyangQin Jun 22, 2023
8a31056
Merge branch 'master' into HeyangQin/staging-zero-pp-v1
HeyangQin Jun 22, 2023
0140ae1
Merge branch 'master' into HeyangQin/staging-zero-pp-v1
HeyangQin Jun 23, 2023
60806d4
Merge branch 'master' into HeyangQin/staging-zero-pp-v1
conglongli Jun 23, 2023
a5bf0ba
Merge branch 'master' into HeyangQin/staging-zero-pp-v1
jeffra Jun 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,11 @@
## Latest News
<b> <span style="color:orange" > DeepSpeed empowers ChatGPT-like model training with a single click, offering 15x speedup over SOTA RLHF systems with unprecedented cost reduction at all scales; [learn how](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat)</span>.</b>

* ***[2023/04] 🚀 [DeepSpeed Chat: Easy, Fast and Affordable RLHF Training of ChatGPT-like Models at All Scales](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat)*** [[English](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/README.md)] [[中文](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/japanese/README.md)]🚀
* [2023/06] [ZeRO++: A leap in speed for LLM and chat model training with 4X less communication](https://www.microsoft.com/en-us/research/blog/deepspeed-zero-a-leap-in-speed-for-llm-and-chat-model-training-with-4x-less-communication/)
* [2023/04] 🚀 [DeepSpeed Chat: Easy, Fast and Affordable RLHF Training of ChatGPT-like Models at All Scales](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat) [[English](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/README.md)] [[中文](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/japanese/README.md)]🚀
* [2023/03] [Scaling Large-Scale Generative Mixture-of-Expert Multimodal Model With VL-MoE](https://www.deepspeed.ai/2023/03/30/multi-modal.html)
* [2023/02] [Automatic Tensor Parallelism: Enables tensor parallelism by default without an injection policy](https://www.deepspeed.ai/tutorials/automatic-tensor-parallelism/)
* [2022/12] [DeepSpeed Data Efficiency: A composable library that makes better use of data, increases training efficiency, and improves model quality](https://www.deepspeed.ai/2022/12/11/data-efficiency.html)
* [2022/11] [Stable Diffusion Image Generation under 1 second w. DeepSpeed MII](https://github.com/microsoft/DeepSpeed-MII/tree/main/examples/benchmark/txt2img)
* [2022/10] [DeepSpeed-MII: instant speedup on 24,000+ open-source DL models with up to 40x cheaper inference](https://www.deepspeed.ai/2022/10/10/mii.html)
* [2022/09] [ZeRO-Inference: Democratizing massive model inference](https://www.deepspeed.ai/2022/09/09/zero-inference.html)
* [2022/07] [Azure and DeepSpeed empower easy-to-use and high-performance model training](https://azure.microsoft.com/en-us/blog/azure-empowers-easytouse-highperformance-and-hyperscale-model-training-using-deepspeed/)

---

Expand Down
26 changes: 26 additions & 0 deletions csrc/includes/quantization.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,32 @@ void launch_dequantize_kernel(T* dequant_data,
int total_elems,
cudaStream_t stream);

void launch_swizzled_quant(int8_t* q_data,
float* q_scales,
const __half* input_data,
int num_bits,
quantize::Type q_type,
int groups,
int elems_per_group,
int pipelining,
int nodes,
int devices_per_node,
cudaStream_t stream);

void launch_dequant_reduce(int8_t* reduced_data,
float* reduced_scales,
const int8_t* input_data,
const float* input_scales,
int num_gpus,
int num_bits,
quantize::Type quant_type,
int out_groups,
int elems_per_out_group,
int elems_per_in_tensor,
int groups_per_in_tensor,
int elems_per_in_group,
cudaStream_t stream);

template <typename T>
void launch_fake_quantize_kernel(T* vals,
int total_count,
Expand Down
91 changes: 91 additions & 0 deletions csrc/quantization/pt_binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,95 @@ at::Tensor dequantize(at::Tensor& quantized_data,
return output;
}

std::vector<at::Tensor> ds_swizzle_quant(at::Tensor& input_vals,
int groups,
int num_bits,
quantize::Type quant_type,
int pipeline_size,
int nodes,
int devices_per_node)
{
auto scales_options = at::TensorOptions()
.dtype(at::kFloat)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
const int scales_elems = (quantize::requires_offset(quant_type)) ? 2 : 1;
auto scales = torch::empty({groups, scales_elems}, scales_options);

auto output_options = at::TensorOptions()
.dtype(at::kChar)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);

const int quantization_scalar = 8 / num_bits;
const int compressed_vals = at::numel(input_vals) / quantization_scalar;

auto output = torch::empty({compressed_vals}, output_options);
const int elems_per_group = at::numel(input_vals) / groups;

launch_swizzled_quant((int8_t*)output.data_ptr(),
(float*)scales.data_ptr(),
(__half*)input_vals.data_ptr(),
num_bits,
quant_type,
groups,
elems_per_group,
pipeline_size,
nodes,
devices_per_node,
at::cuda::getCurrentCUDAStream());

return {output, scales};
}

std::vector<at::Tensor> quantized_reduction(at::Tensor& input_vals,
at::Tensor& input_scales,
int in_groups,
int out_groups,
int num_bits,
quantize::Type quant_type)
{
auto scales_options = at::TensorOptions()
.dtype(at::kFloat)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
const int scales_elems = (quantize::requires_offset(quant_type)) ? 2 : 1;
auto scales = torch::empty({out_groups, scales_elems}, scales_options);

auto output_options = at::TensorOptions()
.dtype(at::kChar)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);

std::vector<long int> sz(input_vals.sizes().begin(), input_vals.sizes().end());
const int gpu_per_node = 16; // depend on machine in_groups/out_groups;
sz[sz.size() - 1] = sz.back() / gpu_per_node; // num of GPU per nodes
const int elems_per_in_tensor = at::numel(input_vals) / gpu_per_node;
auto output = torch::empty(sz, output_options);

const int elems_per_in_group = elems_per_in_tensor / (in_groups / gpu_per_node);
const int elems_per_out_group = elems_per_in_tensor / out_groups;

launch_dequant_reduce((int8_t*)output.data_ptr(),
(float*)scales.data_ptr(),
(const int8_t*)input_vals.data_ptr(),
(const float*)input_scales.data_ptr(),
gpu_per_node,
num_bits,
quant_type,
out_groups,
elems_per_out_group,
elems_per_in_tensor,
in_groups / gpu_per_node,
elems_per_in_group,
at::cuda::getCurrentCUDAStream());
return {output, scales};
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("ds_quantize_fp32", &ds_quantize<float>, "DeepSpeed Quantize with fp32 (CUDA)");
Expand All @@ -158,4 +247,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("quantize", &quantize_kernel);
m.def("dequantize", &dequantize<__half>);
m.def("dequantize_fp32", &dequantize<float>);
m.def("swizzle_quant", &ds_swizzle_quant);
m.def("quantized_reduction", &quantized_reduction);
}
Loading