-
Notifications
You must be signed in to change notification settings - Fork 3.9k
/
deepspeed_light.py
1034 lines (837 loc) · 40 KB
/
deepspeed_light.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
'''
Copyright 2019 The Microsoft DeepSpeed Team
'''
import logging
import torch
import os
import torch.distributed as dist
from torch.nn.modules import Module
from tensorboardX import SummaryWriter
from deepspeed.pt.deepspeed_timer import ThroughputTimer, SynchronizedWallClockTimer
from deepspeed.pt.deepspeed_zero_optimizer import FP16_DeepSpeedZeroOptimizer
from deepspeed.pt.fp16_optimizer import FP16_Optimizer
from deepspeed.pt.fp16_unfused_optimizer import FP16_UnfusedOptimizer
from deepspeed.pt.deepspeed_fused_lamb import FusedLamb
from deepspeed.pt.deepspeed_config import DeepSpeedConfig, \
ADAM_OPTIMIZER, LAMB_OPTIMIZER, DEEPSPEED_OPTIMIZERS
from deepspeed.pt.deepspeed_dataloader import DeepSpeedDataLoader
from deepspeed.pt.deepspeed_constants import ROUTE_TRAIN, ROUTE_PREDICT, \
ROUTE_EVAL
import deepspeed.pt.deepspeed_lr_schedules as lr_schedules
from deepspeed.pt.deepspeed_csr_tensor import CSRTensor
from apex import amp
from apex.optimizers.fused_adam import FusedAdam
MEMORY_OPT_ALLREDUCE_SIZE = 500000000
SUMMARY_WRITER_DIR_NAME = "JobId"
try:
from apex_C import flatten
from apex_C import unflatten
except ImportError:
try:
_ = warned_flatten
except NameError:
print(
"Warning: apex was installed without --cpp_ext. Falling back to Python flatten and unflatten."
)
warned_flatten = True
from torch._utils import _flatten_dense_tensors as flatten
from torch._utils import _unflatten_dense_tensors as unflatten
def split_half_float_double_csr(tensors):
dtypes = [
"torch.cuda.HalfTensor",
"torch.cuda.FloatTensor",
"torch.cuda.DoubleTensor",
CSRTensor.type()
]
buckets = []
for i, dtype in enumerate(dtypes):
bucket = [t for t in tensors if t.type() == dtype]
if bucket:
buckets.append((dtype, bucket))
return buckets
def _initialize_parameter_parallel_groups(parameter_parallel_size=None):
data_parallel_size = int(dist.get_world_size())
if parameter_parallel_size is None:
parameter_parallel_size = int(data_parallel_size)
print(data_parallel_size, parameter_parallel_size)
assert data_parallel_size % parameter_parallel_size == 0, \
'world size should be divisible by parameter parallel size'
rank = dist.get_rank()
my_group = None
for i in range(dist.get_world_size() // parameter_parallel_size):
ranks = range(i * parameter_parallel_size, (i + 1) * parameter_parallel_size)
group = torch.distributed.new_group(ranks)
if rank in ranks:
my_group = group
return my_group
def print_configuration(args, name):
print('{}:'.format(name), flush=True)
for arg in sorted(vars(args)):
dots = '.' * (29 - len(arg))
print(' {} {} {}'.format(arg, dots, getattr(args, arg)), flush=True)
class DeepSpeedLight(Module):
r"""DeepSpeed engine for training.
"""
def __init__(self,
args,
model,
optimizer=None,
model_parameters=None,
training_data=None,
lr_scheduler=None,
mpu=None,
dist_init_required=True,
collate_fn=None):
super(DeepSpeedLight, self).__init__()
logging.basicConfig(level=logging.INFO,
format="[%(levelname)s %(asctime)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S")
self.client_optimizer = optimizer
self.client_model_parameters = model_parameters
self.client_lr_scheduler = lr_scheduler
self.training_data = training_data
self.collate_fn = collate_fn
self.mpu = mpu
self.data_parallel_group = None
self.global_steps = 0
self.micro_steps = 0
self.skipped_steps = 0
self.gradient_predivide_factor = 1.0
self.gradient_average = True
self.warn_unscaled_loss = True
if dist_init_required:
dist.init_process_group(backend="nccl")
self._do_args_sanity_check(args)
self._configure_with_arguments(args, mpu)
self._do_sanity_check()
self.sample_count = 0
if self.tensorboard_enabled():
self.summary_writer = self.get_summary_writer()
self._init_distributed(dist_init_required)
# Throughput timer
self.tput_timer = ThroughputTimer(
batch_size=self.train_micro_batch_size_per_gpu(),
num_workers=self.world_size,
monitor_memory=False)
self.training_dataloader = self.deepspeed_io(
training_data) if training_data else None
# Configure distributed model
self._configure_distributed_model(model)
# Configure optimizer and scheduler
self.optimizer = None
self.lr_scheduler = None
if model_parameters or optimizer:
self._configure_optimizer(optimizer, model_parameters)
self._configure_lr_scheduler(lr_scheduler)
self._report_progress(0)
# Configure wall clock timer
self.timers = SynchronizedWallClockTimer()
# Bookkeeping for csr support
self.csr_tensor_module_names = set()
if self.sparse_gradients_enabled():
for name, module in self.module.named_modules():
if isinstance(module, torch.nn.Embedding):
self.csr_tensor_module_names.add(name)
logging.info("Will convert {} to sparse (csr) "
"tensor during training".format(name))
self.save_non_zero_checkpoint = False
self.save_zero_checkpoint = False
self._configure_checkpointing(dist_init_required)
if self.global_rank == 0:
self._config.print('DeepSpeedLight configuration')
if self.dump_state():
print_configuration(self, 'DeepSpeedLight')
def tensorboard_enabled(self):
return self._config.tensorboard_enabled
def tensorboard_output_path(self):
return self._config.tensorboard_output_path
def tensorboard_job_name(self):
return self._config.tensorboard_job_name
def get_summary_writer(self,
name="DeepSpeedJobName",
base=os.environ["HOME"] + "/tensorboard"):
if self.tensorboard_job_name():
name = self.tensorboard_job_name()
if self.tensorboard_output_path():
return SummaryWriter(log_dir=self.tensorboard_output_path())
if 'DLWS_JOB_ID' in os.environ:
SUMMARY_WRITER_DIR_NAME = os.environ['DLWS_JOB_ID'] + "/logs"
return SummaryWriter(log_dir=os.path.join(base, SUMMARY_WRITER_DIR_NAME, name))
def wall_clock_breakdown(self):
return self._config.wall_clock_breakdown
def sparse_gradients_enabled(self):
return self._config.sparse_gradients_enabled
def train_batch_size(self):
return self._config.train_batch_size
def train_micro_batch_size_per_gpu(self):
return self._config.train_micro_batch_size_per_gpu
def optimizer_name(self):
return self._config.optimizer_name
def optimizer_params(self):
return self._config.optimizer_params
def optimizer_legacy_fusion(self):
return self._config.optimizer_legacy_fusion
def scheduler_name(self):
return self._config.scheduler_name
def scheduler_params(self):
return self._config.scheduler_params
def zero_optimization(self):
return self._config.zero_enabled
def allgather_size(self):
return self._config.allgather_size
def fp16_enabled(self):
return self._config.fp16_enabled
def loss_scale(self):
return self._config.loss_scale
def gradient_accumulation_steps(self):
return self._config.gradient_accumulation_steps
def allreduce_always_fp32(self):
return self._config.allreduce_always_fp32
def postscale_gradients(self):
return not self._config.prescale_gradients
def steps_per_print(self):
return self._config.steps_per_print
def disable_allgather(self):
return self._config.disable_allgather
def dump_state(self):
return self._config.dump_state
def gradient_clipping(self):
return self._config.gradient_clipping
def dynamic_loss_scale(self):
return self._config.loss_scale == 0
def initial_dynamic_scale(self):
return self._config.initial_dynamic_scale
def dynamic_loss_scale_args(self):
return self._config.dynamic_loss_scale_args
def _configure_lr_scheduler(self, client_lr_scheduler):
# First check for scheduler in json configuration
lr_scheduler = self._scheduler_from_config(self.optimizer)
if lr_scheduler:
logging.info(
f'DeepSpeed using configured LR scheduler = {self.scheduler_name()}')
self.lr_scheduler = lr_scheduler
else:
logging.warning('DeepSpeed using client LR scheduler')
self.lr_scheduler = client_lr_scheduler
logging.info(f'DeepSpeed LR Scheduler = {self.lr_scheduler}')
def _configure_checkpointing(self, dist_init_required):
dp_rank = torch.distributed.get_rank(
) if self.mpu is None else self.mpu.get_data_parallel_rank()
#only the first data parallel process needs to store the model checkpoint
self.save_non_zero_checkpoint = True if dp_rank == 0 else False
if self.zero_optimization():
pp_rank = torch.distributed.get_rank(group=self.optimizer.dp_process_group)
#only the first parameter parallel process needs to store the optimizer state checkpoints for zero
self.save_zero_checkpoint = True if pp_rank == dp_rank else False
def _scheduler_from_config(self, optimizer):
scheduler_name = self.scheduler_name()
if scheduler_name is not None:
if hasattr(lr_schedules, scheduler_name):
scheduler = getattr(lr_schedules, scheduler_name)
else:
assert hasattr(torch.optim.lr_scheduler, scheduler_name), \
f"DeepSpeed does not recognize LR scheduler {scheduler_name}"
scheduler = getattr(torch.optim.lr_scheduler, scheduler_name)
scheduler_params = self.scheduler_params()
instantiated_scheduler = scheduler(optimizer, **scheduler_params)
return instantiated_scheduler
else:
return None
def _init_distributed(self, dist_init_required):
if self.local_rank >= 0:
torch.cuda.set_device(self.local_rank)
self.device = torch.device("cuda", self.local_rank)
self.world_size = dist.get_world_size()
self.global_rank = dist.get_rank()
logging.info("Set device to local rank {} within node.".format(
self.local_rank))
else:
self.world_size = 1
self.global_rank = 0
self.device = torch.device("cuda")
# Configure based on command line arguments
def _configure_with_arguments(self, args, mpu):
self.local_rank = args.local_rank if hasattr(args, 'local_rank') else 0
self._config = DeepSpeedConfig(args.deepspeed_config, mpu)
# Validate command line arguments
def _do_args_sanity_check(self, args):
if hasattr(args, 'deepscale_config') and args.deepscale_config is not None:
logging.warning(
"************ --deepscale_config is deprecated, please use --deepspeed_config ************"
)
if hasattr(args, 'deepspeed_config'):
assert args.deepspeed_config is None, "Not sure how to proceed, we were given both a deepscale_config and deepspeed_config"
args.deepspeed_config = args.deepscale_config
assert hasattr(args, 'local_rank') and type(args.local_rank) == int, \
'DeepSpeed requires integer command line parameter --local_rank'
assert hasattr(args, 'deepspeed_config') and args.deepspeed_config is not None, \
'DeepSpeed requires --deepspeed_config to specify configuration file'
assert os.path.isfile(args.deepspeed_config), \
'DeepSpeed configuration file: {} is not an existing file'.format(args.deepspeed_config)
def _is_supported_optimizer(self, optimizer_name):
return optimizer_name in DEEPSPEED_OPTIMIZERS or \
getattr(torch.optim, optimizer_name, None) is not None
# Validate configuration based on command line arguments
def _do_sanity_check(self):
if not self.client_optimizer:
assert self._is_supported_optimizer(self.optimizer_name()), \
'{} is not a supported DeepSpeed Optimizer'.format(self.optimizer_name())
assert self.client_model_parameters, \
'DeepSpeed {} optimizer requires parameters in initialize() call'.format(self.optimizer_name())
if self.optimizer_name() == LAMB_OPTIMIZER:
assert self.dynamic_loss_scale(), \
'DeepSpeed {} optimizer requires dynamic loss scaling'.format(self.optimizer_name())
def _configure_distributed_model(self, model):
self.module = model
if self.fp16_enabled():
self.module.half()
self.module.to(self.device)
if self.mpu is None:
self.data_parallel_group = _initialize_parameter_parallel_groups()
self.dp_world_size = dist.get_world_size()
src_rank = 0
else:
self.data_parallel_group = self.mpu.get_data_parallel_group()
self.dp_world_size = self.mpu.get_data_parallel_world_size()
src_rank = self.mpu.get_model_parallel_rank()
for p in self.module.parameters():
if torch.is_tensor(p):
dist.broadcast(p, src_rank, group=self.data_parallel_group)
# TODO: support new AMP optimizer
# self.module.half()
# self.module.to(self.local_rank)
#self.module, self.optimizer = amp.initialize(self.module, self.optimizer, opt_level="O2")
# Configure optimizer
def _configure_optimizer(self, client_optimizer, model_parameters):
if client_optimizer is not None:
basic_optimizer = client_optimizer
logging.info('Using client Optimizer as basic optimizer')
else:
basic_optimizer = self._configure_basic_optimizer(model_parameters)
logging.info(
'Using DeepSpeed Optimizer param name {} as basic optimizer'.format(
self.optimizer_name()))
logging.info('DeepSpeed Basic Optimizer = {}'.format(basic_optimizer))
if self.zero_optimization() and self.optimizer_name() == ADAM_OPTIMIZER:
self.optimizer = self._configure_zero_optimizer(basic_optimizer)
elif self.fp16_enabled():
self.optimizer = self._configure_fp16_optimizer(basic_optimizer)
else:
self.optimizer = basic_optimizer
# logging.info('DeepSpeed Final Optimizer = {}'.format(self.optimizer.state_dict()))
def _configure_basic_optimizer(self, model_parameters):
optimizer_parameters = self.optimizer_params()
if self.fp16_enabled() and 'max_grad_norm' in optimizer_parameters.keys():
optimizer_parameters['max_grad_norm'] = 0.0
if self.optimizer_name() == ADAM_OPTIMIZER:
optimizer = FusedAdam(model_parameters, **optimizer_parameters)
elif self.optimizer_name() == LAMB_OPTIMIZER:
optimizer = FusedLamb(model_parameters, **optimizer_parameters)
else:
torch_optimizer = getattr(torch.optim, self.optimizer_name())
optimizer = torch_optimizer(model_parameters, **optimizer_parameters)
return optimizer
def _configure_fp16_optimizer(self, optimizer):
initial_dynamic_scale = self.initial_dynamic_scale()
dynamic_loss_args = self.dynamic_loss_scale_args()
clip_grad = self.gradient_clipping()
if self.optimizer_name() == ADAM_OPTIMIZER:
if self.dynamic_loss_scale():
logging.info('Creating fp16 optimizer with dynamic loss scale')
optimizer = FP16_Optimizer(
optimizer,
dynamic_loss_scale=True,
initial_dynamic_scale=initial_dynamic_scale,
dynamic_loss_args=dynamic_loss_args,
mpu=self.mpu,
clip_grad=clip_grad,
fused_adam_legacy=self.optimizer_legacy_fusion())
else:
logging.info('Creating fp16 optimizer with static loss scale: {}'.format(
self.loss_scale()))
optimizer = FP16_Optimizer(
optimizer,
static_loss_scale=self.loss_scale(),
mpu=self.mpu,
clip_grad=clip_grad,
fused_adam_legacy=self.optimizer_legacy_fusion())
else:
logging.info('Creating fp16 unfused optimizer with dynamic loss scale')
optimizer = FP16_UnfusedOptimizer(
optimizer,
dynamic_loss_scale=self.dynamic_loss_scale(),
dynamic_loss_args=dynamic_loss_args,
mpu=self.mpu,
clip_grad=clip_grad,
fused_lamb_legacy=self.optimizer_legacy_fusion()
if self.optimizer_name() == LAMB_OPTIMIZER else False)
return optimizer
def _configure_zero_optimizer(self, optimizer):
logging.info('Creating fp16 zero optimizer')
optimizer = FP16_DeepSpeedZeroOptimizer(
optimizer,
static_loss_scale=self.loss_scale(),
dynamic_loss_scale=self.dynamic_loss_scale(),
dynamic_loss_args=self.dynamic_loss_scale_args(),
dp_process_group=self.data_parallel_group,
clip_grad=self.gradient_clipping(),
all_gather_partitions=not self.disable_allgather(),
allgather_size=self.allgather_size(),
mpu=self.mpu)
return optimizer
def deepspeed_io(self,
dataset,
batch_size=None,
route=ROUTE_TRAIN,
pin_memory=True,
data_sampler=None,
collate_fn=None,
num_local_io_workers=None):
if not isinstance(dataset, torch.utils.data.Dataset):
raise ValueError("Training data must be a torch Dataset")
if data_sampler is None and (route == ROUTE_PREDICT or route == ROUTE_EVAL):
data_sampler = torch.utils.data.SequentialSampler(dataset)
if batch_size is None:
batch_size = self.train_micro_batch_size_per_gpu()
if collate_fn is None:
collate_fn = self.collate_fn
# Currently we only use timer in train route
deepspeed_io_timer = None
if route == ROUTE_TRAIN:
deepspeed_io_timer = self.tput_timer
return DeepSpeedDataLoader(dataset=dataset,
batch_size=batch_size,
pin_memory=pin_memory,
collate_fn=collate_fn,
local_rank=self.local_rank,
tput_timer=deepspeed_io_timer,
num_local_io_workers=num_local_io_workers,
data_sampler=data_sampler)
def train(self):
r"""
"""
self.warn_unscaled_loss = True
self.module.train()
def eval(self):
r"""
"""
self.warn_unscaled_loss = True
self.module.train(False)
def _scale_loss(self, loss):
if isinstance(loss, torch.Tensor):
loss = loss / self.gradient_accumulation_steps()
elif isinstance(loss, tuple) and isinstance(loss[0], torch.Tensor):
loss = (l / self.gradient_accumulation_steps() for l in loss)
elif isinstance(loss, list) and isinstance(loss[0], torch.Tensor):
loss = [l / self.gradient_accumulation_steps() for l in loss]
else:
if self.warn_unscaled_loss:
logging.warning(
f'DeepSpeed unable to scale loss because of type: {type(loss)}')
self.warn_unscaled_loss = False
return loss
def forward(self, *inputs, **kwargs):
r"""Execute forward propagation
Arguments:
*inputs: Variable length input list
**kwargs: variable length keyword arguments
"""
if self.wall_clock_breakdown():
self.timers('forward_microstep').start()
self.timers('forward').start()
if self.training_dataloader is None:
self.tput_timer.start()
loss = self.module(*inputs, **kwargs)
# scale loss w.r.t. gradient accumulation if needed
if self.gradient_accumulation_steps() > 1:
loss = self._scale_loss(loss)
if self.wall_clock_breakdown():
self.timers('forward').stop()
self.timers('forward_microstep').stop()
return loss
def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
if self.is_gradient_accumulation_boundary():
self.buffered_allreduce_fallback(elements_per_buffer=bucket_size)
def backward(self, loss, allreduce_gradients=True):
r"""Execute backward pass on the loss
Arguments:
loss: Torch tensor on which to execute backward propagation
allreduce_gradients: If this is False, then gradient averaging will be skipped. Default is True.
"""
if self.is_gradient_accumulation_boundary() and self.tensorboard_enabled(
) and torch.distributed.get_rank(
) == 0: # deepspeed tensorboard support for loss
self.sample_count += (self.train_micro_batch_size_per_gpu() *
torch.distributed.get_world_size() *
self.gradient_accumulation_steps())
self.summary_events = [
(f'Train/Samples/train_loss',
loss.mean().item() * self.gradient_accumulation_steps(),
self.sample_count)
]
for event in self.summary_events: # write_summary_events
self.summary_writer.add_scalar(event[0], event[1], event[2])
self.summary_writer.flush()
if self.wall_clock_breakdown():
self.timers('backward_microstep').start()
self.timers('backward').start()
assert self.optimizer is not None, "must provide optimizer during " \
"init in order to use backward"
if self.wall_clock_breakdown():
self.timers('backward_inner_microstep').start()
self.timers('backward_inner').start()
if self.zero_optimization():
self.optimizer.backward(loss)
elif self.fp16_enabled():
self.optimizer.backward(loss)
# TODO: Use new AMP semantics as below
# with amp.scale_loss(loss, self.optimizer) as scaled_loss:
# scaled_loss.backward()
else:
loss.backward()
if self.wall_clock_breakdown():
self.timers('backward_inner').stop()
self.timers('backward_inner_microstep').stop()
if self.wall_clock_breakdown():
self.timers('backward_allreduce_microstep').start()
self.timers('backward_allreduce').start()
if allreduce_gradients:
self.allreduce_gradients()
if self.wall_clock_breakdown():
self.timers('backward_allreduce').stop()
self.timers('backward_allreduce_microstep').stop()
self.timers('backward').stop()
self.timers('backward_microstep').stop()
def is_gradient_accumulation_boundary(self):
return (self.micro_steps + 1) % \
self.gradient_accumulation_steps() == 0
def step(self):
r"""Execute the weight update step after forward and backward propagation on effective_train_batch
"""
if self.wall_clock_breakdown():
self.timers('step_microstep').start()
self.timers('step').start()
assert self.optimizer is not None, "must provide optimizer during " \
"init in order to use step"
report_progress = self.global_rank == 0 if self.global_rank else True
if self.is_gradient_accumulation_boundary():
self.optimizer.step()
self.optimizer.zero_grad()
# Check overlow here since in DS fp16 optimizer, the overflow is updated in above step() function.
overflow = False
if hasattr(self.optimizer, 'overflow'):
overflow = self.optimizer.overflow
if overflow:
self.skipped_steps += 1
else:
if self.lr_scheduler is not None:
self.lr_scheduler.step()
if report_progress and (self.global_steps +
1) % self.steps_per_print() == 0:
self._report_progress(self.global_steps + 1)
self.global_steps += 1
self.tput_timer.stop(report_progress)
if self.is_gradient_accumulation_boundary() and self.tensorboard_enabled(
) and torch.distributed.get_rank() == 0: # deepspeed tensorboard support for lr
self.summary_events = [(f'Train/Samples/lr',
self.get_lr()[0],
self.sample_count)]
for event in self.summary_events: # write_summary_events
self.summary_writer.add_scalar(event[0], event[1], event[2])
self.summary_writer.flush()
if self.wall_clock_breakdown():
self.timers('step').stop()
self.timers('step_microstep').stop()
self.timers.log([
'forward_microstep',
'backward_microstep',
'backward_inner_microstep',
'backward_allreduce_microstep',
'step_microstep'
])
if self.is_gradient_accumulation_boundary():
if self.tensorboard_enabled() and torch.distributed.get_rank(
) == 0: # this is done before the log because log resets timers
self.summary_events = [(f'Train/Samples/elapsed_time_ms_forward', self.timers('forward').elapsed(reset=False) * 1000.0, self.sample_count), \
(f'Train/Samples/elapsed_time_ms_backward', self.timers('backward').elapsed(reset=False) * 1000.0, self.sample_count), \
(f'Train/Samples/elapsed_time_ms_backward_inner', self.timers('backward_inner').elapsed(reset=False) * 1000.0, self.sample_count), \
(f'Train/Samples/elapsed_time_ms_backward_allreduce', self.timers('backward_allreduce').elapsed(reset=False) * 1000.0, self.sample_count), \
(f'Train/Samples/elapsed_time_ms_step', self.timers('step').elapsed(reset=False) * 1000.0, self.sample_count)
]
for event in self.summary_events: # write_summary_events
self.summary_writer.add_scalar(event[0], event[1], event[2])
self.summary_writer.flush()
self.timers.log([
'forward',
'backward',
'backward_inner',
'backward_allreduce',
'step'
])
self.micro_steps += 1
def _get_optimizer_param(self, param_name):
result = []
if not self.optimizer:
return result
for group in self.optimizer.param_groups:
if param_name in group:
result.append(group[param_name])
else:
result.append(0.0)
return result
def get_lr(self):
return self._get_optimizer_param('lr')
def get_mom(self):
return self._get_optimizer_param('betas')
def _report_progress(self, step):
lr = self.get_lr()
mom = self.get_mom()
logging.info('rank:{} step={}, skipped={}, lr={}, mom={}'.format(
self.global_rank,
step,
self.skipped_steps,
lr,
mom))
def allreduce_bucket(self, bucket):
tensor = flatten(bucket)
tensor_to_allreduce = tensor
if self.allreduce_always_fp32():
tensor_to_allreduce = tensor.float()
if self.postscale_gradients():
if self.gradient_predivide_factor != 1.0:
tensor_to_allreduce.mul_(1. / self.gradient_predivide_factor)
dist.all_reduce(tensor_to_allreduce, group=self.data_parallel_group)
if self.gradient_average:
if self.gradient_predivide_factor != self.dp_world_size:
tensor_to_allreduce.mul_(self.gradient_predivide_factor /
self.dp_world_size)
else:
tensor_to_allreduce.div_(self.dp_world_size)
dist.all_reduce(tensor_to_allreduce, group=self.data_parallel_group)
if self.allreduce_always_fp32() and tensor is not tensor_to_allreduce:
tensor.copy_(tensor_to_allreduce)
return tensor
def allreduce_and_copy(self, small_bucket):
allreduced = self.allreduce_bucket(small_bucket)
for buf, synced in zip(small_bucket, unflatten(allreduced, small_bucket)):
buf.copy_(synced)
def allreduce_no_retain(self, bucket, numel_per_bucket=500000000):
small_bucket = []
numel = 0
for tensor in bucket:
small_bucket.append(tensor)
numel = numel + tensor.numel()
if numel > numel_per_bucket:
self.allreduce_and_copy(small_bucket)
small_bucket = []
if len(small_bucket) > 0:
self.allreduce_and_copy(small_bucket)
def buffered_allreduce_fallback(self, grads=None, elements_per_buffer=500000000):
grads = []
for param_name, param in self.module.named_parameters():
if param.grad is not None:
grad_data = param.grad.data
param_name_root = param_name.split('.', 1)[0]
if self.sparse_gradients_enabled(
) and param_name_root in self.csr_tensor_module_names:
grads.append(CSRTensor(grad_data))
else:
grads.append(grad_data)
split_buckets = split_half_float_double_csr(grads)
for i, bucket_tuple in enumerate(split_buckets):
bucket_type, bucket = bucket_tuple
if bucket_type == CSRTensor.type():
self.csr_allreduce_no_retain(bucket)
else:
self.allreduce_no_retain(bucket, numel_per_bucket=elements_per_buffer)
def csr_allreduce_no_retain(self, bucket):
allreduced_csrs = self.csr_allreduce_bucket(bucket)
# Densify csr tensor and copy back to original location
for csr in allreduced_csrs:
dense_tensor = csr.to_dense()
csr.orig_dense_tensor.copy_(dense_tensor)
def csr_allreduce_bucket(self, bucket):
csr_list = []
for csr in bucket:
csr_list.append(self.csr_allreduce(csr))
return csr_list
def csr_allreduce(self, csr):
# Pre-divide for fp16 stability
csr.values.div_(self.dp_world_size)
indices_device_list = self.csr_all_gather(csr.indices)
values_device_list = self.csr_all_gather(csr.values)
csr.indices = torch.cat(indices_device_list)
csr.values = torch.cat(values_device_list)
return csr
def csr_all_gather(self, value):
my_size = torch.LongTensor([value.size()[0]]).cuda()
all_sizes = self.all_gather_scalar(my_size)
max_size = torch.cat(all_sizes).max()
fill_size = (max_size - my_size)
assert value.dim() in [1, 2]
if value.dim() == 1:
if fill_size > 0:
value = torch.cat([value, value.new_zeros(fill_size)])
tensor_list = [
value.new_zeros(max_size) for _ in range(dist.get_world_size())
]
else:
if fill_size > 0:
value = torch.cat([value, value.new_zeros(fill_size, value.size()[1])])
tensor_list = [
value.new_zeros(max_size,
value.size()[1]) for _ in range(dist.get_world_size())
]
dist.all_gather(tensor_list, value, group=self.data_parallel_group)
tensors = []
for dev_idx, t in enumerate(tensor_list):
size = all_sizes[dev_idx][0]
tensors.append(t.index_select(0, torch.LongTensor(range(size)).cuda()))
return tensors
def all_gather_scalar(self, value):
tensor_list = [value.new_zeros(value.size()) for _ in range(self.dp_world_size)]
dist.all_gather(tensor_list, value, group=self.data_parallel_group)
return tensor_list
def module_state_dict(self, destination=None, prefix='', keep_vars=False):
sd = self.module.state_dict(destination, prefix, keep_vars)
return sd
def load_module_state_dict(self, state_dict, strict=True):
self.module.load_state_dict(state_dict, strict=strict)
def _get_zero_ckpt_name(self, checkpoints_path, tag):
mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
pp_rank = torch.distributed.get_rank(group=self.optimizer.dp_process_group)
filename = 'zero_pp_rank_{}'.format(pp_rank)
zero_ckpt_name = os.path.join(
checkpoints_path,
str(tag),
filename + '_mp_rank_{:02d}'.format(mp_rank) + 'optim_states.pt')
return zero_ckpt_name
def _get_ckpt_name(self, checkpoints_path, tag):
mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
ckpt_name = os.path.join(checkpoints_path,
str(tag),
'mp_rank_{:02d}'.format(mp_rank) + '_model_states.pt')
return ckpt_name
def _ensure_directory_exists(self, filename):
dirname = os.path.dirname(filename)
if not os.path.exists(dirname):
os.makedirs(dirname)
def load_checkpoint(self, load_dir, tag):
r"""Load training checkpoint
Arguments:
load_dir: Required. Directory to load the checkpoint from
tag: Required. Checkpoint tag used as a unique identifier for the checkpoint. Ex. Global Step.
Return:
load_path: Path of the loaded checkpoint. None if loading the checkpoint failed
client_state: State dictionary used for loading required training states in the client code.
"""
load_path, client_states = self._load_checkpoint(load_dir, tag)
if self.zero_optimization() and load_path is not None:
self._load_zero_checkpoint(load_dir, tag)
return load_path, client_states
def _load_checkpoint(self, load_dir, tag):
load_path = self._get_ckpt_name(load_dir, tag)
if not os.path.exists(load_path):
logging.warn(
'Client provided checkpoint load path: {} does not exist ... skip checkpoint load'
.format(load_path))
return None, None
logging.info('Loading checkpoint: {}'.format(load_path))
checkpoint = torch.load(load_path, map_location=lambda storage, loc: storage)
self.load_module_state_dict(checkpoint['module'])
if not self.zero_optimization():
self.optimizer.load_state_dict(checkpoint['optimizer'])
if self.lr_scheduler is not None:
self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
self.csr_tensor_module_names = checkpoint['csr_tensor_module_names']
self.global_steps = checkpoint['global_steps']
self.skipped_steps = checkpoint['skipped_steps']
deepspeed_states = [
'module',
'optimizer',
'csr_tensor_module_names',
'skipped_steps',
'global_step'
]
client_state = {
key: value
for key,
value in checkpoint.items() if not key in deepspeed_states
}
return load_path, client_state
def _load_zero_checkpoint(self, load_dir, tag):
zero_checkpoint_name = self._get_zero_ckpt_name(load_dir, tag)
if not os.path.exists(zero_checkpoint_name):
logging.warn(
'Client provided checkpoint load path: {} does not exist ... skip checkpoint load'
.format(zero_checkpoint_name))
return None
zero_sd = torch.load(zero_checkpoint_name, map_location='cpu')
self.optimizer.load_state_dict(zero_sd['optimizer_state_dict'])
logging.info('loading zero checkpoint {}'.format(zero_checkpoint_name))
def save_checkpoint(self, save_dir, tag, client_state={}):
r"""Save training checkpoint
Arguments:
save_dir: Required. Directory for saving the checkpoint
tag: Required. Checkpoint tag used as a unique identifier for the checkpoint. Ex. Global Step.
client_state: Optional. State dictionary used for saving required training states in the client code.
"""
#This is to make sure the checkpoint names are created without collision
#There seems to be issue creating them in parallel
self._create_checkpoint_files(save_dir, tag)
try:
if self.save_non_zero_checkpoint:
self._save_checkpoint(save_dir, tag, client_state=client_state)
if self.save_zero_checkpoint:
self._save_zero_checkpoint(save_dir, tag)
except:
logging.error(f'Failed Saving model checkpoint to {save_dir} with tag {tag}')
return False
return True
def _create_checkpoint_files(self, save_dir, tag):
#checkpoint files are created sequentially
for rank in range(dist.get_world_size()):
if rank == dist.get_rank():
try:
if self.save_non_zero_checkpoint:
checkpoint_name = self._get_ckpt_name(save_dir, tag)
self._ensure_directory_exists(checkpoint_name)
if self.save_zero_checkpoint:
checkpoint_name = self._get_zero_ckpt_name(save_dir, tag)
self._ensure_directory_exists(checkpoint_name)
except:
logging.error(
f'Failed Saving model checkpoint to {save_dir} with tag {tag}')
return False
dist.barrier()
def _save_checkpoint(self, save_dir, tag, client_state={}):
save_path = self._get_ckpt_name(save_dir, tag)