-
Notifications
You must be signed in to change notification settings - Fork 4.1k
/
stage_1_and_2.py
executable file
·2418 lines (1950 loc) · 113 KB
/
stage_1_and_2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
import os
from deepspeed import comm as dist
from packaging import version as pkg_version
from collections import OrderedDict
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from deepspeed.runtime import ZeROOptimizer
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
from deepspeed.runtime.utils import (bwc_tensor_model_parallel_rank, get_global_norm, empty_cache, see_memory_usage,
inf, is_model_parallel_parameter, align_dense_tensors, all_gather_dp_groups)
from deepspeed.runtime.zero.config import ZeroStageEnum
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
from deepspeed.ops.adam import DeepSpeedCPUAdam
from deepspeed.utils import logger
from deepspeed.moe.utils import is_moe_param
from deepspeed.git_version_info import version
from deepspeed.runtime.constants import PIPE_REPLICATED
from deepspeed.accelerator import get_accelerator
from deepspeed.checkpoint.constants import (DS_VERSION, GROUP_PADDINGS, PARTITION_COUNT, LOSS_SCALER,
SINGLE_PARTITION_OF_FP32_GROUPS, BASE_OPTIMIZER_STATE,
BASE_OPTIMIZER_STATE_STEP, CLIP_GRAD, ZERO_STAGE, PARAM_SLICE_MAPPINGS)
from deepspeed.utils import link_hp_params
from deepspeed.checkpoint import enable_universal_checkpoint
from deepspeed.utils import groups
# Toggle this to true to enable correctness test
# with gradient partitioning and without
pg_correctness_test = False
OPTIMIZER_ALLGATHER_TIMER = 'optimizer_allgather'
OPTIMIZER_GRADIENTS_TIMER = 'optimizer_gradients'
OPTIMIZER_STEP_TIMER = 'optimizer_step'
OPTIMIZER_TIMERS = [OPTIMIZER_ALLGATHER_TIMER, OPTIMIZER_GRADIENTS_TIMER, OPTIMIZER_STEP_TIMER]
def input(msg):
return
def split_half_float_double(tensors):
device_type = get_accelerator().device_name()
dtypes = [
"torch.{}.HalfTensor".format(device_type), "torch.{}.FloatTensor".format(device_type),
"torch.{}.DoubleTensor".format(device_type), "torch.{}.BFloat16Tensor".format(device_type)
]
buckets = []
for i, dtype in enumerate(dtypes):
bucket = [t for t in tensors if t.type() == dtype]
if bucket:
buckets.append(bucket)
return buckets
def isclose(a, b, rtol=1e-09, atol=0.0):
return abs(a - b) <= max(rtol * max(abs(a), abs(b)), atol)
def lcm(x, y):
from fractions import gcd # or can import gcd from `math` in Python 3
return x * y // gcd(x, y)
def get_alignment_padding(tensor_list, alignment):
num_elements = sum([tensor.numel() for tensor in tensor_list])
remainder = num_elements % alignment
return (alignment - remainder) if remainder else remainder
def move_to_cpu(tensor_list):
for tensor in tensor_list:
tensor.data = tensor.data.cpu()
def print_rank_msg(msg):
print(f"rank {dist.get_rank()} - {msg}")
def _get_padded_tensor(src_tensor, size):
if src_tensor.numel() >= size:
return src_tensor
padded_tensor = torch.zeros(size, dtype=src_tensor.dtype, device=src_tensor.device)
slice_tensor = torch.narrow(padded_tensor, 0, 0, src_tensor.numel())
slice_tensor.data.copy_(src_tensor.data)
return padded_tensor
class DeepSpeedZeroOptimizer(ZeROOptimizer):
"""
DeepSpeedZeroOptimizer designed to reduce the memory footprint
required for training large deep learning models.
For more details please see ZeRO: Memory Optimization Towards Training A Trillion Parameter Models
https://arxiv.org/abs/1910.02054
For usage examples, refer to TODO: DeepSpeed Tutorial
"""
def __init__(self,
init_optimizer,
param_names,
timers,
static_loss_scale=1.0,
dynamic_loss_scale=False,
dynamic_loss_args=None,
verbose=True,
contiguous_gradients=True,
reduce_bucket_size=500000000,
allgather_bucket_size=5000000000,
dp_process_group=None,
expert_parallel_group=None,
expert_data_parallel_group=None,
reduce_scatter=True,
overlap_comm=False,
offload_optimizer_config=None,
mpu=None,
clip_grad=0.0,
gradient_accumulation_dtype=torch.float32,
communication_data_type=torch.float16,
postscale_gradients=True,
gradient_predivide_factor=1.0,
gradient_accumulation_steps=1,
ignore_unused_parameters=True,
partition_grads=True,
round_robin_gradients=False,
has_moe_layers=False,
fp16_master_weights_and_gradients=False,
elastic_checkpoint=False):
if offload_optimizer_config is not None and offload_optimizer_config.device != OffloadDeviceEnum.none:
self.cpu_offload = True
self.cpu_offload_pin_memory = offload_optimizer_config.pin_memory
else:
self.cpu_offload = False
self.cpu_offload_pin_memory = False
if dist.get_rank() == 0:
logger.info(f"Reduce bucket size {reduce_bucket_size}")
logger.info(f"Allgather bucket size {allgather_bucket_size}")
logger.info(f"CPU Offload: {self.cpu_offload}")
logger.info(f'Round robin gradient partitioning: {round_robin_gradients}')
# The fused optimizer does all the work. We need this layer for two reason:
# 1. maintain same user API from apex.fp16_utils
# 2. keep common stuff here in case we need to add ne552w fused optimizer later
self.elastic_checkpoint = elastic_checkpoint
self.param_names = param_names
self.mpu = mpu
# differences from apex.fp16_utils:
# - assume all model params in fp16
# - assume all params requires grad
# - flat by groups, not keeping state. TODO: remove state explicitly?
# - master grad and unflat master weight never exist. TODO: a way to save out unflat master?
if not get_accelerator().is_available():
raise SystemError("Accelerator is not detected, cannot perform low precision training (e.g., fp16, bf16).")
self.optimizer = init_optimizer
# Use torch (un)flatten ops
self.flatten = _flatten_dense_tensors
self.unflatten = _unflatten_dense_tensors
# ZeRO stage 1 (False) or 2 (True)
self.partition_gradients = partition_grads
self.zero_stage_string = "ZeRO-2" if partition_grads else "ZeRO-1"
self.timers = timers
self.reduce_scatter = reduce_scatter
self.overlap_comm = overlap_comm
self.deepspeed_adam_offload = self.cpu_offload
self.device = get_accelerator().current_device_name() if not self.cpu_offload else 'cpu'
self.dp_process_group = dp_process_group
self.sequence_parallel_size = groups._get_sequence_parallel_world_size()
#expert parallel group
self.ep_process_group = expert_parallel_group
#data parallel group for experts
self.expert_dp_process_group = expert_data_parallel_group
#data parallel size for non-experts
dp_size = dist.get_world_size(group=self.dp_process_group)
#For MoE models this maybe different for different param group
#It will be modified during MoE setup later in the init
self.real_dp_process_group = [dp_process_group for i in range(len(self.optimizer.param_groups))]
self.partition_count = [dp_size for i in range(len(self.optimizer.param_groups))]
self.is_gradient_accumulation_boundary = True
# CPU-Offload requires contiguous gradients
self.contiguous_gradients = contiguous_gradients or self.cpu_offload
self.has_moe_layers = has_moe_layers
if self.has_moe_layers:
self._configure_moe_settings()
self._global_grad_norm = 0.
if mpu is None:
self.model_parallel_group = None
self.model_parallel_world_size = 1
self.model_parallel_rank = 0
else:
self.model_parallel_group = mpu.get_model_parallel_group()
self.model_parallel_world_size = mpu.get_model_parallel_world_size()
self.model_parallel_rank = bwc_tensor_model_parallel_rank(mpu)
self.overflow = False
self.clip_grad = clip_grad
self.communication_data_type = communication_data_type
self.gradient_predivide_factor = gradient_predivide_factor
self.postscale_gradients = postscale_gradients
self.gradient_accumulation_steps = gradient_accumulation_steps
self.micro_step_id = 0
self.ignore_unused_parameters = ignore_unused_parameters
self.round_robin_gradients = round_robin_gradients
self.extra_large_param_to_reduce = None
self.fp16_master_weights_and_gradients = fp16_master_weights_and_gradients
if self.fp16_master_weights_and_gradients:
assert self.cpu_offload and type(self.optimizer) in [DeepSpeedCPUAdam], \
f"fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32."\
f"Currently only supported using ZeRO-Offload with DeepSpeedCPUAdam. But current setting is ZeRO-Offload:{self.cpu_offload} and optimizer type {type(self.optimizer)}." \
f"Either disable fp16_master_weights_and_gradients or enable {self.zero_stage_string} Offload with DeepSpeedCPUAdam."
if self.reduce_scatter:
valid_reduce_scatter_dtypes = (torch.float16, torch.bfloat16, torch.float32)
assert self.communication_data_type in valid_reduce_scatter_dtypes, f"{self.zero_stage_string} supports {valid_reduce_scatter_dtypes} communication_data_type with reduce scatter enabled. Got: '{self.communication_data_type}'"
assert self.gradient_predivide_factor == 1.0, "gradient_predivide_factor != 1.0 is not yet supported with {self.zero_stage_string} with reduce scatter enabled"
assert self.postscale_gradients, "pre-scale gradients is not yet supported with {self.zero_stage_string} with reduce scatter enabled"
# param flattened by groups
self.bit16_groups = []
self.bit16_groups_flat = []
# param partitioned by data parallel degree
# this will contain a list of equal sized tensors
# each of which will be updated by a different process
self.parallel_partitioned_bit16_groups = []
# a single 32-bit partition of the parallel partitioned parameters
# that this process will update
self.single_partition_of_fp32_groups = []
# param partition info
# These are the parameters in each group that will not be updated by this process directly
self.params_not_in_partition = []
# These are the parameters that will be updated by this process directly
self.params_in_partition = []
# Offset from the first parameter in the self.params_in_partition
# the parameter boundaries may not align with partition boundaries
# so we need to keep track of the offset
self.first_offset = []
# number of elements per partition in each group
self.partition_size = []
# align nccl all-gather send buffers to 4-byte boundary
self.nccl_start_alignment_factor = 2 # 4-byte alignment/sizeof(fp16) = 2
assert (
allgather_bucket_size % self.nccl_start_alignment_factor == 0
), f"allgather_bucket_size must be a multiple of nccl_start_alignment_factor, {self.nccl_start_alignment_factor} "
self.all_reduce_print = False
self.dtype = self.optimizer.param_groups[0]['params'][0].dtype
self.gradient_accumulation_dtype = gradient_accumulation_dtype
if self.dtype != self.gradient_accumulation_dtype:
self.use_separate_grad_accum = True
else:
self.use_separate_grad_accum = False
if self.use_separate_grad_accum and not self.partition_gradients:
self.use_grad_accum_attribute = True
else:
self.use_grad_accum_attribute = False
self.round_robin_bit16_groups = []
self.round_robin_bit16_indices = []
# Use different parallel to do all_to_all_reduce related things
# padding on each partition for alignment purposes
self.groups_padding = []
# loop to deal with groups
for i, param_group in enumerate(self.optimizer.param_groups):
partition_id = dist.get_rank(group=self.real_dp_process_group[i])
# push this group to list before modify
# TODO: Explore simplification that avoids the extra book-keeping by pushing the reordered group
trainable_parameters = []
for param in param_group['params']:
if param.requires_grad:
param.grad_accum = None
trainable_parameters.append(param)
self.bit16_groups.append(trainable_parameters)
# not sure why apex was cloning the weights before flattening
# removing cloning here
see_memory_usage(f"Before moving param group {i} to CPU")
# move all the parameters to cpu to free up GPU space for creating flat buffer
move_to_cpu(self.bit16_groups[i])
empty_cache()
see_memory_usage(f"After moving param group {i} to CPU", force=False)
# Reorder group parameters for load balancing of gradient partitioning during backward among ranks.
# This ensures that gradients are reduced in a fashion such that ownership round robins among the ranks.
# For example, rather than 3 gradients (g_n+2, g_n+1, g_n) that are reduced consecutively belonging
# to the same rank, instead they will belong to 3 ranks (r_m+2, r_m+1, r_m).
if self.round_robin_gradients:
round_robin_tensors, round_robin_indices = self._round_robin_reorder(
self.bit16_groups[i], dist.get_world_size(group=self.real_dp_process_group[i]))
else:
round_robin_tensors = self.bit16_groups[i]
round_robin_indices = list(range(len(self.bit16_groups[i])))
self.round_robin_bit16_groups.append(round_robin_tensors)
self.round_robin_bit16_indices.append(round_robin_indices)
# create flat buffer in CPU and move to GPU
self.bit16_groups_flat.append(
self.flatten_dense_tensors_aligned(
self.round_robin_bit16_groups[i],
self.nccl_start_alignment_factor * dist.get_world_size(group=self.real_dp_process_group[i])).to(
get_accelerator().current_device_name()))
see_memory_usage(f"After flattening and moving param group {i} to GPU", force=False)
# Record padding required for alignment
if partition_id == dist.get_world_size(group=self.real_dp_process_group[i]) - 1:
padding = self.bit16_groups_flat[i].numel() - sum(
[t.numel() for t in self.round_robin_bit16_groups[i]])
else:
padding = 0
self.groups_padding.append(padding)
if dist.get_rank(group=self.real_dp_process_group[i]) == 0:
see_memory_usage(f"After Flattening and after emptying param group {i} cache", force=False)
# set model bit16 weight to slices of flattened buffer
self._update_model_bit16_weights(i)
# divide the flat weights into near equal partition equal to the data parallel degree
# each process will compute on a different part of the partition
data_parallel_partitions = self.get_data_parallel_partitions(self.bit16_groups_flat[i], i)
self.parallel_partitioned_bit16_groups.append(data_parallel_partitions)
# verify that data partition start locations are 4-byte aligned
for partitioned_data in data_parallel_partitions:
assert (partitioned_data.data_ptr() % (2 * self.nccl_start_alignment_factor) == 0)
# A partition of the fp32 master weights that will be updated by this process.
# Note that the params in single_partition_of_fp32_groups is cloned and detached
# from the origin params of the model.
if not fp16_master_weights_and_gradients:
self.single_partition_of_fp32_groups.append(self.parallel_partitioned_bit16_groups[i][partition_id].to(
self.device).clone().float().detach())
else:
self.single_partition_of_fp32_groups.append(self.parallel_partitioned_bit16_groups[i][partition_id].to(
self.device).clone().half().detach())
# Set local optimizer to have flat params of its own partition.
# After this, the local optimizer will only contain its own partition of params.
# In that case, the local optimizer only saves the states(momentum, variance, etc.) related to its partition's params(zero stage1).
self.single_partition_of_fp32_groups[
i].requires_grad = True # keep this in case internal optimizer uses it
param_group['params'] = [self.single_partition_of_fp32_groups[i]]
partition_size = len(self.bit16_groups_flat[i]) / dist.get_world_size(group=self.real_dp_process_group[i])
params_in_partition, params_not_in_partition, first_offset = self.get_partition_info(
self.round_robin_bit16_groups[i], partition_size, partition_id)
self.partition_size.append(partition_size)
self.params_in_partition.append(params_in_partition)
self.params_not_in_partition.append(params_not_in_partition)
self.first_offset.append(first_offset)
self.reduce_bucket_size = int(reduce_bucket_size)
self.allgather_bucket_size = int(allgather_bucket_size)
self.reduction_stream = None if get_accelerator().is_synchronized_device() else get_accelerator().Stream()
#self.copy_grad_stream = get_accelerator().Stream()
self.callback_queued = False
self.param_dict = {}
# map between param_id and bool to specify if a param is in this partition
self.is_param_in_current_partition = {}
self.grads_in_ipg_bucket = []
self.params_in_ipg_bucket = []
self.elements_in_ipg_bucket = 0
self.params_already_reduced = []
self._release_ipg_buffers()
self.previous_reduced_grads = None
self.ipg_bucket_has_moe_params = False
# simplified param id
self.param_id = {}
#interesting code: unique ids being assigned to individual parameters
largest_param_numel = 0
count = 0
for i, params_group in enumerate(self.bit16_groups):
for param in params_group:
unique_id = id(param)
self.param_id[unique_id] = count
self.param_dict[count] = param
self.params_already_reduced.append(False)
if param.numel() > largest_param_numel:
largest_param_numel = param.numel()
count = count + 1
for param_group in self.params_in_partition:
for param in param_group:
self.is_param_in_current_partition[self.get_param_id(param)] = True
for param_group in self.params_not_in_partition:
for param in param_group:
self.is_param_in_current_partition[self.get_param_id(param)] = False
if self.cpu_offload:
self.accumulated_grads_in_cpu = {}
self.norm_for_param_grads = {}
self.local_overflow = False
self.grad_position = {}
self.temp_grad_buffer_for_cpu_offload = torch.zeros(largest_param_numel,
device=self.device,
dtype=self.dtype)
if self.cpu_offload_pin_memory:
self.temp_grad_buffer_for_cpu_offload = get_accelerator().pin_memory(
self.temp_grad_buffer_for_cpu_offload)
self.temp_grad_buffer_for_gpu_offload = torch.zeros(largest_param_numel,
device=get_accelerator().current_device_name(),
dtype=self.dtype)
for i, params_group in enumerate(self.bit16_groups):
self.get_grad_position(i, self.params_in_partition[i], self.first_offset[i], self.partition_size[i])
# mapping from parameter to partition that it belongs to
self.param_to_partition_ids = {}
# stores if a partition has been reduced in this step
self.is_partition_reduced = {}
# number of grads in partition that still need to be computed
self.remaining_grads_in_partition = {}
# total number of grads in partition
self.total_grads_in_partition = {}
# stores if a grad in a partition has been computed or not
self.is_grad_computed = {}
# stores the offset at which a parameter gradient needs to be inserted in a partition
self.grad_partition_insertion_offset = {}
# the offset in the gradient at which it must be inserted at the beginning of the partition
self.grad_start_offset = {}
# will store the averaged gradients required by this partition
self.averaged_gradients = {}
# For cpu_offload, will store the averaged gradients required by this partition
self.offload_gradient_dict = {}
# store index of first parameter in each partition
self.first_param_index_in_partition = {}
# initializes all data structures for implementing gradient partitioning
self.initialize_gradient_partitioning_data_structures()
# resets the data structure value for the next backward propagation
self.reset_partition_gradient_structures()
# creates backward hooks for gradient partitioning
if self.partition_gradients or self.overlap_comm:
self.create_reduce_and_remove_grad_hooks()
self.custom_loss_scaler = False
self.external_loss_scale = None
# we may have a way of fusing dynamic scale. Do not support for now
self.loss_scaler = CreateLossScaler(dtype=self.dtype,
static_loss_scale=static_loss_scale,
dynamic_scaling=dynamic_loss_scale,
dynamic_loss_args=dynamic_loss_args)
self.dynamic_loss_scale = self.loss_scaler.dynamic
if self.dtype != torch.float16:
# Only fp16 should use dynamic loss scaling
assert self.loss_scaler.cur_scale == 1.0
assert not self.dynamic_loss_scale
see_memory_usage("Before initializing optimizer states", force=True)
self.initialize_optimizer_states()
see_memory_usage("After initializing optimizer states", force=True)
if dist.get_rank() == 0:
logger.info(f"optimizer state initialized")
if dist.get_rank(group=self.dp_process_group) == 0:
see_memory_usage(f"After initializing ZeRO optimizer", force=True)
self._link_all_hp_params()
self._enable_universal_checkpoint()
self._param_slice_mappings = self._create_param_mapping()
def _enable_universal_checkpoint(self):
for lp_param_group in self.bit16_groups:
enable_universal_checkpoint(param_list=lp_param_group)
def _create_param_mapping(self):
param_mapping = []
for i, _ in enumerate(self.optimizer.param_groups):
param_mapping_per_group = OrderedDict()
for lp in self.bit16_groups[i]:
if lp._hp_mapping is not None:
lp_name = self.param_names[lp]
param_mapping_per_group[lp_name] = lp._hp_mapping.get_hp_fragment_address()
param_mapping.append(param_mapping_per_group)
return param_mapping
def _link_all_hp_params(self):
dp_world_size = dist.get_world_size(group=self.dp_process_group)
if self.cpu_offload:
self._get_offload_gradient_dict()
for i, _ in enumerate(self.optimizer.param_groups):
# Link bit16 and fp32 params in partition
partition_id = dist.get_rank(group=self.real_dp_process_group[i])
partition_size = self.bit16_groups_flat[i].numel() // dp_world_size
flat_hp_partition = self.single_partition_of_fp32_groups[i]
link_hp_params(lp_param_list=self.bit16_groups[i],
flat_hp_partition=flat_hp_partition,
gradient_dict=self.averaged_gradients,
offload_gradient_dict=self.offload_gradient_dict,
use_offload=self.cpu_offload,
param_group_index=i,
partition_start=partition_id * partition_size,
partition_size=partition_size,
partition_optimizer_state=self.optimizer.state[flat_hp_partition],
dp_group=self.real_dp_process_group[i])
def is_moe_group(self, group):
return 'moe' in group and group['moe']
def _configure_moe_settings(self):
# if we're using ZeRO stage 2, ensure contiguous gradients are used
if self.partition_gradients:
assert self.contiguous_gradients, "Contiguous Gradients in ZeRO Stage 2 must be set to True for MoE. Other code paths are not tested with MoE"
# NOTE: To run ZeRO stage 1 with MoE, we need to set self.contiguous_gradients to True or ignore the assertion
if not self.partition_gradients and not self.contiguous_gradients:
logger.warn(
"ZeRO Stage 1 has not been thoroughly tested with MoE. This configuration is still experimental.")
assert self.reduce_scatter, "Reduce Scatter in ZeRO Stage 2 must be set to True for MoE. Other code paths are not tested with MoE"
assert any(
[self.is_moe_group(group) for group in self.optimizer.param_groups]
), "The model has moe layers, but None of the param groups are marked as MoE. Create a param group with 'moe' key set to True before creating optimizer"
self.is_moe_param_group = []
for i, group in enumerate(self.optimizer.param_groups):
if self.is_moe_group(group):
assert all([is_moe_param(param)
for param in group['params']]), "All params in MoE group must be MoE params"
self.real_dp_process_group[i] = self.expert_dp_process_group[group['name']]
self.partition_count[i] = dist.get_world_size(group=self.expert_dp_process_group[group['name']])
self.is_moe_param_group.append(True)
else:
self.is_moe_param_group.append(False)
assert self.expert_dp_process_group is not None, "Expert data parallel group should be configured with MoE"
assert self.ep_process_group is not None, "Expert parallel group should be configured with MoE"
def _update_model_bit16_weights(self, group_index):
updated_params = self.unflatten(self.bit16_groups_flat[group_index],
self.round_robin_bit16_groups[group_index])
for p, q in zip(self.round_robin_bit16_groups[group_index], updated_params):
p.data = q.data
# set model fp16 weight to slices of reordered flattened buffer
for param_index, param in enumerate(self.bit16_groups[group_index]):
new_index = self.round_robin_bit16_indices[group_index][param_index]
param.data = self.round_robin_bit16_groups[group_index][new_index].data
def _round_robin_reorder(self, tensor_list, num_partitions):
# disable round robin if need to debug something
# return tensor_list, list(range(len(tensor_list)))
partition_tensors = {}
for i, tensor in enumerate(tensor_list):
j = i % num_partitions
if not j in partition_tensors:
partition_tensors[j] = []
partition_tensors[j].append((i, tensor))
reordered_tensors = []
reordered_indices = {}
for partition_index in partition_tensors.keys():
for i, (original_index, tensor) in enumerate(partition_tensors[partition_index]):
reordered_indices[original_index] = len(reordered_tensors)
reordered_tensors.append(tensor)
return reordered_tensors, reordered_indices
def _release_ipg_buffers(self):
if self.contiguous_gradients:
self.ipg_buffer = None
self.grads_in_partition = None
self.grads_in_partition_offset = 0
def initialize_optimizer_states(self):
for i, group in enumerate(self.bit16_groups):
single_grad_partition = torch.zeros(int(self.partition_size[i]),
dtype=self.single_partition_of_fp32_groups[i].dtype,
device=self.device)
self.single_partition_of_fp32_groups[i].grad = get_accelerator().pin_memory(
single_grad_partition) if self.cpu_offload_pin_memory else single_grad_partition
# Initialize the optimizer states with the flattened fp32 partition.
# State initialization for the Adagrad optimizer occurs at construction as opposed to other optimizers
# which do lazy initialization of the state at the first call to step.
if isinstance(self.optimizer, torch.optim.Adagrad):
self.optimizer = torch.optim.Adagrad(self.single_partition_of_fp32_groups, **self.optimizer.defaults)
else:
self.optimizer.step()
if not self.cpu_offload:
for group in self.single_partition_of_fp32_groups:
group.grad = None #class init
return
#########################################################################
#################### ZeRO Stage 1 - reduce gradients ####################
#########################################################################
def reduce_gradients(self, pipeline_parallel=False):
world_size = dist.get_world_size(self.dp_process_group)
my_rank = dist.get_rank(self.dp_process_group)
# with PP we must create ipg buffer, since backward is handled outside zero
if pipeline_parallel and self.contiguous_gradients:
self.ipg_buffer = []
buf_0 = torch.empty(int(self.reduce_bucket_size),
dtype=self.dtype,
device=get_accelerator().current_device_name())
self.ipg_buffer.append(buf_0)
self.ipg_index = 0
if not self.overlap_comm:
for i, group in enumerate(self.bit16_groups):
for param in group:
grad_reduc = self.get_gradient_for_reduction(param)
if grad_reduc is not None:
self.reduce_ready_partitions_and_remove_grads(param, i)
# reduce any pending grads in either hook/non-hook case
self.overlapping_partition_gradients_reduce_epilogue()
#########################################################################
#########################ZeRO Partition Gradients########################
#########################################################################
def get_first_param_index(self, group_id, param_group, partition_id):
for index, param in enumerate(param_group):
param_id = self.get_param_id(param)
if partition_id in self.param_to_partition_ids[group_id][param_id]:
return index
return None
def initialize_gradient_partitioning_data_structures(self):
for i, param_group in enumerate(self.round_robin_bit16_groups):
total_partitions = dist.get_world_size(group=self.real_dp_process_group[i])
self.param_to_partition_ids[i] = {}
self.is_partition_reduced[i] = {}
self.total_grads_in_partition[i] = {}
self.remaining_grads_in_partition[i] = {}
self.is_grad_computed[i] = {}
self.grad_partition_insertion_offset[i] = {}
self.grad_start_offset[i] = {}
self.first_param_index_in_partition[i] = {}
for partition_id in range(total_partitions):
self.is_grad_computed[i][partition_id] = {}
self.grad_partition_insertion_offset[i][partition_id] = {}
self.grad_start_offset[i][partition_id] = {}
self.total_grads_in_partition[i][partition_id] = 0
self.initialize_gradient_partition(i, param_group, partition_id)
self.is_partition_reduced[i][partition_id] = False
self.first_param_index_in_partition[i][partition_id] = self.get_first_param_index(
i, param_group, partition_id)
def independent_gradient_partition_epilogue(self):
self.report_ipg_memory_usage(f"In ipg_epilogue before reduce_ipg_grads", 0)
self.reduce_ipg_grads()
self.report_ipg_memory_usage(f"In ipg_epilogue after reduce_ipg_grads", 0)
# if dist.get_rank() == 0:
# logger.info("Params already reduced %s", self.params_already_reduced)
for i in range(len(self.params_already_reduced)):
self.params_already_reduced[i] = False
if self.overlap_comm:
get_accelerator().synchronize()
# It is safe to clear previously reduced grads of other partitions
self._clear_previous_reduced_grads()
if self.cpu_offload is False:
for i, _ in enumerate(self.bit16_groups):
if not i in self.averaged_gradients or self.averaged_gradients[i] is None:
self.averaged_gradients[i] = self.get_flat_partition(
self.params_in_partition[i],
self.first_offset[i],
self.partition_size[i],
dtype=self.gradient_accumulation_dtype,
device=get_accelerator().current_device_name(),
return_tensor_list=True)
else:
avg_new = self.get_flat_partition(self.params_in_partition[i],
self.first_offset[i],
self.partition_size[i],
dtype=self.gradient_accumulation_dtype,
device=get_accelerator().current_device_name(),
return_tensor_list=True)
for accumulated_grad, new_avg_grad in zip(self.averaged_gradients[i], avg_new):
accumulated_grad.add_(new_avg_grad)
self._release_ipg_buffers()
# No need to keep the gradients anymore.
# All gradients required by the step
# are in self.averaged_gradients
self.zero_grad(set_to_none=True)
see_memory_usage(f"End ipg_epilogue")
# resets all partition to no reduced
# sets remaining grads to the total number of grads in each partition
# set is grad computed to false for all grads in partition
def reset_partition_gradient_structures(self):
for i, _ in enumerate(self.bit16_groups):
total_partitions = dist.get_world_size(group=self.real_dp_process_group[i])
for partition_id in range(total_partitions):
self.is_partition_reduced[i][partition_id] = False
self.remaining_grads_in_partition[i][partition_id] = self.total_grads_in_partition[i][partition_id]
for param_id in self.is_grad_computed[i][partition_id]:
self.is_grad_computed[i][partition_id][param_id] = False
def initialize_gradient_partition(self, i, param_group, partition_id):
def set_key_value_list(dictionary, key, value):
if key in dictionary:
dictionary[key].append(value)
else:
dictionary[key] = [value]
def increment_value(dictionary, key):
if key in dictionary:
dictionary[key] += 1
else:
dictionary[key] = 1
partition_size = self.partition_size[i]
start_index = partition_size * partition_id
end_index = partition_size * (partition_id + 1)
current_index = 0
first_offset = 0
for param in param_group:
param_size = param.numel()
param_id = self.get_param_id(param)
if start_index <= current_index < end_index:
set_key_value_list(self.param_to_partition_ids[i], param_id, partition_id)
increment_value(self.total_grads_in_partition[i], partition_id)
self.is_grad_computed[i][partition_id][param_id] = False
self.grad_partition_insertion_offset[i][partition_id][param_id] = current_index - start_index
self.grad_start_offset[i][partition_id][param_id] = 0
elif current_index < start_index < (current_index + param_size):
assert (first_offset == 0
), "This can happen either zero or only once as this must be the first tensor in the partition"
first_offset = start_index - current_index
set_key_value_list(self.param_to_partition_ids[i], param_id, partition_id)
increment_value(self.total_grads_in_partition[i], partition_id)
self.is_grad_computed[i][partition_id][param_id] = False
self.grad_partition_insertion_offset[i][partition_id][param_id] = 0
self.grad_start_offset[i][partition_id][param_id] = first_offset
current_index = current_index + param_size
def overlapping_partition_gradients_reduce_epilogue(self):
self.independent_gradient_partition_epilogue()
def fill_grad_accum_attribute(self):
for group in self.bit16_groups:
for param in group:
if param.grad is not None:
if param.grad_accum is None:
param.grad_accum = param.grad.to(self.gradient_accumulation_dtype)
else:
param.grad_accum.add_(
param.grad.to(self.gradient_accumulation_dtype).view(param.grad_accum.shape))
param.grad = None
def get_gradient_for_reduction(self, param):
if self.use_grad_accum_attribute:
return param.grad_accum.to(self.dtype) if param.grad_accum is not None else None
else:
return param.grad
def get_param_gradient_attribute(self, param):
return param.grad_accum if self.use_grad_accum_attribute else param.grad
# Clear the tensor the reduction gradient attribute is pointing to
def clear_grad_attribute(self, param):
if self.use_grad_accum_attribute:
param.grad_accum = None
else:
param.grad = None
def create_reduce_and_remove_grad_hooks(self):
self.grad_accs = []
for i, param_group in enumerate(self.bit16_groups):
for param in param_group:
if param.requires_grad:
def wrapper(param, i):
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]
def reduce_partition_and_remove_grads(*notneeded):
self.reduce_ready_partitions_and_remove_grads(param, i)
grad_acc.register_hook(reduce_partition_and_remove_grads)
self.grad_accs.append(grad_acc)
wrapper(param, i)
def get_param_id(self, param):
unique_id = id(param)
return self.param_id[unique_id]
def report_ipg_memory_usage(self, tag, param_elems):
elem_count = self.elements_in_ipg_bucket + param_elems
percent_of_bucket_size = (100.0 * elem_count) // self.reduce_bucket_size
see_memory_usage(
f"{tag}: elems in_bucket {self.elements_in_ipg_bucket} param {param_elems} max_percent {percent_of_bucket_size}"
)
# create a flat tensor aligned at the alignment boundary
def flatten_dense_tensors_aligned(self, tensor_list, alignment):
return self.flatten(align_dense_tensors(tensor_list, alignment))
############### Independent Partition Gradient ########################
def reduce_independent_p_g_buckets_and_remove_grads(self, param, i):
grad_reduc = self.get_gradient_for_reduction(param)
if self.elements_in_ipg_bucket + param.numel() > self.reduce_bucket_size:
self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads", param.numel())
self.reduce_ipg_grads()
if self.contiguous_gradients and self.overlap_comm:
# Swap ipg_index between 0 and 1
self.ipg_index = 1 - self.ipg_index
self.report_ipg_memory_usage("In ipg_remove_grads after reduce_ipg_grads", param.numel())
param_id = self.get_param_id(param)
assert self.params_already_reduced[param_id] == False, \
f"The parameter {param_id} has already been reduced. \
Gradient computed twice for this partition. \
Multiple gradient reduction is currently not supported"
if self.contiguous_gradients:
if param.numel() > self.reduce_bucket_size:
self.extra_large_param_to_reduce = param
else:
# keeping the gradients contiguous to prevent memory fragmentation, and avoid flattening
new_grad_tensor = self.ipg_buffer[self.ipg_index].narrow(0, self.elements_in_ipg_bucket, param.numel())
new_grad_tensor.copy_(grad_reduc.view(-1))
grad_reduc.data = new_grad_tensor.data.view_as(grad_reduc)
self.elements_in_ipg_bucket += param.numel()
assert grad_reduc is not None, f"rank {dist.get_rank()} - Invalid to reduce Param {param_id} with None gradient"
self.grads_in_ipg_bucket.append(grad_reduc)
self.params_in_ipg_bucket.append((i, param, param_id))
#make sure the average tensor function knows how to average the gradients
if is_moe_param(param):
self.ipg_bucket_has_moe_params = True
self.report_ipg_memory_usage("End ipg_remove_grads", 0)
def print_rank_0(self, message):
if dist.get_rank() == 0:
logger.info(message)
def gradient_reduction_w_predivide(self, tensor):
dp_world_size = dist.get_world_size(group=self.dp_process_group)
tensor_to_allreduce = tensor
if self.communication_data_type != tensor.dtype:
tensor_to_allreduce = tensor.to(self.communication_data_type)
if self.postscale_gradients:
if self.gradient_predivide_factor != 1.0:
tensor_to_allreduce.mul_(1. / self.gradient_predivide_factor)
dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
if self.gradient_predivide_factor != dp_world_size:
tensor_to_allreduce.mul_(self.gradient_predivide_factor /
(dp_world_size / float(self.sequence_parallel_size)))
else:
tensor_to_allreduce.div_(dp_world_size / float(self.sequence_parallel_size))
dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
if self.communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce:
tensor.copy_(tensor_to_allreduce)
return tensor
def average_tensor(self, tensor):
if self.overlap_comm:
stream = self.reduction_stream
if not get_accelerator().is_synchronized_device():
stream.wait_stream(get_accelerator().current_stream())
else:
stream = get_accelerator().current_stream()
with get_accelerator().stream(stream):
if not self.reduce_scatter:
self.gradient_reduction_w_predivide(tensor)
return
# Accumulate destination ranks and bucket offsets for each gradient slice.
# Note: potential future optimization, record access pattern of parameters
# in backward pass and partition gradients w.r.t. access pattern so that our
# bucket is guaranteed to be contiguous w.r.t. ranks
rank_and_offsets = []
real_dp_process_group = []
curr_size = 0
prev_id, prev_process_group = -1, None
process_group = self.dp_process_group
# count = 0
for i, param, param_id in self.params_in_ipg_bucket:
process_group = self.dp_process_group
grad_reduc = self.get_gradient_for_reduction(param)
#Averages gradients at parameter level if ipg has a moe param
#Otherwise averaging is done at the entire buffer level at the end of the loop
# MoE param have different groups
if self.ipg_bucket_has_moe_params:
process_group = self.expert_dp_process_group[param.group_name] if is_moe_param(
param) else self.dp_process_group
grad_reduc.data.div_(dist.get_world_size(group=process_group) / float(self.sequence_parallel_size))
partition_ids = self.param_to_partition_ids[i][param_id]
assert all([p_id < dist.get_world_size(group=process_group) for p_id in partition_ids
]), f"world size {dist.get_world_size(group=process_group)} and p_ids: {partition_ids}"
partition_size = self.partition_size[i]
# Get all partition ids + their offsets
partition_ids_w_offsets = []
for partition_id in partition_ids:
offset = self.grad_start_offset[i][partition_id][param_id]
partition_ids_w_offsets.append((partition_id, offset))