From 5501aef89710a68e079e87880d3436dc4426d572 Mon Sep 17 00:00:00 2001 From: Lev Kurilenko Date: Thu, 14 Sep 2023 19:44:15 +0000 Subject: [PATCH] DS-Chat BLOOM: Fix Attention mask --- deepspeed/ops/transformer/inference/ds_attention.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/deepspeed/ops/transformer/inference/ds_attention.py b/deepspeed/ops/transformer/inference/ds_attention.py index 967f1d4b8d9d..99fc3daa9292 100644 --- a/deepspeed/ops/transformer/inference/ds_attention.py +++ b/deepspeed/ops/transformer/inference/ds_attention.py @@ -247,6 +247,11 @@ def compute_attention(self, qkv_out, input_mask, layer_past, alibi): offset = dist.get_rank() * self.num_attention_heads_per_partition if dist.is_initialized() else 0 target_dtype = torch.float16 if self.config.dtype == torch.int8 else self.config.dtype + + # When using the hybrid engine with BLOOM, input_mask needs to be converted from torch.bool -> torch.int64 + if input_mask.dtype == torch.bool: + input_mask = input_mask.long() + attention_probs = self.softmax_func(attn_scores=attention_scores, attn_mask=((1 - input_mask).to(target_dtype) * minus_inf), alibi=alibi,