diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index 493714cc0767..358d2035bdce 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -849,20 +849,6 @@ at::Tensor ds_linear_layer(at::Tensor& input, int head_size = input_cont.size(2) / num_heads; int bsz = input.size(0) * input.size(1); T* workspace = (T*)Context::Instance().GetWorkSpace(); - // Reallocate memory if we received a new prompt - if (!workspace) { - cublasSetStream(Context::Instance().GetCublasHandle(), - Context::Instance().GetCurrentStream()); - allocate_workspace(input.size(2), - input.size(0), - input.size(1), - num_layers, - num_heads, - 1, - external_cache, - 0); - workspace = (T*)Context::Instance().GetWorkSpace(); - } auto output = at::from_blob(workspace, {input.size(0), input.size(1), weight.size(1)}, options); float alpha = (T)1.0; diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index bbb60b6b3ff2..b689372a1afd 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -210,6 +210,7 @@ def replace_attn(child, policy, layer_id): heads=heads, fp16=fp16, triangular_masking=False, + max_out_tokens=4096, ) attn_module = transformer_inference.DeepSpeedAttention(config) diff --git a/deepspeed/ops/transformer/inference/attention.py b/deepspeed/ops/transformer/inference/attention.py index f8bad34d5bff..85fab6e6c28a 100644 --- a/deepspeed/ops/transformer/inference/attention.py +++ b/deepspeed/ops/transformer/inference/attention.py @@ -233,6 +233,9 @@ def __init__( inference_cuda_module.linear_layer_fp32 self.cuda_graph_created = False self.enable_cuda_graph = False + self.allocate_workspace = inference_cuda_module.allocate_workspace_fp32 if (not config.fp16) else \ + inference_cuda_module.allocate_workspace_fp16 + self.iter = 0 def _graph_replay(self, *inputs, **kwargs): for i in range(len(inputs)): @@ -275,6 +278,18 @@ def forward(self, *inputs, **kwargs): return outputs def _forward(self, input, context=None, input_mask=None): + # Allocate memory only on first layer forward + if self.config.layer_id == 0 and self.iter == 0: + self.iter += 1 + self.allocate_workspace(self.config.hidden_size, + input.size()[0], + input.size()[1], + DeepSpeedAttention.layer_id, + self.config.heads, + self.config.mp_size, + self.config.bigscience_bloom, + 0, + self.config.max_out_tokens) output = DeepSpeedAttentionFunction.apply(input, context, input_mask,