diff --git a/colossalai/kernel/cuda_native/scaled_softmax.py b/colossalai/kernel/cuda_native/scaled_softmax.py index 786b922c6905..cb36da8a130d 100644 --- a/colossalai/kernel/cuda_native/scaled_softmax.py +++ b/colossalai/kernel/cuda_native/scaled_softmax.py @@ -28,9 +28,7 @@ def forward(ctx, inputs, scale): raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions') scale_t = torch.tensor([scale]) - softmax_results = colossal_scaled_upper_triang_masked_softmax.forward( - inputs, scale_t[0] - ) + softmax_results = colossal_scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0]) ctx.save_for_backward(softmax_results, scale_t) return softmax_results @@ -43,9 +41,7 @@ def backward(ctx, output_grads): raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions') softmax_results, scale_t = ctx.saved_tensors - input_grads = colossal_scaled_upper_triang_masked_softmax.backward( - output_grads, softmax_results, scale_t[0] - ) + input_grads = colossal_scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results, scale_t[0]) return input_grads, None @@ -81,9 +77,7 @@ def backward(ctx, output_grads): softmax_results, scale_t = ctx.saved_tensors - input_grads = colossal_scaled_masked_softmax.backward( - output_grads, softmax_results, scale_t[0] - ) + input_grads = colossal_scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0]) return input_grads, None, None @@ -114,9 +108,8 @@ def __init__( super(FusedScaleMaskSoftmax, self).__init__() self.input_in_fp16 = input_in_fp16 self.input_in_bf16 = input_in_bf16 - assert not ( - self.input_in_fp16 and self.input_in_bf16 - ), "both fp16 and bf16 flags cannot be active at the same time." + assert not (self.input_in_fp16 + and self.input_in_bf16), "both fp16 and bf16 flags cannot be active at the same time." self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 self.attn_mask_type = attn_mask_type self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion @@ -124,9 +117,7 @@ def __init__( self.softmax_in_fp32 = softmax_in_fp32 self.scale = scale - assert ( - self.scale is None or softmax_in_fp32 - ), "softmax should be in fp32 when scaled" + assert (self.scale is None or softmax_in_fp32), "softmax should be in fp32 when scaled" def forward(self, input, mask): # [b, np, sq, sk] @@ -140,14 +131,13 @@ def forward(self, input, mask): def is_kernel_available(self, mask, b, np, sq, sk): attn_batches = b * np - if ( - self.scaled_masked_softmax_fusion # user want to fuse - and self.input_in_float16 # input must be fp16 - and mask is not None # mask tensor must not be None - and 16 < sk <= 2048 # sk must be 16 ~ 2048 - and sq % 4 == 0 # sq must be divisor of 4 - and attn_batches % 4 == 0 # np * b must be divisor of 4 - ): + if (self.scaled_masked_softmax_fusion # user want to fuse + and self.input_in_float16 # input must be fp16 + and mask is not None # mask tensor must not be None + and 16 < sk <= 2048 # sk must be 16 ~ 2048 + and sq % 4 == 0 # sq must be divisor of 4 + and attn_batches % 4 == 0 # np * b must be divisor of 4 + ): if 0 <= sk <= 2048: batch_per_block = self.get_batch_per_block(sq, sk, b, np)