forked from NVIDIA/Megatron-LM
-
Notifications
You must be signed in to change notification settings - Fork 328
/
transformer.py
2076 lines (1807 loc) · 92.7 KB
/
transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""Transformer."""
from contextlib import nullcontext
import math
import numpy as np
import torch
import torch.nn.functional as F
from typing import Optional
from megatron import get_timers, get_args, get_retro_args, core, get_num_microbatches
from .module import MegatronModule
from megatron.core import parallel_state, tensor_parallel, mpu
from megatron.core.enums import ModelType
from megatron.model import LayerNorm, RMSNorm
from megatron.model.enums import AttnMaskType, LayerType, AttnType
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.rotary_pos_embedding import apply_rotary_pos_emb
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
import deepspeed
from deepspeed.moe.layer import MoE
from deepspeed.accelerator import get_accelerator
try:
from deepspeed.sequence.layer import DistributedAttention
dist_attn_supported = True
except ImportError:
dist_attn_supported = False
try:
from einops import rearrange
except ImportError:
rearrange = None
try:
# FlashAttention (1.x)
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
from flash_attn.flash_attn_triton import flash_attn_func
except ImportError:
flash_attn_unpadded_func = None
flash_attn_func = None
try:
# FlashAttention-2
from flash_attn.flash_attn_interface import flash_attn_varlen_func
except ImportError:
flash_attn_varlen_func = None
FlashAttentionBuilder = get_accelerator().get_op_builder("FlashAttentionBuilder")
flash_attn_builder = None
""" We use the following notation throughout this file:
h: hidden size
n: number of attention heads
p: number of model parallel partitions
np: n/p
hp: h/p
hn: h/n
b: batch size
s: sequence length
l: number of layers
Transformer takes input of size [s, b, h] and returns a
tensor of the same size. We use the following arguments:
hyperparameters: transformer hyperparameters
"""
class DropPath(MegatronModule):
"""Drop paths (Stochastic Depth) per sample
(when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=0.):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, hidden_state):
if self.drop_prob == 0. or not self.training:
return hidden_state
keep_prob = 1 - self.drop_prob
# work with diff dim tensors, not just 2D ConvNets
# hidden_state: [s, b, h]
shape = (1,) + (hidden_state.shape[1],) + (1,) * (hidden_state.ndim - 2)
random_tensor = keep_prob + \
torch.rand(shape, dtype=hidden_state.dtype, device=hidden_state.device)
random_tensor.floor_() # binarize
output = hidden_state.div(keep_prob) * random_tensor
return output
class ParallelMLP(MegatronModule):
"""MLP.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension.
"""
def __init__(self, config, moe=False, enable_expert_tensor_parallelism=False):
super(ParallelMLP, self).__init__()
args = get_args()
self.add_bias = config.add_bias_linear
ffn_hidden_size = config.ffn_hidden_size
if config.gated_linear_unit:
ffn_hidden_size *= 2
# Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
ffn_hidden_size,
config=config,
init_method=config.init_method,
bias=self.add_bias,
gather_output=False,
skip_bias_add=True,
moe=moe,
enable_expert_tensor_parallelism=enable_expert_tensor_parallelism
)
self.bias_gelu_fusion = False
self.activation_func = None
self.swiglu = args.swiglu
if args.openai_gelu:
self.activation_func = openai_gelu
elif args.onnx_safe:
self.activation_func = erf_gelu
elif args.swiglu:
def swiglu(x):
x = torch.chunk(x, 2, dim=-1)
return F.silu(x[0]) * x[1]
self.activation_func = swiglu
elif args.squared_relu:
def squared_relu(x):
return torch.pow(F.relu(x), 2)
self.activation_func = squared_relu
else:
self.bias_gelu_fusion = args.bias_gelu_fusion
self.activation_func = F.gelu
# Project back to h.
self.dense_4h_to_h = tensor_parallel.RowParallelLinear(
config.ffn_hidden_size,
config.hidden_size,
config=config,
init_method=config.output_layer_init_method,
bias=self.add_bias,
input_is_parallel=True,
moe=moe,
enable_expert_tensor_parallelism=enable_expert_tensor_parallelism
)
def forward(self, hidden_states):
# [s, b, 4hp]
intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
if self.bias_gelu_fusion:
assert self.add_bias is True
# DeepSpeed FLOPS profiler temporarily substitues functions like F.gelu to calculate the throughput
assert hasattr(self, "__flops__") or self.activation_func == F.gelu
intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
else:
if bias_parallel is not None:
intermediate_parallel = intermediate_parallel + bias_parallel
intermediate_parallel = self.activation_func(intermediate_parallel)
# [s, b, h]
output, output_bias = self.dense_4h_to_h(intermediate_parallel)
return output, output_bias
class SwitchMLP(MegatronModule):
"""
Routes input to one of N MLP "experts"
"""
def __init__(self, config):
super(SwitchMLP, self).__init__()
args = get_args()
self.router = torch.nn.Linear(config.hidden_size, args.num_experts_switch)
self.experts = torch.nn.ModuleList()
for i in range(args.num_experts_switch):
self.experts.append(ParallelMLP(config))
def forward(self, hidden_states):
# hidden_states: [s, b, h]
s = hidden_states.size(0)
b = hidden_states.size(1)
h = hidden_states.size(2)
route = self.router(hidden_states)
route = torch.nn.functional.softmax(route, dim=2)
max_prob, max_ind = torch.max(route, dim=2)
max_prob = torch.unsqueeze(max_prob, 2) # [s b 1]
# TODO (rprenger) TODO this could be made easier to read
# Converting [s, b, h] to [s*b, h].
# Each vector could be routed differently
hidden_states = hidden_states.view(-1, hidden_states.size(2)) # [s*b h]
max_prob = max_prob.view(-1, max_prob.size(2)) # [s*b 1]
max_ind = max_ind.view(-1) # [s*b]
output_total = torch.empty_like(hidden_states)
output_bias_total = torch.empty_like(hidden_states)
#TODO (rprenger) This does each expert in serial, but it could be parallelized
for expert_num, expert in enumerate(self.experts):
local_indices = (max_ind == expert_num).nonzero()
hidden = hidden_states[local_indices,:]
output, output_bias = expert(hidden)
output_bias = output_bias.expand_as(output)
output_total[local_indices,:] = output
output_bias_total[local_indices,:] = output_bias
output_total = output_total*max_prob
output_bias_total = output_bias_total*max_prob
output_total = output_total.view(s, b, h)
output_bias_total = output_bias_total.view(s, b, h)
return output_total, output_bias_total
class CoreAttention(MegatronModule):
def __init__(self, layer_number, config,
attn_mask_type=AttnMaskType.padding):
super(CoreAttention, self).__init__()
self.fp16 = config.fp16
self.bf16 = config.bf16
self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
self.layer_number = max(1, layer_number)
self.attn_mask_type = attn_mask_type
self.sequence_parallel = config.sequence_parallel
projection_size = config.kv_channels * config.num_attention_heads
# Per attention head and per partition values.
seq_parallel_world_size = 1
if parallel_state.sequence_parallel_is_initialized():
seq_parallel_world_size = parallel_state.get_sequence_parallel_world_size()
world_size = seq_parallel_world_size if seq_parallel_world_size > 1 else parallel_state.get_tensor_model_parallel_world_size()
self.hidden_size_per_partition = core.utils.divide(projection_size,
world_size)
self.hidden_size_per_attention_head = core.utils.divide(
projection_size, config.num_attention_heads)
self.num_attention_heads_per_partition = core.utils.divide(
config.num_attention_heads, world_size)
coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
if self.apply_query_key_layer_scaling:
coeff = self.layer_number
self.norm_factor *= coeff
self.scale_mask_softmax = FusedScaleMaskSoftmax(
self.fp16, self.bf16,
self.attn_mask_type,
config.masked_softmax_fusion,
attention_mask_func,
self.attention_softmax_in_fp32,
coeff)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
def forward(self, query_layer, key_layer,
value_layer, attention_mask):
# ===================================
# Raw attention scores. [b, np, s, s]
# ===================================
# [b, np, sq, sk]
output_size = (query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0))
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(output_size[2],
output_size[0] * output_size[1], -1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(output_size[3],
output_size[0] * output_size[1], -1)
# preallocting input tensor: [b * np, sq, sk]
matmul_input_buffer = parallel_state.get_global_memory_buffer().get_tensor(
(output_size[0]*output_size[1], output_size[2], output_size[3]),
query_layer.dtype, "mpu")
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
matmul_input_buffer,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0, alpha=(1.0/self.norm_factor))
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
# ===========================
# Attention probs and dropout
# ===========================
# attention scores and attention mask [b, np, sq, sk]
attention_probs = self.scale_mask_softmax(attention_scores,
attention_mask)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
if not self.sequence_parallel:
with tensor_parallel.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
else:
attention_probs = self.attention_dropout(attention_probs)
# =========================
# Context layer. [sq, b, hp]
# =========================
# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size = (value_layer.size(1),
value_layer.size(2),
query_layer.size(0),
value_layer.size(3))
# change view [sk, b * np, hn]
value_layer = value_layer.view(value_layer.size(0),
output_size[0] * output_size[1], -1)
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1],
output_size[2], -1)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer
class FlashSelfAttention(torch.nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
device=None, dtype=None):
super().__init__()
assert flash_attn_unpadded_func is not None or flash_attn_varlen_func is not None or flash_attn_builder is not None, \
('Please install FlashAttention first, e.g., with pip install flash-attn or implement your own flash attention')
assert rearrange is not None, 'Please install einops first, e.g., with pip install einops'
self.causal = causal
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
# Use FlashAttention-2 when args.use_flash_attn_v2 is True
args = get_args()
self.use_flash_attn_builder_v1 = False
self.use_flash_attn_builder_v2 = False
self.use_flash_attn = False
if args.use_flash_attn_builder:
if hasattr(flash_attn_builder, 'flash_attn_func'):
self.flash_attn_func = flash_attn_builder.flash_attn_func
self.use_flash_attn_builder_v1 = True
else:
self.flash_attn_func = flash_attn_builder.flash_attn_func_v2
self.use_flash_attn_builder_v2 = True
else:
self.flash_attn_func = flash_attn_varlen_func if args.use_flash_attn_v2 else flash_attn_unpadded_func
self.use_flash_attn = True
def forward(self, q, k, v):
"""Implements the multihead softmax attention.
Arguments
---------
q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
"""
assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q,k,v)))
assert all((get_accelerator().on_accelerator(i) for i in (q, k, v)))
batch_size, seqlen_q = q.shape[0], q.shape[1]
seqlen_k = k.shape[1]
if self.use_flash_attn:
q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]]
cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32,
device=q.device)
elif self.use_flash_attn_builder_v1:
q, k, v = [rearrange(x, 'b s h d -> b h s d').contiguous() for x in [q, k, v]]
else:
# use_flash_attn_builder_v2
q, k, v = [rearrange(x, 'b s h d -> b h s d') for x in [q, k, v]]
if self.training:
# during training q,k,v always have same seqlen
assert seqlen_k == seqlen_q
is_causal = self.causal
cu_seqlens_k = cu_seqlens_q if get_accelerator().device_name() == 'cuda' else None
dropout_p = self.dropout_p
else:
# turn off FA causal mask after first inference autoregressive iteration
# only on first autoregressive step q,k,v have same seqlen
is_causal = seqlen_q == seqlen_k
cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32,
device=q.device) if get_accelerator().device_name() == 'cuda' else None
dropout_p = 0
if self.use_flash_attn:
output = self.flash_attn_func(
q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
dropout_p,
softmax_scale=self.softmax_scale, causal=is_causal
)
else:
# use_flash_attn_builder
output = self.flash_attn_func(
q, k, v, self.dropout_p, self.softmax_scale, is_causal
)
if self.use_flash_attn:
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
elif self.use_flash_attn_builder_v1:
output = rearrange(output, 'b h s d -> b s h d').contiguous()
else:
# use_flash_attn_builder_v2:
output = rearrange(output, 'b h s d -> b s h d')
return output
class FlashSelfAttentionTriton(torch.nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
device=None, dtype=None):
super().__init__()
assert flash_attn_func is not None, ('Triton version of FlashAttention is not installed.')
assert rearrange is not None, 'Please install einops first, e.g., with pip install einops'
self.causal = causal
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
def forward(self, q, k, v):
"""Implements the multihead softmax attention.
Arguments
---------
q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
"""
assert q.dtype in [torch.float16, torch.bfloat16]
assert q.is_cuda
q, k, v = [rearrange(x, 's b ... -> b s ...').contiguous()
for x in (q, k, v)]
output = flash_attn_func(q, k, v, None, self.causal)
output = rearrange(output, 'b s h d -> s b (h d)').contiguous()
return output
class ParallelAttention(MegatronModule):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""
def __init__(self, config, layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=AttnMaskType.padding):
super(ParallelAttention, self).__init__()
args = get_args()
self.layer_number = max(1, layer_number)
self.attention_type = attention_type
self.attn_mask_type = attn_mask_type
self.params_dtype = config.params_dtype
self.sequence_parallel = config.sequence_parallel
self.num_attention_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.use_gqa = (self.num_attention_heads != self.num_key_value_heads)
self.use_flash_attn = (args.use_flash_attn_v1 or args.use_flash_attn_triton or args.use_flash_attn_v2 or \
args.use_flash_attn_builder) \
and attention_type == AttnType.self_attn \
and self.attn_mask_type == AttnMaskType.causal
self.use_flash_attn_triton = args.use_flash_attn_triton
if self.use_flash_attn:
global flash_attn_builder
try:
flash_attn_builder = FlashAttentionBuilder().load()
except TypeError:
flash_attn_builder = None
if args.use_flash_attn_v1:
assert flash_attn_unpadded_func != None, "Cannot import FlashAttention v1 "
if args.use_flash_attn_v2:
assert flash_attn_varlen_func != None, "Cannot import FlashAttention v2 "
if args.use_flash_attn_triton:
assert flash_attn_func != None, "Cannot import FlashAttention triton "
if args.use_flash_attn_builder:
assert flash_attn_builder != None, "Cannot find FlashAttention op builder "
assert attention_type == AttnType.self_attn, ('FlashAttention code path only supports '
'self-attention for now')
assert self.attn_mask_type == AttnMaskType.causal, ('FlashAttention code path only '
'supports causal mask for now')
if rearrange is None:
raise ImportError('einops is not installed, please install with pip install einops')
projection_size = config.kv_channels * config.num_attention_heads
# Per attention head and per partition values.
world_size = parallel_state.get_tensor_model_parallel_world_size()
self.hidden_size_per_attention_head = core.utils.divide(
projection_size, config.num_attention_heads)
self.num_attention_heads_per_partition = core.utils.divide(
config.num_attention_heads, world_size)
# Per GQA head and per partition values
self.num_key_value_heads_per_partition = core.utils.divide(
config.num_key_value_heads, world_size)
self.num_key_value_groups = core.utils.divide(
config.num_attention_heads, config.num_key_value_heads)
kv_projection_size = config.kv_channels * config.num_key_value_heads
assert self.hidden_size_per_attention_head == core.utils.divide(
kv_projection_size, config.num_key_value_heads)
# Strided linear layer.
if attention_type == AttnType.self_attn:
self.query_key_value = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
projection_size + 2 * kv_projection_size,
config=config,
init_method=config.init_method,
bias=args.add_bias_linear,
gather_output=False)
else:
assert attention_type == AttnType.cross_attn
self.query = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
projection_size,
config=config,
init_method=config.init_method,
bias=config.add_bias_linear,
gather_output=False)
self.key_value = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
2 * projection_size,
config=config,
init_method=config.init_method,
bias=config.add_bias_linear,
gather_output=False)
# Currently FlashAttention only works with causal mask
if self.use_flash_attn_triton:
local_attn = FlashSelfAttentionTriton(causal=True, attention_dropout=args.attention_dropout)
elif self.use_flash_attn:
local_attn = FlashSelfAttention(causal=True, attention_dropout=config.attention_dropout)
else:
local_attn = CoreAttention(self.layer_number, config, self.attn_mask_type)
self.enable_ds_sequence_parallel = parallel_state.get_sequence_parallel_world_size() > 1 \
or args.force_ds_sequence_parallel
if self.enable_ds_sequence_parallel:
assert dist_attn_supported, 'Distributed attention is not supported in this DeepSpeed version'
assert args.num_attention_heads % parallel_state.get_sequence_parallel_world_size() == 0
self.dist_attn = DistributedAttention(local_attn, parallel_state.get_sequence_parallel_group())
else:
if self.use_flash_attn:
self.core_attention_flash = local_attn
else:
self.core_attention = local_attn
self.checkpoint_core_attention = config.recompute_granularity == 'selective'
# Output.
self.dense = tensor_parallel.RowParallelLinear(
projection_size,
config.hidden_size,
config=config,
init_method=config.output_layer_init_method,
bias=args.add_bias_linear,
input_is_parallel=True,
skip_bias_add=True)
def _checkpointed_attention_forward(self, query_layer, key_layer,
value_layer, attention_mask,
rotary_pos_emb=None):
"""Forward method with activation checkpointing."""
def custom_forward(*inputs):
query_layer = inputs[0]
key_layer = inputs[1]
value_layer = inputs[2]
attention_mask = inputs[3]
output_ = self.core_attention(query_layer, key_layer,
value_layer, attention_mask)
return output_
q_pos_emb, k_pos_emb = (None, None) if rotary_pos_emb is None \
else rotary_pos_emb
hidden_states = tensor_parallel.checkpoint(
custom_forward,
False, query_layer, key_layer, value_layer, attention_mask,
q_pos_emb, k_pos_emb)
return hidden_states
def _allocate_memory(self, inference_max_sequence_len, batch_size):
return torch.empty(
inference_max_sequence_len,
batch_size,
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
dtype=self.params_dtype,
device=get_accelerator().current_device_name())
def repeat_kv(self, hidden_states, n_rep):
slen, batch, num_key_value_heads_per_partition, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, :, None, :].expand(
slen, batch, num_key_value_heads_per_partition, n_rep, head_dim)
return hidden_states.reshape(slen, batch,
num_key_value_heads_per_partition * n_rep,
head_dim)
def split_tensor(self, mixed_x_layer):
query_layer = mixed_x_layer[:, :, :, :-2, :].reshape(mixed_x_layer.shape[:2] + (-1, self.hidden_size_per_attention_head))
key_layer = mixed_x_layer[:, :, :, -2, :]
value_layer = mixed_x_layer[:, :, :, -1, :]
return query_layer, key_layer, value_layer
def forward(self, hidden_states, attention_mask,
encoder_output=None, inference_params=None,
rotary_pos_emb=None):
# hidden_states: [sq, b, h]
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
is_first_step = False
if inference_params:
if self.layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_len = inference_params.max_sequence_len
inf_max_batch_size = inference_params.max_batch_size
inference_key_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size)
inference_value_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size)
inference_params.key_value_memory_dict[self.layer_number] = (
inference_key_memory, inference_value_memory)
is_first_step = True
else:
inference_key_memory, inference_value_memory = \
inference_params.key_value_memory_dict[self.layer_number]
# =====================
# Query, Key, and Value
# =====================
if self.attention_type == AttnType.self_attn:
# Attention heads [sq, b, h] --> [sq, b, ((nq + 2 * nkv) * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states)
# [sq, b, ((nq + 2 * nkv) * hn)] --> [sq, b, nkv, (nq // nkv + 2), hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + \
(-1, (self.num_key_value_groups + 2),
self.hidden_size_per_attention_head)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, nkv, (nq // nkv + 2), hn] --> 3 [sq, b, np, hn]
(query_layer,
key_layer,
value_layer) = self.split_tensor(mixed_x_layer)
# Repeat kv
if self.use_gqa:
key_layer = self.repeat_kv(key_layer, self.num_key_value_groups)
value_layer = self.repeat_kv(value_layer,
self.num_key_value_groups)
else:
assert not self.use_gqa, 'GQA + cross-attn not tested yet'
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer, _ = self.key_value(encoder_output)
# [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
new_tensor_shape = mixed_kv_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
2 * self.hidden_size_per_attention_head)
mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(key_layer,
value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2)
# Attention head [sq, b, h] --> [sq, b, hp]
query_layer, _ = self.query(hidden_states)
# [sq, b, hp] --> [sq, b, np, hn]
new_tensor_shape = query_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
query_layer = query_layer.view(*new_tensor_shape)
# ==================================
# Adjust key and value for inference
# ==================================
# duplicate the pos_emb for self attention
if rotary_pos_emb is not None:
if isinstance(rotary_pos_emb, tuple):
rotary_pos_emb = rotary_pos_emb
else:
rotary_pos_emb = ((rotary_pos_emb,) * 2)
if inference_params:
batch_start = inference_params.batch_size_offset
batch_end = batch_start + key_layer.size(1)
assert batch_end <= inference_key_memory.size(1)
sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + key_layer.size(0)
assert sequence_end <= inference_key_memory.size(0)
# Copy key and values.
inference_key_memory[sequence_start:sequence_end,
batch_start:batch_end, ...] = key_layer
inference_value_memory[sequence_start:sequence_end,
batch_start:batch_end, ...] = value_layer
key_layer = inference_key_memory[
:sequence_end, batch_start:batch_end, ...]
value_layer = inference_value_memory[
:sequence_end, batch_start:batch_end, ...]
# adjust the key rotary positional embedding
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb
# need to cross check this condition during inference
# if not set_inference_key_value_memory:
if not is_first_step:
# In inference, we compute one token at a time.
# Select the correct positional embedding
# (only the last token in the sequence)
q_pos_emb = q_pos_emb[sequence_end - 1 : sequence_end]
else:
# In the first forward pass of inference,
# we use the entire provided prefix.
# q_pos_emb here has the rope embeddings of the entire
# prefix + to-be-generated output so
# we slice to just the prefix.
q_pos_emb = q_pos_emb[:sequence_end, :, :, :]
k_pos_emb = k_pos_emb[:sequence_end, :, :, :]
rotary_pos_emb = (q_pos_emb, k_pos_emb)
# ==================================
# core attention computation
# ==================================
# apply relative positional encoding (rotary embedding)
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb
query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb)
key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb)
# TODO, can apply positional embedding to value_layer so it has
# absolute positional embedding.
# otherwise, only relative positional embedding takes effect
# value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb)
if self.enable_ds_sequence_parallel:
if self.use_flash_attn:
if not self.use_flash_attn_triton:
query_layer, key_layer, value_layer = [rearrange(x, 's b ... -> b s ...').contiguous()
for x in (query_layer, key_layer, value_layer)]
context_layer = self.dist_attn(query_layer, key_layer, value_layer)
if not self.use_flash_attn_triton:
context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous()
else:
context_layer = self.dist_attn(query_layer, key_layer, value_layer, attention_mask)
else:
if self.use_flash_attn:
if not self.use_flash_attn_triton:
query_layer, key_layer, value_layer = [rearrange(x, 's b ... -> b s ...').contiguous()
for x in (query_layer, key_layer, value_layer)]
if self.sequence_parallel:
context_layer = self.core_attention_flash(query_layer, key_layer, value_layer)
else:
with tensor_parallel.get_cuda_rng_tracker().fork():
context_layer = self.core_attention_flash(query_layer, key_layer, value_layer)
if not self.use_flash_attn_triton:
context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous()
else:
if self.checkpoint_core_attention:
context_layer = self._checkpointed_attention_forward(
query_layer, key_layer, value_layer, attention_mask)
else:
context_layer = self.core_attention(
query_layer, key_layer, value_layer, attention_mask)
# =================
# Output. [sq, b, h]
# =================
output, bias = self.dense(context_layer)
return output, bias
def bias_dropout_add(x, bias, residual, prob, training):
# type: (Tensor, Optional[Tensor], Tensor, float, bool) -> Tensor
if bias is not None:
x = x + bias
out = torch.nn.functional.dropout(x, p=prob, training=training)
out = residual + out
return out
def get_bias_dropout_add(training):
def _bias_dropout_add(x, bias, residual, prob):
return bias_dropout_add(x, bias, residual, prob, training)
return _bias_dropout_add
@torch.jit.script
def bias_dropout_add_fused_train(x: torch.Tensor,
bias: Optional[torch.Tensor],
residual: torch.Tensor,
prob: float) -> torch.Tensor:
return bias_dropout_add(x, bias, residual, prob, True)
@torch.jit.script
def bias_dropout_add_fused_inference(x: torch.Tensor,
bias: Optional[torch.Tensor],
residual: torch.Tensor,
prob: float) -> torch.Tensor:
return bias_dropout_add(x, bias, residual, prob, False)
class ParallelTransformerLayer(MegatronModule):
"""A single transformer layer.
Transformer layer takes input with size [s, b, h] and returns an
output of the same size.
"""
def __init__(self, config,
layer_number, layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding,
drop_path_rate=0., num_experts=1):
# retriever=None):
args = get_args()
super(ParallelTransformerLayer, self).__init__()
self.layer_number = layer_number
self.layer_type = layer_type
self.apply_residual_connection_post_layernorm \
= config.apply_residual_connection_post_layernorm
self.bf16 = config.bf16
self.fp32_residual_connection = config.fp32_residual_connection
# Layernorm on the input data.
if args.normalization == 'layernorm':
if get_accelerator().device_name() == 'cuda':
self.input_layernorm = LayerNorm(
config.hidden_size,
eps=config.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=config.sequence_parallel,
apply_layernorm_1p=args.apply_layernorm_1p,
mem_efficient_ln=args.mem_efficient_ln)
else:
self.input_layernorm = LayerNorm(
config.hidden_size,
eps=config.layernorm_epsilon)
else:
self.input_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon)
# Self attention.
self.self_attention = ParallelAttention(
config,
layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=self_attn_mask_type)
self.hidden_dropout = config.hidden_dropout
self.bias_dropout_fusion = config.bias_dropout_fusion
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None
# Layernorm on the attention output
if args.normalization == 'layernorm':
if get_accelerator().device_name() == 'cuda':
self.post_attention_layernorm = LayerNorm(
config.hidden_size,
eps=config.layernorm_epsilon,
no_persist_layer_norm=not config.persist_layer_norm,
sequence_parallel=config.sequence_parallel,
apply_layernorm_1p=args.apply_layernorm_1p,
mem_efficient_ln=args.mem_efficient_ln)
else:
self.post_attention_layernorm = LayerNorm(
config.hidden_size,
eps=config.layernorm_epsilon)
else:
self.post_attention_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon)
# Cross attention.
if self.layer_type in (LayerType.decoder,
LayerType.retro_decoder,
LayerType.retro_decoder_with_retriever,
LayerType.retro_encoder):
self.inter_attention = ParallelAttention(
config,
layer_number,
attention_type=AttnType.cross_attn)
# Layernorm on the attention output.
if args.normalization == 'layernorm':
self.post_inter_attention_layernorm = LayerNorm(
config.hidden_size,
eps=config.layernorm_epsilon,
no_persist_layer_norm=not config.persist_layer_norm,
sequence_parallel=config.sequence_parallel,
apply_layernorm_1p=args.apply_layernorm_1p,
mem_efficient_ln=args.mem_efficient_ln)
else:
self.post_inter_attention_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon)
# MLP
self.num_experts = num_experts
if args.num_experts_switch is not None:
self.mlp = SwitchMLP(config) # Megatron-LM's MoE
else:
if self.num_experts <= 1: # dense, not MoE
self.mlp = ParallelMLP(config)
else: # DeepSpeed's MoE
enable_expert_tensor_parallelism = args.enable_expert_tensor_parallelism
self.mlp = MoE(args.hidden_size,
ParallelMLP(config,
moe=True,
enable_expert_tensor_parallelism=enable_expert_tensor_parallelism),
num_experts=self.num_experts,
ep_size=args.moe_expert_parallel_size,
k=args.topk,
use_residual=(args.mlp_type == 'residual'),
capacity_factor=args.moe_train_capacity_factor,
eval_capacity_factor=args.moe_eval_capacity_factor,
min_capacity=args.moe_min_capacity,
drop_tokens=args.moe_token_dropping,
use_tutel=args.use_tutel,
enable_expert_tensor_parallelism=enable_expert_tensor_parallelism,
top2_2nd_expert_sampling=args.moe_top2_2nd_expert_sampling)
# Set bias+dropout+add fusion grad_enable execution handler.
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
self.bias_dropout_add_exec_handler = \
nullcontext if use_nvfuser else torch.enable_grad
if args.retro_add_retriever:
retro_args = get_retro_args()
self.retro_num_neighbors = args.retro_num_neighbors