Skip to content

Commit

Permalink
DeepSpeed4Science (#569)
Browse files Browse the repository at this point in the history
* Integrating evoformer attention

* add cutlass version check

* Updaate error message

* add benchmark

* Update

* Update evoformer_attn.py

* Update run_evoformer_test.py

* Update evoformer_attn.py

* Update run_evoformer_test.py

* support more GPU archs

* add copyright

* add tests

* Fix bugs

* Update benchmark

* update

* Fix nvcc macro

* clean code

* fix formatting

* fix yaml import

* skip unit test when not compatible

* fix yaml requirement

* revert changes

* update tutorial

* update

* fix formatting

* fix format

* skip evoformer attn in pre-compile-ops

* revert changes

* update tutorial

* fix cutlass check

* update tutorial

* refactor tutorial

* revise

* Updated the Megatron-DS section (#565)

* Updated the Megatron-DS section

* minor fix

* minor fix

* minor fix

* separate evoformer tutorial

* Revised the ds4science landing page (#566)

* Updated the Megatron-DS section

* minor fix

* minor fix

* minor fix

* Revised the landing page

* Revised the landing page

* Removing unused file

* fix links image position

* modify main page

* fix doc

---------

Co-authored-by: Shiyang Chen <csycfl@gmail.com>
Co-authored-by: Minjia Zhang <33713995+minjiaz@users.noreply.github.com>
  • Loading branch information
3 people committed Sep 18, 2023
1 parent 00dfab9 commit a5552a6
Show file tree
Hide file tree
Showing 42 changed files with 15,421 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/nv-pre-compile-ops.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
#python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
- name: Compile DeepSpeed Ops
run: |
TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 pip3 install .
TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_EVOFORMER_ATTN=0 pip3 install .
- name: DS Report
run: |
ds_report
10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
## Latest News
<b> <span style="color:orange" > DeepSpeed empowers ChatGPT-like model training with a single click, offering 15x speedup over SOTA RLHF systems with unprecedented cost reduction at all scales; [learn how](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat)</span>.</b>

* [2023/09] Announcing the DeepSpeed4Science Initiative: Enabling large-scale scientific discovery through sophisticated AI system technologies [[DeepSpeed4Science website](https://deepspeed4science.ai/)] [[Tutorials](https://www.deepspeed.ai/deepspeed4science/)]
* [2023/08] [DeepSpeed ZeRO-Inference: 20X faster inference through weight quantization and KV cache offloading](https://github.com/microsoft/DeepSpeedExamples/blob/master/inference/huggingface/zero_inference/README.md)
* [2023/08] [DeepSpeed-Chat: Llama/Llama-2 system support, efficiency boost, and training stability improvements](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/ds-chat-release-8-31/README.md)
* [2023/08] [DeepSpeed Ulysses: System Optimizations for Enabling Training of Extreme Long Sequence Transformer Models](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-ulysses)
* [2023/06] [ZeRO++: A leap in speed for LLM and chat model training with 4X less communication](https://www.microsoft.com/en-us/research/blog/deepspeed-zero-a-leap-in-speed-for-llm-and-chat-model-training-with-4x-less-communication/)[[English](https://www.microsoft.com/en-us/research/blog/deepspeed-zero-a-leap-in-speed-for-llm-and-chat-model-training-with-4x-less-communication/)] [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/zeropp/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/zeropp/japanese/README.md)]
* [2023/04] 🚀 [DeepSpeed Chat: Easy, Fast and Affordable RLHF Training of ChatGPT-like Models at All Scales](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat) [[English](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/README.md)] [[中文](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/japanese/README.md)]🚀

---

Expand All @@ -35,9 +35,9 @@

---

# DeepSpeed's three innovation pillars
# DeepSpeed's four innovation pillars

<img src="docs/assets/images/3pillars.png" width="800px">
<img src="docs/assets/images/DeepSpeed-pillars.png" width="800px">


## DeepSpeed-Training
Expand All @@ -53,6 +53,10 @@ DeepSpeed brings together innovations in parallelism technology such as tensor,

To further increase the inference efficiency, DeepSpeed offers easy-to-use and flexible-to-compose compression techniques for researchers and practitioners to compress their models while delivering faster speed, smaller model size, and significantly reduced compression cost. Moreover, SoTA innovations on compression like ZeroQuant and XTC are included under the compression pillar. Learn more: [DeepSpeed-Compression](https://www.deepspeed.ai/compression)

## DeepSpeed4Science

In line with Microsoft's mission to solve humanity's most pressing challenges, the DeepSpeed team at Microsoft is responding to this opportunity by launching a new initiative called *DeepSpeed4Science*, aiming to build unique capabilities through AI system technology innovations to help domain experts to unlock today's biggest science mysteries. Learn more: [DeepSpeed4Science website](https://deepspeed4science.ai/) and [tutorials](https://www.deepspeed.ai/deepspeed4science/)

---

# DeepSpeed Software Suite
Expand Down
62 changes: 62 additions & 0 deletions csrc/deepspeed4science/evoformer_attn/attention.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0

// DeepSpeed Team

#include <torch/extension.h>

void attention_impl(torch::Tensor& q,
torch::Tensor& k,
torch::Tensor& v,
torch::Tensor& bias1,
torch::Tensor& bias2,
torch::Tensor& o,
torch::Tensor& lse);
void attention(torch::Tensor& q,
torch::Tensor& k,
torch::Tensor& v,
torch::Tensor& bias1,
torch::Tensor& bias2,
torch::Tensor& o,
torch::Tensor& lse)
{
attention_impl(q, k, v, bias1, bias2, o, lse);
}

void attention_back_impl(torch::Tensor& go,
torch::Tensor& q,
torch::Tensor& k,
torch::Tensor& v,
torch::Tensor& o,
torch::Tensor& lse,
torch::Tensor& delta,
torch::Tensor& bias1,
torch::Tensor& bias2,
torch::Tensor& gq,
torch::Tensor& gk,
torch::Tensor& gv,
torch::Tensor& gb1,
torch::Tensor& gb2);
void attention_bwd(torch::Tensor& go,
torch::Tensor& q,
torch::Tensor& k,
torch::Tensor& v,
torch::Tensor& o,
torch::Tensor& lse,
torch::Tensor& delta,
torch::Tensor& bias1,
torch::Tensor& bias2,
torch::Tensor& gq,
torch::Tensor& gk,
torch::Tensor& gv,
torch::Tensor& gb1,
torch::Tensor& gb2)
{
attention_back_impl(go, q, k, v, o, lse, delta, bias1, bias2, gq, gk, gv, gb1, gb2);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("attention", &attention, "");
m.def("attention_bwd", &attention_bwd, "");
}
160 changes: 160 additions & 0 deletions csrc/deepspeed4science/evoformer_attn/attention.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0

// DeepSpeed Team

#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "gemm_kernel_utils.h"
#include "kernel_forward.h"
#include "transform/bias_broadcast.h"

template <typename arch,
typename scalar_t,
typename torch_scalar_t,
template <typename, typename, typename>
class Broadcast1_,
template <typename, typename, typename>
class Broadcast2_>
typename std::enable_if<!CheckArch<arch, scalar_t>::value>::type attention_impl_template(
torch::Tensor& q,
torch::Tensor& k,
torch::Tensor& v,
torch::Tensor& bias1,
torch::Tensor& bias2,
torch::Tensor& o,
float* lse_ptr)
{
EVOFORMER_CHECK(false, "Unsupported GPU and data type combination")
}

template <typename arch,
typename scalar_t,
typename torch_scalar_t,
template <typename, typename, typename>
class Broadcast1_,
template <typename, typename, typename>
class Broadcast2_>
typename std::enable_if<CheckArch<arch, scalar_t>::value>::type attention_impl_template(
torch::Tensor& q,
torch::Tensor& k,
torch::Tensor& v,
torch::Tensor& bias1,
torch::Tensor& bias2,
torch::Tensor& o,
float* lse_ptr)
{
// Attention definition goes here, replaced with BroadcastType1 and
// BroadcastType2
using Attention = AttentionKernel<scalar_t, /* scalar_t */
arch, /* ArchTag */
true, /* Memory is aligned */
64,
64,
true,
true, /* Supports bias */
Broadcast1_,
Broadcast2_>;

static_assert(!Attention::kNeedsOutputAccumulatorBuffer,
"This test does not support output accumulator buffer");
int head_size = q.size(-1);
int head_number = q.size(-2);
int seq_length = q.size(-3);
auto q_view = q.view({-1, seq_length, head_number, head_size});
auto k_view = k.view({-1, seq_length, head_number, head_size});
auto v_view = v.view({-1, seq_length, head_number, head_size});
auto o_view = o.view({-1, seq_length, head_number, head_size});
int batch_size = q_view.size(0);
auto q_ptr = reinterpret_cast<scalar_t*>(q.data_ptr<torch_scalar_t>());
auto k_ptr = reinterpret_cast<scalar_t*>(k.data_ptr<torch_scalar_t>());
auto v_ptr = reinterpret_cast<scalar_t*>(v.data_ptr<torch_scalar_t>());
auto o_ptr = reinterpret_cast<scalar_t*>(o.data_ptr<torch_scalar_t>());

auto bias1_ptr = reinterpret_cast<scalar_t*>(bias1.data_ptr<torch_scalar_t>());
auto bias2_ptr = reinterpret_cast<scalar_t*>(bias2.data_ptr<torch_scalar_t>());

typename Attention::Params p;
{ // set parameters
p.query_ptr = q_ptr;
p.key_ptr = k_ptr;
p.value_ptr = v_ptr;
p.logsumexp_ptr = lse_ptr; // Only needed for bw
p.output_accum_ptr = nullptr;
p.output_ptr = o_ptr;
p.scale = 1.0f / sqrt(float(head_size));

p.bias1_ptr = bias1_ptr;
p.bias2_ptr = bias2_ptr;
p.B = q.size(0);
p.N = q.size(1);

p.num_heads = head_number;
p.num_batches = batch_size;
p.head_dim = head_size;
p.head_dim_value = head_size;
p.num_queries = seq_length;
p.num_keys = seq_length;

// All tensors are in BMHK shapes
p.q_strideH = q_view.stride(-2);
p.k_strideH = k_view.stride(-2);
p.v_strideH = v_view.stride(-2);
p.q_strideM = q_view.stride(-3);
p.k_strideM = k_view.stride(-3);
p.v_strideM = v_view.stride(-3);
p.o_strideM = o_view.stride(-3);
p.q_strideB = q_view.stride(-4);
p.k_strideB = k_view.stride(-4);
p.v_strideB = v_view.stride(-4);
}

constexpr auto kernel_fn = attention_kernel_batched_impl<Attention>;
int smem_bytes = sizeof(typename Attention::SharedStorage);
if (smem_bytes > 0xc000) {
cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
}
if (!Attention::check_supported(p)) { throw std::runtime_error("Parameters not supported"); }
kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes>>>(p);
}

#define CODE(scalar_t, torch_scalar_t) \
do { \
if (bias1.size(0) == 0 && bias2.size(0) == 0) { \
attention_impl_template<ArchTag, \
scalar_t, \
torch_scalar_t, \
BroadcastNoLoad, \
BroadcastNoLoad>(q, k, v, bias1, bias2, o, lse_ptr); \
} else if (bias1.size(0) == 0) { \
attention_impl_template<ArchTag, \
scalar_t, \
torch_scalar_t, \
BroadcastNoLoad, \
BroadcastB>(q, k, v, bias1, bias2, o, lse_ptr); \
} else if (bias2.size(0) == 0) { \
attention_impl_template<ArchTag, \
scalar_t, \
torch_scalar_t, \
BroadcastA, \
BroadcastNoLoad>(q, k, v, bias1, bias2, o, lse_ptr); \
} else { \
attention_impl_template<ArchTag, scalar_t, torch_scalar_t, BroadcastA, BroadcastB>( \
q, k, v, bias1, bias2, o, lse_ptr); \
} \
} while (0)

// Function to select and call the correct template based on biases sizes
void attention_impl(torch::Tensor& q,
torch::Tensor& k,
torch::Tensor& v,
torch::Tensor& bias1,
torch::Tensor& bias2,
torch::Tensor& o,
torch::Tensor& lse)
{
auto lse_ptr = lse.size(0) == 0 ? nullptr : reinterpret_cast<float*>(lse.data_ptr<float>());
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
DISPATCH_ARCHTAG(prop->major * 10 + prop->minor,
DISPATCH_TYPES(q, ([&]() { CODE(scalar_t, torch_scalar_t); })));
}
Loading

0 comments on commit a5552a6

Please sign in to comment.