-
Notifications
You must be signed in to change notification settings - Fork 547
/
flash.py
747 lines (699 loc) · 23.8 KB
/
flash.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import os
from dataclasses import replace
from itertools import zip_longest
from typing import Any, List, Optional, Set, Tuple, Union
import torch
from ..common import _get_storage_base, get_operator, register_operator
from .attn_bias import (
AttentionBias,
BlockDiagonalCausalFromBottomRightMask,
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
BlockDiagonalCausalLocalAttentionMask,
BlockDiagonalCausalMask,
BlockDiagonalCausalWithOffsetGappyKeysMask,
BlockDiagonalCausalWithOffsetPaddedKeysMask,
BlockDiagonalGappyKeysMask,
BlockDiagonalMask,
BlockDiagonalPaddedKeysMask,
LocalAttentionFromBottomRightMask,
LowerTriangularFromBottomRightLocalAttentionMask,
LowerTriangularFromBottomRightMask,
LowerTriangularMask,
)
from .common import (
AttentionBwOpBase,
AttentionFwOpBase,
Context,
Gradients,
Inputs,
check_lastdim_alignment_stride1,
)
FLASH_VERSION = "0.0.0"
try:
try:
from ... import _C_flashattention # type: ignore[attr-defined]
from ..._cpp_lib import _build_metadata
if _build_metadata is not None:
FLASH_VERSION = _build_metadata.flash_version
except ImportError:
import flash_attn
from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
FLASH_VERSION = flash_attn.__version__
FLASH_VER_MIN = (2, 5, 2)
FLASH_VER_LAST = (2, 5, 6) # last supported, inclusive
flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3])
if (
flash_ver_parsed < FLASH_VER_MIN or flash_ver_parsed > FLASH_VER_LAST
) and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1":
raise ImportError(
f"Requires Flash-Attention version >={'.'.join([str(i) for i in FLASH_VER_MIN])},"
f"<={'.'.join([str(i) for i in FLASH_VER_LAST])} "
f"but got {FLASH_VERSION}."
)
# create library so that flash-attn goes through the PyTorch Dispatcher
_flash_lib = torch.library.Library("xformers_flash", "DEF")
_flash_lib.define(
"flash_fwd(Tensor query, Tensor key, Tensor value, "
"Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, Tensor? seqused_k, "
"int max_seqlen_q, int max_seqlen_k, "
"float p, float softmax_scale, "
"bool is_causal, int window_left, "
"int window_right, bool return_softmax) -> (Tensor, Tensor, Tensor)"
)
_flash_lib.define(
"flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
"Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
"Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
"int max_seqlen_q, int max_seqlen_k, "
"float p, float softmax_scale, bool is_causal, "
"int window_left, int window_right, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
)
def _flash_fwd(
query,
key,
value,
cu_seq_lens_q,
cu_seq_lens_k,
seqused_k,
max_seq_len_q,
max_seq_len_k,
p,
softmax_scale,
is_causal,
window_left,
window_right,
return_softmax,
):
if cu_seq_lens_q is None:
assert cu_seq_lens_k is None
assert seqused_k is None
(
out,
q_padded,
k_padded,
v_padded,
out_padded,
softmax_lse,
p,
rng_state,
) = _C_flashattention.fwd(
query,
key,
value,
None, # out
None, # alibi_slopes
p,
softmax_scale,
is_causal,
window_left, # window_size_left
window_right, # window_size_right
return_softmax,
None, # rng
)
else:
(
out,
q_padded,
k_padded,
v_padded,
out_padded,
softmax_lse,
p,
rng_state,
) = _C_flashattention.varlen_fwd(
query,
key,
value,
None, # out
cu_seq_lens_q,
cu_seq_lens_k,
seqused_k,
None, # alibi_slopes
max_seq_len_q,
max_seq_len_k,
p,
softmax_scale,
False,
is_causal,
window_left,
window_right,
return_softmax,
None,
)
return out, softmax_lse, rng_state
def _flash_bwd(
grad,
query,
key,
value,
out,
lse,
dq,
dk,
dv,
cu_seq_lens_q,
cu_seq_lens_k,
max_seq_len_q,
max_seq_len_k,
p,
softmax_scale,
is_causal,
window_left,
window_right,
rng_state,
):
if cu_seq_lens_k is None:
assert cu_seq_lens_q is None
_C_flashattention.bwd(
grad,
query,
key,
value,
out,
lse,
dq,
dk,
dv,
None, # alibi_slopes
p,
softmax_scale,
is_causal,
window_left,
window_right,
False, # deterministic
None,
rng_state,
)
else:
_C_flashattention.varlen_bwd(
grad,
query,
key,
value,
out,
lse,
dq,
dk,
dv,
cu_seq_lens_q,
cu_seq_lens_k,
None, # alibi_slopes
max_seq_len_q,
max_seq_len_k,
p,
softmax_scale,
False, # zero_tensors
is_causal,
window_left,
window_right,
False, # deterministic
None,
rng_state,
)
return dq, dk, dv
_flash_lib.impl("flash_fwd", _flash_fwd, "CUDA")
_flash_lib.impl("flash_bwd", _flash_bwd, "CUDA")
except ImportError:
pass
def _convert_input_format(
inp: Inputs,
supports_mqa: bool,
) -> Tuple[
Inputs,
Optional[torch.Tensor],
int,
Optional[torch.Tensor],
int,
Optional[torch.Tensor],
]:
assert inp.query.ndim in [4, 5]
query, key, value = inp.query, inp.key, inp.value
batch = query.shape[0]
seqlen_q = query.shape[1]
seqlen_kv = key.shape[1]
head_dim_q = query.shape[-1]
head_dim_v = value.shape[-1]
attn_bias = inp.attn_bias
if isinstance(attn_bias, BlockDiagonalMask):
# BlockDiagonalMask or BlockDiagonalCausalMask
attn_bias.k_seqinfo.seqstart = attn_bias.k_seqinfo.seqstart.to(
inp.query.device, non_blocking=True
)
attn_bias.q_seqinfo.seqstart = attn_bias.q_seqinfo.seqstart.to(
inp.query.device, non_blocking=True
)
cu_seqlen_k = attn_bias.k_seqinfo.seqstart
cu_seqlen_q = attn_bias.q_seqinfo.seqstart
max_seqlen_q = attn_bias.q_seqinfo.max_seqlen
max_seqlen_k = attn_bias.k_seqinfo.max_seqlen
seqused_k = None
elif isinstance(
attn_bias,
(
BlockDiagonalGappyKeysMask,
BlockDiagonalPaddedKeysMask,
),
):
attn_bias.k_seqinfo.seqstart = attn_bias.k_seqinfo.seqstart.to(
inp.query.device, non_blocking=True
)
attn_bias.q_seqinfo.seqstart = attn_bias.q_seqinfo.seqstart.to(
inp.query.device, non_blocking=True
)
attn_bias.k_seqinfo.seqlen = attn_bias.k_seqinfo.seqlen.to(
inp.query.device, non_blocking=True
)
cu_seqlen_k = attn_bias.k_seqinfo.seqstart
cu_seqlen_q = attn_bias.q_seqinfo.seqstart
max_seqlen_q = attn_bias.q_seqinfo.max_seqlen
max_seqlen_k = attn_bias.k_seqinfo.max_seqlen
seqused_k = attn_bias.k_seqinfo.seqlen
else:
cu_seqlen_k = None
cu_seqlen_q = None
seqused_k = None
max_seqlen_q = inp.query.shape[1]
max_seqlen_k = inp.key.shape[1]
if query.ndim == 5: # GQA
assert supports_mqa
# Fold the group/head_in_group dimensions together
def fold(x):
# Either the head is replicated
if x.stride(3) == 0:
return x[:, :, :, 0]
# Or we reshape
return x.reshape(
[
x.shape[0],
x.shape[1],
-1,
x.shape[4],
]
)
query = fold(query)
key = fold(key)
value = fold(value)
# Optimize for MHA
if key.ndim == 4 and key.stride(2) == 0 and value.stride(2) == 0 and supports_mqa:
key = key[:, :, :1]
value = value[:, :, :1]
# Initially we have `query.shape = [batch, seqlen, num_heads, head_dim_q]`
# We want format `[batch * seqlen, num_heads, head_dim_q]`
if cu_seqlen_k is not None:
query = query.reshape([batch * seqlen_q, -1, head_dim_q])
key = key.reshape([batch * seqlen_kv, -1, head_dim_q])
value = value.reshape([batch * seqlen_kv, -1, head_dim_v])
new_inp = replace(
inp,
query=query,
key=key,
value=value,
)
return new_inp, cu_seqlen_q, max_seqlen_q, cu_seqlen_k, max_seqlen_k, seqused_k
def _is_causal(attn_bias: Optional[Union[torch.Tensor, AttentionBias]]) -> bool:
return isinstance(
attn_bias,
(
LowerTriangularMask,
LowerTriangularFromBottomRightMask,
LowerTriangularFromBottomRightLocalAttentionMask,
BlockDiagonalCausalMask,
BlockDiagonalCausalLocalAttentionMask,
BlockDiagonalCausalFromBottomRightMask,
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
BlockDiagonalCausalWithOffsetGappyKeysMask,
BlockDiagonalCausalWithOffsetPaddedKeysMask,
),
)
def _window_size(
attn_bias: Optional[Union[torch.Tensor, AttentionBias]]
) -> Tuple[int, int]:
win_left = -1
win_right = -1
if isinstance(
attn_bias,
(
BlockDiagonalCausalLocalAttentionMask,
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
LowerTriangularFromBottomRightLocalAttentionMask,
),
):
win_left = attn_bias._window_size - 1
if isinstance(attn_bias, LocalAttentionFromBottomRightMask):
win_left = attn_bias.window_left
win_right = attn_bias.window_right
return (win_left, win_right)
def _check_needs_no_topleft(d: Inputs, reasons: List[str]) -> None:
# Flash does not support TopLeft, so only allow causal masks with TopLeft
# if each batch element has equal number of queries and keys.
if isinstance(d.attn_bias, BlockDiagonalCausalMask):
# Flash does not support TopLeft, so only allow BlockDiagonalCausalMask
# if each batch element has equal number of queries and keys.
for k_start, q_start in zip_longest(
d.attn_bias.k_seqinfo.seqstart_py, d.attn_bias.q_seqinfo.seqstart_py
):
if k_start != q_start:
reasons.append(
"Only support BlockDiagonalCausalMask if equal"
" numbers of keys and queries"
)
break
elif isinstance(d.attn_bias, LowerTriangularMask):
if d.query.shape[1] != d.key.shape[1]:
reasons.append(
"Only support LowerTriangularMask if equal number of" "keys and queries"
)
def _check_strides_for_bmghk(x: torch.Tensor, name: str, reasons: List[str]) -> None:
"""
We want to be able to collapse the G/H dimensions together
"""
if x.ndim == 5:
stride_g, stride_h = x.stride(2), x.stride(3)
if x.shape[2] == 1:
return
if x.shape[3] == 1 or stride_h == 0:
return
if stride_g != stride_h * x.shape[-2]:
reasons.append(
f"GQA is only supported when the G/H dimensions are contiguous\n"
f" {name}.stride: {x.stride()}\n"
f" {name}.shape : {list(x.shape)}"
)
def _post_process_lse(
lse: torch.Tensor, inp: Inputs, original_query_shape: Tuple[int, ...]
) -> torch.Tensor:
if not isinstance(
inp.attn_bias,
(
BlockDiagonalGappyKeysMask,
BlockDiagonalPaddedKeysMask,
),
):
if inp.is_partial and len(original_query_shape) == 5:
# [B, GH, M] => [B, G, H, M]
return lse.unflatten(1, original_query_shape[2:4])
return lse
q_seqinfo = inp.attn_bias.q_seqinfo
B = len(q_seqinfo.seqstart_py) - 1
if q_seqinfo.max_seqlen * B != original_query_shape[1]:
# Heterogeneous batch. We can't fix it.
return lse
# reshape from (B, G*H, max_seqlen) to (1, G*H, B*max_seqlen)
# Unfortunately this flatten is not just a view.
lse_hkm = lse.permute(1, 0, 2).flatten(start_dim=1)[None]
if len(original_query_shape) == 5:
return lse_hkm.unflatten(1, original_query_shape[2:4])
return lse_hkm
@register_operator
class FwOp(AttentionFwOpBase):
"""Operator that computes memory-efficient attention using \
`Flash-Attention <https://github.com/HazyResearch/flash-attention>`_ \
implementation.
"""
OPERATOR = get_operator("xformers_flash", "flash_fwd")
SUPPORTED_DEVICES: Set[str] = {"cuda"}
CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0)
SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16}
SUPPORTED_MAX_K = 256
SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {
type(None),
LowerTriangularMask,
LowerTriangularFromBottomRightMask,
LowerTriangularFromBottomRightLocalAttentionMask,
BlockDiagonalMask,
BlockDiagonalCausalMask,
BlockDiagonalCausalLocalAttentionMask,
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
BlockDiagonalCausalFromBottomRightMask,
BlockDiagonalCausalWithOffsetGappyKeysMask,
BlockDiagonalCausalWithOffsetPaddedKeysMask,
BlockDiagonalGappyKeysMask,
BlockDiagonalPaddedKeysMask,
LocalAttentionFromBottomRightMask,
}
SUPPORTS_DROPOUT = True
SUPPORTS_CUSTOM_SCALE = True
SUPPORTS_DIFFERENT_VALUE_EMBED = False
SUPPORTS_BMGHK = True
SUPPORTS_PARTIAL = True
NAME = f"flshattF@{FLASH_VERSION}"
VERSION = FLASH_VERSION
@classmethod
def not_supported_reasons(cls, d: Inputs) -> List[str]:
reasons = super(FwOp, cls).not_supported_reasons(d)
check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
_check_needs_no_topleft(d, reasons)
_check_strides_for_bmghk(d.query, "query", reasons)
_check_strides_for_bmghk(d.key, "key", reasons)
_check_strides_for_bmghk(d.value, "value", reasons)
if d.is_partial and isinstance(
d.attn_bias,
(
BlockDiagonalGappyKeysMask,
BlockDiagonalPaddedKeysMask,
),
):
q_seqinfo = d.attn_bias.q_seqinfo
if q_seqinfo.min_seqlen != q_seqinfo.max_seqlen:
# Flash provides padded LSE which we don't handle.
reasons.append("partial attention with heterogeneous queries")
return reasons
@classmethod
def apply(
cls, inp: Inputs, needs_gradient: bool
) -> Tuple[torch.Tensor, Optional[Context]]:
return_softmax = False
original_query_shape = inp.query.shape
out_shape = [
*inp.query.shape[:-1],
inp.value.shape[-1],
]
# no cumulative seqlen
(
inp,
cu_seqlens_q,
max_seqlen_q,
cu_seqlens_k,
max_seqlen_k,
seqused_k,
) = _convert_input_format(inp, supports_mqa=True)
if inp.query.numel() > 0 and inp.key.numel() > 0:
win_left, win_right = _window_size(inp.attn_bias)
out, softmax_lse, rng_state = cls.OPERATOR(
inp.query,
inp.key,
inp.value,
cu_seqlens_q,
cu_seqlens_k,
seqused_k,
max_seqlen_q,
max_seqlen_k,
inp.p,
inp.scale_float,
_is_causal(inp.attn_bias),
window_left=win_left,
window_right=win_right,
return_softmax=return_softmax,
)
out = out.reshape(out_shape)
else:
out = torch.zeros(out_shape, device=inp.query.device, dtype=inp.query.dtype)
rng_state = None
softmax_lse = torch.empty(
[inp.query.shape[0], inp.query.shape[2], inp.query.shape[1]],
device=inp.query.device,
dtype=torch.float32,
)
if not needs_gradient:
return out, None
ctx = Context(
out=out, lse=_post_process_lse(softmax_lse, inp, original_query_shape)
)
if inp.p != 0.0:
ctx.op_bw = BwOp
ctx.rng_state = rng_state
return (out, ctx)
@classmethod
# type: ignore
def operator_flop(
cls,
query,
key,
value,
cu_seq_lens_q,
cu_seq_lens_k,
max_seq_len_q,
max_seq_len_k,
p,
softmax_scale,
causal,
return_softmax,
) -> int:
return cls.attn_operator_flop(
query.unsqueeze(0),
key.unsqueeze(0),
value.unsqueeze(0),
causal=causal,
seqstart_k=cu_seq_lens_k,
seqstart_q=cu_seq_lens_q,
)
@register_operator
class BwOp(AttentionBwOpBase):
__doc__ = FwOp.__doc__
OPERATOR = get_operator("xformers_flash", "flash_bwd")
SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES
CUDA_MINIMUM_COMPUTE_CAPABILITY = FwOp.CUDA_MINIMUM_COMPUTE_CAPABILITY
SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES
SUPPORTED_MAX_K = FwOp.SUPPORTED_MAX_K
SUPPORTED_ATTN_BIAS_TYPES = FwOp.SUPPORTED_ATTN_BIAS_TYPES.difference(
{
BlockDiagonalCausalWithOffsetGappyKeysMask,
BlockDiagonalCausalWithOffsetPaddedKeysMask,
BlockDiagonalGappyKeysMask,
BlockDiagonalPaddedKeysMask,
}
)
SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT
SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE
SUPPORTS_DIFFERENT_VALUE_EMBED = FwOp.SUPPORTS_DIFFERENT_VALUE_EMBED
IS_DETERMINISTIC = False
SUPPORTS_BMGHK = False # NOTE: Don't forget to update fmha doc when changing this!
NAME = f"flshattB@{FLASH_VERSION}"
VERSION = FLASH_VERSION
MAX_HEADDIM_DROPOUT_SM8x = 224
@classmethod
def not_supported_reasons(cls, d: Inputs) -> List[str]:
reasons = super(BwOp, cls).not_supported_reasons(d)
check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
_check_needs_no_topleft(d, reasons)
if d.device.type == "cuda":
# Due to limited shared-memory, some GPUs are limited in head dimension
device_capability = torch.cuda.get_device_capability(d.device)
is_sm80_or_sm90 = device_capability in [(8, 0), (9, 0)]
if (
max(d.key.shape[-1], d.query.shape[-1]) > cls.MAX_HEADDIM_DROPOUT_SM8x
and not is_sm80_or_sm90
and d.p != 0.0
):
reasons.append(
"requires a GPU with compute capability 8.0 "
f"(A100) or 9.0 (H100) for dropout when 'query.shape[-1] > {cls.MAX_HEADDIM_DROPOUT_SM8x}'"
)
return reasons
@classmethod
def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
dq_shape, dk_shape, dv_shape = inp.query.shape, inp.key.shape, inp.value.shape
(
inp,
cu_seqlens_q,
max_seqlen_q,
cu_seqlens_k,
max_seqlen_k,
seqused_k,
) = _convert_input_format(inp, supports_mqa=False)
assert ctx.lse.is_contiguous()
assert seqused_k is None
ctx_lse = ctx.lse
assert ctx_lse.shape[2] >= max_seqlen_q
if max_seqlen_q != ctx_lse.shape[2]:
ctx_lse = ctx_lse[:, :, :max_seqlen_q].contiguous()
kernel_out_shape = [
*inp.query.shape[:-1],
inp.value.shape[-1],
]
# Create dq,dk,dv
# If Q/K/V come from a single QKV tensor, let's put the gradient in the
# right strides, so we can avoid a `cat`
if (
inp.query.shape[0] == inp.key.shape[0]
and inp.query.shape[-1] == inp.value.shape[-1]
and _get_storage_base(inp.query) == _get_storage_base(inp.key)
and _get_storage_base(inp.query) == _get_storage_base(inp.value)
):
# Create one big contiguous chunk
# This is because q, k and v usually come from a single
# output of a linear layer that is chunked.
# Creating the gradients with the right layout saves us
# a `torch.cat` call in the backward pass
chunk = torch.empty(
(*inp.query.shape[0:-2], 3, inp.query.shape[-2], inp.query.shape[-1]),
dtype=inp.query.dtype,
device=inp.device,
)
grads = Gradients(
dq=chunk.select(-3, 0),
dk=chunk.select(-3, 1),
dv=chunk.select(-3, 2),
)
else:
grads = Gradients(
dq=torch.empty_like(inp.query),
dk=torch.empty_like(inp.key),
dv=torch.empty_like(inp.value),
)
assert grad.dtype in cls.SUPPORTED_DTYPES
if grads.dq.numel() == 0:
grads.dk.zero_()
grads.dv.zero_()
if grads.dv.numel() == 0:
grads.dq.zero_()
if grads.dq.numel() and grads.dk.numel():
win_left, win_right = _window_size(inp.attn_bias)
cls.OPERATOR(
grad.reshape(kernel_out_shape).contiguous(),
inp.query,
inp.key,
inp.value,
ctx.out.reshape(kernel_out_shape),
ctx_lse,
grads.dq,
grads.dk,
grads.dv,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
inp.p,
inp.scale_float,
_is_causal(inp.attn_bias),
window_left=win_left,
window_right=win_right,
rng_state=ctx.rng_state,
)
grads.dq = grads.dq.reshape(dq_shape)
grads.dk = grads.dk.reshape(dk_shape)
grads.dv = grads.dv.reshape(dv_shape)
return grads
@classmethod
# type: ignore
def operator_flop(
cls,
grad,
query,
key,
value,
out,
lse,
dq,
dk,
dv,
cu_seq_lens_q,
cu_seq_lens_k,
max_seq_len_q,
max_seq_len_k,
p,
softmax_scale,
causal,
) -> int:
return cls.attn_operator_flop(
query.unsqueeze(0),
key.unsqueeze(0),
value.unsqueeze(0),
causal=causal,
seqstart_k=cu_seq_lens_k,
seqstart_q=cu_seq_lens_q,
)