Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support of OPT models #2205

Merged
merged 17 commits into from
Aug 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
60 changes: 49 additions & 11 deletions csrc/transformer/inference/csrc/pt_binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@

std::array<int, 3> gemm_algos = std::array<int, 3>({99, 99, 99});

// NOTE: This activation function type enum should be always in sync
// with the python counterpart, otherwise the casting from python binding
// will be incorrect.
enum class ActivationFuncType { UNKNOWN = 0, GELU = 1, ReLU = 2 };

template <typename T>
at::Tensor ds_softmax(at::Tensor& attn_scores,
at::Tensor& attn_mask,
Expand Down Expand Up @@ -464,9 +469,9 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
1);

if (layer_id == num_layers - 1) Context::Instance().advance_tokens();
auto prev_key = torch::from_blob(workspace + offset, {bsz, all_tokens, hidden_dim}, options);
auto prev_key = torch::from_blob(workspace + offset, {bsz, heads, all_tokens, k}, options);
auto prev_value =
torch::from_blob(workspace + offset + value_offset, {bsz, all_tokens, hidden_dim}, options);
torch::from_blob(workspace + offset + value_offset, {bsz, heads, all_tokens, k}, options);
return {output, prev_key, prev_value};
}

Expand All @@ -486,6 +491,22 @@ at::Tensor ds_bias_gelu(at::Tensor& input, at::Tensor& bias)
return input_cont;
}

template <typename T>
at::Tensor ds_bias_relu(at::Tensor& input, at::Tensor& bias)
{
auto input_cont = input.contiguous();

int bsz = input_cont.size(0) * input_cont.size(1);
int intermediate_size = input_cont.size(2);

launch_bias_relu((T*)input_cont.data_ptr(),
(T*)bias.data_ptr(),
intermediate_size,
bsz,
Context::Instance().GetCurrentStream());
return input_cont;
}

template <typename T>
at::Tensor ds_bias_residual(at::Tensor& input, at::Tensor& residual, at::Tensor& bias)
{
Expand Down Expand Up @@ -840,7 +861,8 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output,
at::Tensor& beta,
const float epsilon,
bool preLayerNorm,
bool mlp_after_attn)
bool mlp_after_attn,
ActivationFuncType act_func_type)
{
int bsz = input.size(0) * input.size(1);
auto inp_norm = at::empty_like(input);
Expand Down Expand Up @@ -878,13 +900,24 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output,
#else
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif
launch_bias_gelu((T*)output.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());

if (act_func_type == ActivationFuncType::GELU) {
launch_bias_gelu((T*)output.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
} else if (act_func_type == ActivationFuncType::ReLU) {
launch_bias_relu((T*)output.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
}

return inp_norm;
}

template <typename T>
std::vector<at::Tensor> ds_mlp_gemm(at::Tensor& input,
at::Tensor& residual,
Expand All @@ -895,7 +928,8 @@ std::vector<at::Tensor> ds_mlp_gemm(at::Tensor& input,
at::Tensor& beta,
const float epsilon,
bool preLayerNorm,
bool mlp_after_attn)
bool mlp_after_attn,
int activation_type)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
Expand All @@ -907,6 +941,7 @@ std::vector<at::Tensor> ds_mlp_gemm(at::Tensor& input,
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
int bsz = input_cont.size(0) * input_cont.size(1);

auto act_func_type = static_cast<ActivationFuncType>(activation_type);
auto res_add = mlp_unfused_cublas<T>(output,
mlp_after_attn ? input : residual,
residual,
Expand All @@ -917,7 +952,8 @@ std::vector<at::Tensor> ds_mlp_gemm(at::Tensor& input,
beta,
epsilon,
preLayerNorm,
mlp_after_attn);
mlp_after_attn,
act_func_type);

return {output, res_add};
}
Expand Down Expand Up @@ -1205,7 +1241,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
&ds_softmax_context1<__half>,
"DeepSpeed attention with fp32 (CUDA)");
m.def("bias_gelu_fp32", &ds_bias_gelu<float>, "DeepSpeed Gelu with fp32 (CUDA)");
m.def("bias_gelu_fp16", &ds_bias_gelu<__half>, "DeepSpeed Gelu with fp32 (CUDA)");
m.def("bias_gelu_fp16", &ds_bias_gelu<__half>, "DeepSpeed Gelu with fp16 (CUDA)");
m.def("bias_relu_fp32", &ds_bias_relu<float>, "DeepSpeed ReLU with fp32 (CUDA)");
m.def("bias_relu_fp16", &ds_bias_relu<__half>, "DeepSpeed ReLU with fp16 (CUDA)");
m.def("bias_residual_fp32",
&ds_bias_residual<float>,
"DeepSpeed residual-bias add with fp32 (CUDA)");
Expand Down
95 changes: 95 additions & 0 deletions csrc/transformer/inference/csrc/relu.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#include "custom_cuda_layers.h"

#define MAX_CAP 4
#define MAX_SEQ 2048

inline __device__ float relu(const float x) { return x < 0 ? 0 : x; }

__global__ void fused_bias_relu(float* input,
const float* bias,
int total_count,
int intermediate_size)
{
float4* input_cast = reinterpret_cast<float4*>(input);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;

if (offset < total_count) {
float4 data = input_cast[offset];
float4 bias_data = bias_cast[offset % intermediate_size];

data.x += bias_data.x;
data.y += bias_data.y;
data.z += bias_data.z;
data.w += bias_data.w;

data.x = relu(data.x);
data.y = relu(data.y);
data.z = relu(data.z);
data.w = relu(data.w);

input_cast[offset] = data;
}
}

__global__ void fused_bias_relu(__half* input,
const __half* bias,
int total_count,
int intermediate_size)
{
#ifdef HALF_PRECISION_AVAILABLE

float2* input_cast = reinterpret_cast<float2*>(input);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);

int offset = blockIdx.x * blockDim.x + threadIdx.x;

if (offset < total_count) {
float2 vals_vec = input_cast[offset];
float2 bias_vec = bias_cast[offset % intermediate_size];

__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);

float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);

float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);

low_data.x += low_bias.x;
low_data.y += low_bias.y;
high_data.x += high_bias.x;
high_data.y += high_bias.y;

low_data.x = relu(low_data.x);
low_data.y = relu(low_data.y);
high_data.x = relu(high_data.x);
high_data.y = relu(high_data.y);

vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);

input_cast[offset] = vals_vec;
}
#endif
}

template <typename T>
void launch_bias_relu(T* input,
const T* bias,
int intermediate_size,
int batch_size,
cudaStream_t stream)
{
int total_count = batch_size * (intermediate_size / 4);
int threads = 1024; // intermediate_size / iterations / 4;
dim3 block_dims(threads);
dim3 grid_dims(((total_count - 1) / 1024 + 1)); // (batch_size);

fused_bias_relu<<<grid_dims, block_dims, 0, stream>>>(
input, bias, total_count, intermediate_size / 4);
}

template void launch_bias_relu<float>(float*, const float*, int, int, cudaStream_t);
template void launch_bias_relu<__half>(__half*, const __half*, int, int, cudaStream_t);
9 changes: 9 additions & 0 deletions csrc/transformer/inference/includes/custom_cuda_layers.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,15 @@ void launch_bias_gelu(T* input,
int intermediate_size,
int batch_size,
cudaStream_t stream);

// Fused bias add with relu activation
template <typename T>
void launch_bias_relu(T* input,
const T* bias,
int intermediate_size,
int batch_size,
cudaStream_t stream);

template <typename T>
void launch_bias_add(T* input, const T* bias, int hidden_size, int batch_size, cudaStream_t stream);

Expand Down
4 changes: 4 additions & 0 deletions deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from ..moe.utils import has_moe_layers
from ..runtime.zero import GatheredParameters
from ..module_inject import LinearAllreduce, LinearLayer, Normalize, ReplaceWithTensorSlicing
from ..module_inject.replace_policy import DSPolicy

DS_INFERENCE_ENABLED = False
from torch import nn
Expand Down Expand Up @@ -77,6 +78,9 @@ def __init__(self,

self._get_model_config_generate(config)

if hasattr(self.module, "config"):
DSPolicy.hf_model_config = self.module.config

self.mp_world_size = mp_size
self.checkpoint = checkpoint
self.dtype = dtype
Expand Down
22 changes: 6 additions & 16 deletions deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ def replace_transformer_layer(orig_layer_impl,
mp_group=None,
ep_group=None,
expert_mp_group=None,
preln=True,
fp16=True,
local_rank=-1,
stochastic_mode=True,
Expand Down Expand Up @@ -204,13 +203,8 @@ def replace_with_policy(child,
policy_cls,
triangular_masking,
inference=False,
preln=True,
layer_id=0):
preln = False if policy_cls is HFBertLayerPolicy else preln
if policy_cls is HFBertLayerPolicy:
policy = policy_cls(child, inference=inference, preln=preln)
else:
policy = policy_cls(child, inference=inference)
policy = policy_cls(child, inference=inference)

if inference:
hidden_size, num_attention_heads = policy.get_hidden_heads()
Expand Down Expand Up @@ -275,7 +269,7 @@ def replace_with_policy(child,
config,
'layer_norm_eps') else 1e-12,
fp16=fp16,
pre_layer_norm=preln,
pre_layer_norm=policy.pre_attn_norm,
mp_size=mp_size,
q_int8=quantize,
moe_experts=local_ep_size,
Expand All @@ -297,7 +291,7 @@ def replace_with_policy(child,
if hasattr(config,
'layernorm_epsilon') else 1.0e-12),
fp16=fp16,
pre_layer_norm=preln,
pre_layer_norm=policy.pre_attn_norm,
mp_size=mp_size,
q_int8=quantize,
return_tuple=(return_tuple or (policy_cls is HFBertLayerPolicy)),
Expand All @@ -309,6 +303,7 @@ def replace_with_policy(child,
'window_size') else 1),
rotary_dim=rotary_dim,
mlp_after_attn=(rotary_dim is None or rotary_dim < 0),
mlp_act_func_type=policy.mlp_act_func_type,
training_mp_size=training_mp_size,
bigscience_bloom=bigscience_bloom)

Expand Down Expand Up @@ -594,7 +589,7 @@ def _transpose(x):
'layer_norm_eps') else 1e-12,
seed=seed,
fp16=fp16,
pre_layer_norm=(False if policy_cls is HFBertLayerPolicy else preln),
pre_layer_norm=policy.pre_attn_norm,
return_tuple=return_tuple,
local_rank=local_rank,
stochastic_mode=stochastic_mode,
Expand Down Expand Up @@ -758,10 +753,7 @@ def _replace_module(r_module, prev_name=''):
def replace_fn(child, _policy, layer_id=0):
if training:
# copy relevant state from child -> new module
new_module = replace_with_policy(child,
_policy,
triangular_masking,
preln=preln)
new_module = replace_with_policy(child, _policy, triangular_masking)

else:
# copy relevant state from child -> new module
Expand All @@ -770,8 +762,6 @@ def replace_fn(child, _policy, layer_id=0):
_policy,
triangular_masking,
inference=True,
preln=(_policy
is not HFBertLayerPolicy),
layer_id=layer_id)
else:
new_module = replace_wo_policy(child, _policy)
Expand Down