Skip to content

Commit

Permalink
[NFC] polish colossalai/kernel/cuda_native/scaled_softmax.py code sty…
Browse files Browse the repository at this point in the history
…le (#955)
  • Loading branch information
wangbo-zhao authored and binmakeswell committed May 17, 2022
1 parent f6970ef commit 8ca2a85
Showing 1 changed file with 13 additions and 23 deletions.
36 changes: 13 additions & 23 deletions colossalai/kernel/cuda_native/scaled_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ def forward(ctx, inputs, scale):
raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions')

scale_t = torch.tensor([scale])
softmax_results = colossal_scaled_upper_triang_masked_softmax.forward(
inputs, scale_t[0]
)
softmax_results = colossal_scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0])

ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
Expand All @@ -43,9 +41,7 @@ def backward(ctx, output_grads):
raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions')

softmax_results, scale_t = ctx.saved_tensors
input_grads = colossal_scaled_upper_triang_masked_softmax.backward(
output_grads, softmax_results, scale_t[0]
)
input_grads = colossal_scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results, scale_t[0])

return input_grads, None

Expand Down Expand Up @@ -81,9 +77,7 @@ def backward(ctx, output_grads):

softmax_results, scale_t = ctx.saved_tensors

input_grads = colossal_scaled_masked_softmax.backward(
output_grads, softmax_results, scale_t[0]
)
input_grads = colossal_scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
return input_grads, None, None


Expand Down Expand Up @@ -114,19 +108,16 @@ def __init__(
super(FusedScaleMaskSoftmax, self).__init__()
self.input_in_fp16 = input_in_fp16
self.input_in_bf16 = input_in_bf16
assert not (
self.input_in_fp16 and self.input_in_bf16
), "both fp16 and bf16 flags cannot be active at the same time."
assert not (self.input_in_fp16
and self.input_in_bf16), "both fp16 and bf16 flags cannot be active at the same time."
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
self.attn_mask_type = attn_mask_type
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
self.mask_func = mask_func
self.softmax_in_fp32 = softmax_in_fp32
self.scale = scale

assert (
self.scale is None or softmax_in_fp32
), "softmax should be in fp32 when scaled"
assert (self.scale is None or softmax_in_fp32), "softmax should be in fp32 when scaled"

def forward(self, input, mask):
# [b, np, sq, sk]
Expand All @@ -140,14 +131,13 @@ def forward(self, input, mask):
def is_kernel_available(self, mask, b, np, sq, sk):
attn_batches = b * np

if (
self.scaled_masked_softmax_fusion # user want to fuse
and self.input_in_float16 # input must be fp16
and mask is not None # mask tensor must not be None
and 16 < sk <= 2048 # sk must be 16 ~ 2048
and sq % 4 == 0 # sq must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
):
if (self.scaled_masked_softmax_fusion # user want to fuse
and self.input_in_float16 # input must be fp16
and mask is not None # mask tensor must not be None
and 16 < sk <= 2048 # sk must be 16 ~ 2048
and sq % 4 == 0 # sq must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
):
if 0 <= sk <= 2048:
batch_per_block = self.get_batch_per_block(sq, sk, b, np)

Expand Down

0 comments on commit 8ca2a85

Please sign in to comment.