-
Notifications
You must be signed in to change notification settings - Fork 4.2k
/
Copy pathcheckpointing.py
1142 lines (907 loc) · 43.3 KB
/
checkpointing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
Use to partition the activations stored for backward propagation
Therefore reduces the memory consumption
Also implements CPU checkpointing and contiguous memory checkpointing
Reduces memory consumption and memory fragmentation
Code for rng checkpointing taken from NVIDIA Megatron-LM mpu/random.py
b886b7bb972afe72bac0f5de4f42a4a7bae8ebef
"""
# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch
import copy
import torch
import contextlib
from deepspeed import comm as dist
import weakref
import mmap
from torch import _C
from deepspeed.runtime.config import DeepSpeedConfig
from deepspeed.utils import logger
from deepspeed.runtime.utils import copy_to_device, move_to_device, see_memory_usage
from deepspeed.utils.timer import SynchronizedWallClockTimer as Timers, FORWARD_GLOBAL_TIMER
from deepspeed.utils.bwc import bwc_tensor_model_parallel_rank
from deepspeed.accelerator import get_accelerator
from deepspeed.runtime import compiler
# DeepSpeed Checkpointing Enabled or Disabled
deepspeed_checkpointing_enabled = False
# MP parameters
mpu = None
#set default values
mp_rank = 0
mp_size = 1
mp_group = None
# Model Parameters
num_layers = None
# Checkpointing buffers
contiguous_data_buffers = []
data_offsets = []
contiguous_size_buffers = []
size_offsets = []
timers = None
# optimization flags
PARTITION_ACTIVATIONS = False
CPU_CHECKPOINT = False
CONTIGUOUS_CHECKPOINTING = False
SYNCHRONIZE = False
PROFILE_TIME = False
# Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
def detach_variable(inputs, device=None):
if isinstance(inputs, tuple):
out = []
for inp in inputs:
if not isinstance(inp, torch.Tensor):
out.append(inp)
continue
requires_grad = inp.requires_grad
if device is not None:
x = inp.to(device=device)
else:
x = inp
x = x.detach()
x.requires_grad = requires_grad
out.append(x)
return tuple(out)
else:
raise RuntimeError("Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__)
def _set_cuda_rng_state(new_state, device=-1):
"""Sets the random number generator state of the current GPU.
Arguments:
new_state (torch.ByteTensor): The desired state
This function is adapted from PyTorch repo (torch.cuda.set_rng_state) #ignore-cuda
with a single change: the input state is not cloned. Cloning caused
major performance issues for +4 GPU cases.
"""
if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState):
# older PyTorch
def cb():
with get_accelerator().device(device):
_C._cuda_setRNGState(new_state)
else:
# newer PyTorch
if device == -1:
device = torch.device(get_accelerator().device_name())
elif isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device(get_accelerator().device_name(), device)
def cb():
idx = device.index
if idx is None:
idx = get_accelerator().current_device()
default_generator = get_accelerator().default_generator(idx)
default_generator.set_state(new_state)
get_accelerator().lazy_call(cb)
class CudaRNGStatesTracker:
"""Tracker for the cuda RNG states.
Using the `add` method, a cuda rng state is initialized based on
the input `seed` and is assigned to `name`. Later, by forking the
rng state, we can perform operations and return to our starting
cuda state.
"""
def __init__(self):
# Map from a string name to the cuda rng state.
self.states_ = {}
# Seeds are just for book keeping and ensure no seed is set twice.
self.seeds_ = set()
def reset(self):
"""Set to the initial state (no tracker)."""
self.states_ = {}
self.seeds_ = set()
def get_states(self):
"""Get rng states. Copy the dictionary so we have direct
pointers to the states, not just a pointer to the dictionary."""
return copy.copy(self.states_)
def set_states(self, states):
"""Set the rng states. For efficiency purposes, we do not check
the size of seed for compatibility."""
self.states_ = states
def add(self, name, seed):
"""Track the rng state."""
# Check seed is not already used.
if seed in self.seeds_:
raise Exception('seed {} already exists'.format(seed))
self.seeds_.add(seed)
# Check that state is not already defined.
if name in self.states_:
raise Exception('cuda rng state {} already exists'.format(name))
# Get the current rng state.
orig_rng_state = get_accelerator().get_rng_state()
# Set the new state and store it.
get_accelerator().manual_seed(seed)
self.states_[name] = get_accelerator().get_rng_state()
# Reset rng state to what it was.
_set_cuda_rng_state(orig_rng_state)
@contextlib.contextmanager
def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME):
"""Fork the cuda rng state, perform operations, and exit with
the original state."""
# Check if we have added the state
if name not in self.states_:
raise Exception('cuda rng state {} is not added'.format(name))
# Store current rng state.
orig_cuda_rng_state = get_accelerator().get_rng_state()
# Set rng state to the desired one
_set_cuda_rng_state(self.states_[name])
# Do the stuff we wanted to do.
try:
yield
finally:
# Update the current rng state for later use.
self.states_[name] = get_accelerator().get_rng_state()
# And set the state to the original state we started with.
_set_cuda_rng_state(orig_cuda_rng_state)
# RNG tracker object.
_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
def get_cuda_rng_tracker():
"""Get cuda rng tracker."""
return _CUDA_RNG_STATE_TRACKER
def model_parallel_cuda_manual_seed(seed):
"""Initialize model parallel cuda seed.
This function should be called after the model parallel is
initialized. Also, no get_accelerator().manual_seed should be called
after this function. Basically, this is replacement for that
function.
Two set of RNG states are tracked:
default state: This is for data parallelism and is the same among a
set of model parallel GPUs but different across
different model parallel groups. This is used for
example for dropout in the non-model-parallel regions.
model-parallel state: This state is different among a set of model
parallel GPUs, but the same across data parallel
groups. This is used for example for dropout in
model parallel regions.
"""
global mpu
tp_rank = bwc_tensor_model_parallel_rank(mpu)
# 2718 is just for fun and any POSITIVE value will work.
offset = seed + 2718
model_parallel_seed = offset + tp_rank
# Data parallel gets the original seed.
data_parallel_seed = seed
if dist.get_rank() == 0:
logger.info(
'> initializing model parallel cuda seeds on global rank {}, '
'model parallel rank {}, and data parallel rank {} with '
'model parallel seed: {} and data parallel seed: {}'.format(dist.get_rank(), tp_rank,
mpu.get_data_parallel_rank(),
model_parallel_seed, data_parallel_seed), )
_CUDA_RNG_STATE_TRACKER.reset()
# Set the default state.
get_accelerator().manual_seed(data_parallel_seed)
# and model parallel state.
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, model_parallel_seed)
def model_parallel_reconfigure_tp_seed(seed):
global mpu
tp_rank = bwc_tensor_model_parallel_rank(mpu)
model_parallel_seed = seed + 2718 + tp_rank
with _CUDA_RNG_STATE_TRACKER.fork():
get_accelerator().manual_seed(model_parallel_seed)
def get_partition_start(item):
global mp_rank, mp_size, mp_group
size = item.numel()
partition_size = size / mp_size
start = partition_size * mp_rank
return int(start)
def get_partition_size(item):
global mp_rank, mp_size, mp_group
size = item.numel()
assert size % mp_size == 0, "Doesn't handle if partition activation if item is not divisible by mp size"
partition_size = size / mp_size
return int(partition_size)
def gather_partitioned_activations(tensors, device=None):
global mp_rank, mp_size, mp_group
assert len(tensors) % 2 == 0, f'Expected even count of tensors, instead got {len(tensors)}'
inputs = []
num_args = int(len(tensors) / 2)
for i in range(num_args):
item = tensors[2 * i]
size = tensors[2 * i + 1]
if not is_activation_to_checkpoint(item):
inputs.append(item)
continue
# don't need to do all_gather if model parallel is not enabled
if mp_group is None or mp_size == 1:
item = item.view(list(size.numpy()))
if device is not None:
item = item.to(device)
inputs.append(item)
continue
partition_size = item.numel()
tensor_size = partition_size * mp_size
if device is not None:
flat_tensor = torch.zeros([tensor_size], dtype=item.dtype, device=device)
else:
flat_tensor = torch.zeros([tensor_size], dtype=item.dtype, device=item.device)
part = flat_tensor.narrow(0, partition_size * mp_rank, partition_size)
part.copy_(item)
dist.all_gather_into_tensor(flat_tensor, part, group=mp_group)
input_tensor = flat_tensor.view(list(size.numpy()))
item.data = input_tensor.data
inputs.append(item)
return tuple(inputs)
def extract_tensors(all_objects):
"""
Separate objects in list/tuple into tensors and non-tensors and create a mapping to enable re-aggregation.
The order of tensors and non-tensors is preserved in their respective output groups.
Parameters:
all_objects (list/tuple): Objects containing tensors and non-tensors to be split.
Returns:
tuple: Containing tensors, non-tensors, and bools of whether each position in original list/tuple was a tensor.
"""
tensor_objects = [v for v in all_objects if torch.is_tensor(v)]
non_tensor_objects = [v for v in all_objects if not torch.is_tensor(v)]
tensor_flags = [torch.is_tensor(v) for v in all_objects]
if type(all_objects) is tuple:
return tuple(tensor_objects), tuple(non_tensor_objects), tuple(tensor_flags)
return tensor_objects, non_tensor_objects, tensor_flags
def merge_tensors(tensor_objects, non_tensor_objects, tensor_flags):
"""
Merge two lists (or tuples) of tensors and non-tensors using a mapping of positions in merged list (or tuple).
Parameters:
tensor_objects (list/tuple): Tensors to merge.
non_tensor_objects (list/tuple): Non-tensors to merge.
tensor_flags (list/tuple): Indicates whether each position in output is a tensor.
Returns:
tuple: Merge of tensors and non-tensors
"""
merged_objects = []
tensor_idx = 0
non_tensor_idx = 0
real_tensor_flags = None
# remove the flags that are assigned to the size of the flattened tensors
if PARTITION_ACTIVATIONS:
real_tensor_flags = []
previous_flag = False
for flag in tensor_flags:
if previous_flag:
previous_flag = False
continue
previous_flag = flag
real_tensor_flags.append(flag)
else:
real_tensor_flags = tensor_flags
for is_tensor in real_tensor_flags:
if is_tensor:
merged_objects.append(tensor_objects[tensor_idx])
tensor_idx += 1
else:
merged_objects.append(non_tensor_objects[non_tensor_idx])
non_tensor_idx += 1
return tuple(merged_objects)
def is_activation_to_checkpoint(item):
"""
Is an activation to be checkpointed
"""
global mp_size
extra_flag = (not hasattr(item, 'no_checkpointing')) or (hasattr(item, 'no_checkpointing')
and item.no_checkpointing == False)
return torch.is_tensor(item) and item.is_floating_point() and item.numel() >= mp_size and extra_flag
def partition_activations(args, cpu_checkpoint, contiguous_checkpoint):
global contiguous_data_buffers, data_offsets
inputs = []
num_non_fp_tensors = 0
for arg_index, item in enumerate(args):
if not is_activation_to_checkpoint(item):
inputs.append(item)
num_non_fp_tensors += 1
continue
i = arg_index - num_non_fp_tensors
partition_size = get_partition_size(item)
partition = item.detach().contiguous().view(-1).narrow(0, get_partition_start(item), partition_size).clone()
buffer_device = torch.device('cpu') if cpu_checkpoint else partition.device
if contiguous_checkpoint:
if i >= len(contiguous_data_buffers):
tensor_list = [
torch.tensor(()).new_empty([partition_size], dtype=partition.dtype, device=buffer_device)
for _ in range(num_layers)
]
contiguous_data_buffers.append(tensor_list)
data_offsets.append(0)
elif contiguous_data_buffers[i] is None:
tensor_list = [
torch.tensor(()).new_empty([partition_size], dtype=partition.dtype, device=buffer_device)
for _ in range(num_layers)
]
contiguous_data_buffers[i] = tensor_list
data_offsets[i] = 0
# Because the 'new_empty' returns uninitialized pages,
# the pages need to be populated during the cudaMemcpy time
# which increases the data copy time. To avoid this, we
# pre-populate these pages by simply writing 0 ahead of
# the actual cudaMemcpy operation time. Due to the
# previously launched GPU kernels, there is a small
# window of time here for CPUs to populate pages asynchronously.
contiguous_data_buffers[i][data_offsets[i]].data[range(
0, contiguous_data_buffers[i][data_offsets[i]].data.shape[0],
int(mmap.PAGESIZE / contiguous_data_buffers[i][data_offsets[i]].data.element_size()))] = 0
contiguous_partition = contiguous_data_buffers[i][data_offsets[i]].data.copy_(partition.data)
data_offsets[i] = data_offsets[i] + 1
inputs.append(contiguous_partition)
else:
partition = partition.cpu() if CPU_CHECKPOINT else partition
inputs.append(partition)
return inputs
def get_partitioned_activations_for_backward(args, inputs, contiguous_checkpoint):
global contiguous_size_buffers, size_offsets
new_args = []
num_non_fp_tensors = 0
for arg_index, (arg, inp) in enumerate(zip(args, inputs)):
size = torch.tensor(arg.size()) if torch.is_tensor(arg) else None
if not is_activation_to_checkpoint(arg):
new_args.append(arg)
new_args.append(size)
num_non_fp_tensors += 1
continue
arg.data = torch.empty([], device=arg.device).data
arg.saved_data = inp.data
new_args.append(arg)
i = arg_index - num_non_fp_tensors
if contiguous_checkpoint:
numel = size.numel()
if i >= len(contiguous_size_buffers):
tmp = torch.tensor(())
contiguous_size_buffers.append(
tmp.new_empty([numel * num_layers], dtype=size.dtype, device=size.device))
size_offsets.append(0)
elif contiguous_size_buffers[i] is None:
tmp = torch.tensor(())
contiguous_size_buffers[i] = tmp.new_empty([numel * num_layers], dtype=size.dtype, device=size.device)
size_offsets[i] = 0
contiguous_size = contiguous_size_buffers[i].narrow(0, size_offsets[i], numel).data.copy_(size.data)
contiguous_size = contiguous_size.view_as(size)
size_offsets[i] = size_offsets[i] + numel
new_args.append(contiguous_size)
else:
new_args.append(size)
return new_args
def get_cpu_activations_for_backward(args, inputs):
new_args = []
for i, (arg, inp) in enumerate(zip(args, inputs)):
if not is_activation_to_checkpoint(arg):
new_args.append(arg)
continue
arg.data = torch.empty([], device=arg.device).data
arg.saved_data = inp.data
new_args.append(arg)
return new_args
class CheckpointFunction(torch.autograd.Function):
"""This function is adapted from torch.utils.checkpoint with
two main changes:
1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state` #ignore-cuda
2) the states in the model parallel tracker are also properly
tracked/set/reset.
3) Performance activation partitioning, contiguous memory optimization
4) CPU Checkpointing
5) Profile forward and backward functions
"""
@staticmethod
def forward(ctx, run_function, all_outputs, *args):
global mpu, timers, SYNCHRONIZE, PROFILE_TIME
def save_args_for_backward(*all_args):
tensor_args, non_tensor_args, tensor_flags = extract_tensors(all_objects=all_args)
ctx.deepspeed_saved_tensors = tensor_args
ctx.non_tensor_args = non_tensor_args
ctx.tensor_flags = tensor_flags
if SYNCHRONIZE:
get_accelerator().synchronize()
if timers is None and PROFILE_TIME:
timers = Timers()
if PROFILE_TIME:
timers(FORWARD_GLOBAL_TIMER).start()
ctx.run_function = run_function
global num_layers
global mp_rank, mp_size, mp_group
global contiguous_data_buffers, contiguous_size_buffers
global data_offsets, size_offsets
global PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset
cuda_device = get_accelerator().current_device_name()
transport_stream = get_accelerator().Stream(device=cuda_device)
if PARTITION_ACTIVATIONS:
inputs = partition_activations(args, CPU_CHECKPOINT, CONTIGUOUS_CHECKPOINTING)
elif CPU_CHECKPOINT:
inputs = copy_to_device(args, device=torch.device('cpu'), criterion_func=is_activation_to_checkpoint)
# just in case something funky is happening such as reuse of inputs
inputs_cuda = copy_to_device(args, device=cuda_device, criterion_func=is_activation_to_checkpoint)
# Copy the rng states.
ctx.fwd_cpu_rng_state = torch.get_rng_state()
ctx.fwd_cuda_rng_state = get_accelerator().get_rng_state()
ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
see_memory_usage("Before running forward on the layer", force=False)
# ctx.save_for_backward(*args)
with torch.no_grad():
outputs = run_function(*inputs_cuda)
see_memory_usage("After running forward on the layer", force=False)
del inputs_cuda
if PARTITION_ACTIVATIONS:
new_args = get_partitioned_activations_for_backward(args, inputs, CONTIGUOUS_CHECKPOINTING)
assert len(new_args) % 2 == 0, f'save_for_backward called with odd number of args, {len(new_args)}'
save_args_for_backward(*new_args)
elif CPU_CHECKPOINT:
new_args = get_cpu_activations_for_backward(args, inputs)
save_args_for_backward(*new_args)
else:
save_args_for_backward(*args)
if PROFILE_TIME:
timers(FORWARD_GLOBAL_TIMER).stop()
timers.log([FORWARD_GLOBAL_TIMER])
if SYNCHRONIZE:
get_accelerator().synchronize()
# Tensors returned from forward() may not be differentiable.
if torch.is_tensor(outputs):
non_grad_outputs = [outputs] if not outputs.is_floating_point() else []
else:
non_grad_outputs = [o for o in outputs if torch.is_tensor(o) and not o.is_floating_point()]
ctx.mark_non_differentiable(*non_grad_outputs)
if torch.is_tensor(outputs):
all_outputs += [outputs]
return outputs
else:
all_outputs += outputs
outputs, _, _ = extract_tensors(all_objects=outputs)
return tuple(outputs)
@staticmethod
def backward(ctx, *grads):
global timers
see_memory_usage("In backward", force=False)
# removing pointers to the contiguous buffer memory
# so that they can be garbage collected once the checkpoints
# have been used
if SYNCHRONIZE:
get_accelerator().synchronize()
if PROFILE_TIME:
timers('backward').start()
if CONTIGUOUS_CHECKPOINTING:
global data_offsets, size_offsets
global contiguous_data_buffers, contiguous_size_buffers
for buffers in contiguous_data_buffers:
buffers = []
# frees up all the pointers to the checkpoints except for the ones
# stored by save for backward
contiguous_data_buffers = []
contiguous_size_buffers = []
data_offsets = []
size_offsets = []
see_memory_usage("In backward checkpointing code", force=False)
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError("Checkpointing is not compatible with .grad(), "
"please use .backward() if possible")
global PARTITION_ACTIVATIONS
cuda_device = get_accelerator().current_device_name()
transport_stream = get_accelerator().Stream(device=cuda_device)
# Rebuild deepspeed_saved_tensors
for t in ctx.deepspeed_saved_tensors:
if t is not None and hasattr(t, 'saved_data') and t.saved_data is not None:
t.data = t.saved_data.to(t.device)
t.saved_data = None
if PARTITION_ACTIVATIONS:
# with get_accelerator().stream(transport_stream):
inputs = gather_partitioned_activations(ctx.deepspeed_saved_tensors,
device=cuda_device if CPU_CHECKPOINT else None)
detached_inputs = detach_variable(inputs)
elif CPU_CHECKPOINT:
inputs = move_to_device(ctx.deepspeed_saved_tensors, cuda_device, is_activation_to_checkpoint)
detached_inputs = detach_variable(inputs)
else:
inputs = ctx.deepspeed_saved_tensors
detached_inputs = detach_variable(inputs)
# Add non tensor input args
detached_inputs = merge_tensors(tensor_objects=detached_inputs,
non_tensor_objects=ctx.non_tensor_args,
tensor_flags=ctx.tensor_flags)
# Store the current states.
bwd_cpu_rng_state = torch.get_rng_state()
bwd_cuda_rng_state = get_accelerator().get_rng_state()
bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
# Set the states to what it used to be before the forward pass.
torch.set_rng_state(ctx.fwd_cpu_rng_state)
_set_cuda_rng_state(ctx.fwd_cuda_rng_state)
get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)
# if PARTITION_ACTIVATIONS:
# current_stream=get_accelerator().current_stream()
# current_stream.wait_stream(transport_stream)
see_memory_usage("In backward checkpointing code before forward", force=False)
with torch.enable_grad():
outputs = ctx.run_function(*detached_inputs)
see_memory_usage("In backward checkpointing code after forward", force=False)
# Set the states back to what it was at the start of this function.
torch.set_rng_state(bwd_cpu_rng_state)
_set_cuda_rng_state(bwd_cuda_rng_state)
get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)
if isinstance(outputs, torch.Tensor):
outputs = (outputs, )
# Filter out non tensor outputs
outputs, _, _ = extract_tensors(all_objects=outputs)
# Construct arguments to autograd.backward().
# This is usually just outputs and grads, but forward() can return tensors that
# are not differentiable.
output_tensors = []
grad_tensors = []
for out, grad in zip(outputs, grads):
if out.requires_grad:
output_tensors.append(out)
grad_tensors.append(grad)
see_memory_usage("In backward checkpointing code before backward", force=False)
torch.autograd.backward(output_tensors, grad_tensors)
# Force clear our stashed tensors to prevent a memory leak in certain scenarios
ctx.deepspeed_saved_tensors = None
ctx.non_tensor_args = None
ctx.tensor_flags = None
see_memory_usage("After backward checkpointing code after backward", force=False)
if PROFILE_TIME:
timers('backward').stop()
timers.log(['backward'])
if SYNCHRONIZE:
get_accelerator().synchronize()
ret_list = [None, None] # first None for ctx
for inp in detached_inputs:
if torch.is_tensor(inp):
ret_list.append(inp.grad)
else:
ret_list.append(None)
return tuple(ret_list)
def non_reentrant_checkpoint(function, *args):
"""This function is union of `torch.utils.checkpoint._checkpoint_without_reentrant` and `CheckpointFunction` in this module
This function is aim to solve the back probagation error raised from all input requires no grad.
* has already been implemented in pytorch for a while, the solution is stable at most time except for jit module mode.
* can help to solve the issue which is hacked by `deepspeed.runtime.pipe.module.PipelineModule._is_checkpointable`
Main modifications compared to the implementation of torch:
1. adapt to the signature of `checkpoint` function in this module
2. solve the non-deterministic by random state management consistent with deepspeed `CheckpointFunction`
3. when there is partition or cpu checkpointing, gather them in the unpack_hook during back probagation
4. make all after backward blocks in the hook which will executed after all leaf nodes backward execution.
5. above 4. is inspired by `torch.autograd.graph.register_multi_grad_hook`, which is only implemented after 2.0.0
"""
global mpu, timers, SYNCHRONIZE, PROFILE_TIME
deepspeed_saved_tensors = None
non_tensor_args = None
tensor_flags = None
def save_args_for_backward(*all_args):
"""keep this function to reduce the modification from original implementation"""
nonlocal deepspeed_saved_tensors, non_tensor_args, tensor_flags
tensor_args, non_tensor_args, tensor_flags = extract_tensors(all_objects=all_args)
deepspeed_saved_tensors = tensor_args
non_tensor_args = non_tensor_args
tensor_flags = tensor_flags
if SYNCHRONIZE:
get_accelerator().synchronize()
if timers is None and PROFILE_TIME:
timers = Timers()
if PROFILE_TIME:
timers(FORWARD_GLOBAL_TIMER).start()
global num_layers
global mp_rank, mp_size, mp_group
global contiguous_data_buffers, contiguous_size_buffers
global data_offsets, size_offsets
global PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset
cuda_device = get_accelerator().current_device_name()
transport_stream = get_accelerator().Stream(device=cuda_device)
if PARTITION_ACTIVATIONS:
inputs = partition_activations(args, CPU_CHECKPOINT, CONTIGUOUS_CHECKPOINTING)
elif CPU_CHECKPOINT:
inputs = copy_to_device(args, device=torch.device('cpu'), criterion_func=is_activation_to_checkpoint)
# just in case something funky is happening such as reuse of inputs
inputs_cuda = copy_to_device(args, device=cuda_device, criterion_func=is_activation_to_checkpoint)
# Copy the rng states.
fwd_cpu_rng_state = torch.get_rng_state()
fwd_cuda_rng_state = get_accelerator().get_rng_state()
fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
if PARTITION_ACTIVATIONS:
new_args = get_partitioned_activations_for_backward(args, inputs, CONTIGUOUS_CHECKPOINTING)
assert len(new_args) % 2 == 0, f'save_for_backward called with odd number of args, {len(new_args)}'
save_args_for_backward(*new_args)
elif CPU_CHECKPOINT:
new_args = get_cpu_activations_for_backward(args, inputs)
save_args_for_backward(*new_args)
else:
save_args_for_backward(*args)
class Holder():
"""the place holder object used as activations to save memory"""
pass
# weakref seems utilized to discover the tensor deletion before a whole
# forward backward pair loop finished
storage: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
weak_holder_list = []
leaf_tensors = []
backward_visited_leaf_nodes = 0
def checkpoint_pack(tensor_from_forward):
"""used to record the activation order in the `weak_holder_list`
the activation order in holder list is consistent between the first forward and recomputing forward.
* the jit compiled forward will break the order consistency *
"""
res = Holder()
weak_holder_list.append(weakref.ref(res))
# if this is a leaf tensor, save it for backward progression trace
# leaf tensor used to be input or parameters, which is not activations and
# has no memory overhead
if tensor_from_forward.requires_grad and tensor_from_forward.is_leaf:
leaf_tensors.append(tensor_from_forward)
return res
def checkpoint_unpack(holder_from_backward):
"""retrieve the activations from recompute"""
nonlocal deepspeed_saved_tensors, non_tensor_args, tensor_flags
# if this is the first step of backward probagation, recompute the graph and save
# all the activations with the same order as `checkpoint_pack` does
if len(storage) == 0:
unpack_counter = 0
def replay_pack(tensor_from_replay):
"""save recompute activations"""
nonlocal unpack_counter
unpack_counter += 1
if weak_holder_list[unpack_counter - 1]() is None:
return
detached_activations = tensor_from_replay.detach()
storage[weak_holder_list[unpack_counter - 1]()] = detached_activations
return
def replay_unpack(none_value):
"""recompute graph need not to backward"""
raise RuntimeError("You are calling backwards on a tensor that is never exposed.")
global timers
see_memory_usage("In backward", force=False)
# removing pointers to the contiguous buffer memory
# so that they can be garbage collected once the checkpoints
# have been used
if SYNCHRONIZE:
get_accelerator().synchronize()
if PROFILE_TIME:
timers('backward').start()
if CONTIGUOUS_CHECKPOINTING:
global data_offsets, size_offsets
global contiguous_data_buffers, contiguous_size_buffers
for buffers in contiguous_data_buffers:
buffers = []
# frees up all the pointers to the checkpoints except for the ones
# stored by save for backward
contiguous_data_buffers = []
contiguous_size_buffers = []
data_offsets = []
size_offsets = []
see_memory_usage("In backward checkpointing code", force=False)
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError("Checkpointing is not compatible with .grad(), "
"please use .backward() if possible")
global PARTITION_ACTIVATIONS
cuda_device = get_accelerator().current_device_name()
transport_stream = get_accelerator().Stream(device=cuda_device)
# gather inputs which is partitioned or checkpointed before first forward
if PARTITION_ACTIVATIONS:
# with get_accelerator().stream(transport_stream):
inputs = gather_partitioned_activations(deepspeed_saved_tensors,
device=cuda_device if CPU_CHECKPOINT else None)
detached_inputs = detach_variable(inputs)
elif CPU_CHECKPOINT:
inputs = move_to_device(deepspeed_saved_tensors, cuda_device, is_activation_to_checkpoint)
detached_inputs = detach_variable(inputs)
else:
inputs = deepspeed_saved_tensors
detached_inputs = detach_variable(inputs)
# Add non tensor input args
detached_inputs = merge_tensors(tensor_objects=detached_inputs,
non_tensor_objects=non_tensor_args,
tensor_flags=tensor_flags)
# Store the current states.
bwd_cpu_rng_state = torch.get_rng_state()
bwd_cuda_rng_state = get_accelerator().get_rng_state()
bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
# Set the states to what it used to be before the forward pass.
torch.set_rng_state(fwd_cpu_rng_state)
_set_cuda_rng_state(fwd_cuda_rng_state)
get_cuda_rng_tracker().set_states(fwd_cuda_rng_state_tracker)
see_memory_usage("In backward checkpointing code before forward", force=False)
with torch.enable_grad(), torch.autograd.graph.saved_tensors_hooks(replay_pack, replay_unpack):
_unused = function(*detached_inputs)
see_memory_usage("In backward checkpointing code after forward", force=False)
# Set the states back to what it was at the start of this function.
torch.set_rng_state(bwd_cpu_rng_state)
_set_cuda_rng_state(bwd_cuda_rng_state)
get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)
deepspeed_saved_tensors = None
non_tensor_args = None
tensor_flags = None
if holder_from_backward not in storage:
raise RuntimeError("Attempt to retrieve a tensor saved by autograd multiple times without checkpoint"
" recomputation being triggered in between, this is not currently supported.")
return storage[holder_from_backward]
def after_backward_hook(_nonuse_grads):
"""the hook registered to all leaf tensors"""
nonlocal leaf_tensors, backward_visited_leaf_nodes
backward_visited_leaf_nodes += 1
if backward_visited_leaf_nodes == len(leaf_tensors):
see_memory_usage("After backward checkpointing code after backward", force=False)
if PROFILE_TIME:
timers('backward').stop()
timers.log(['backward'])
if SYNCHRONIZE:
get_accelerator().synchronize()
with torch.autograd.graph.saved_tensors_hooks(checkpoint_pack, checkpoint_unpack):
outputs = function(*inputs_cuda)
if PROFILE_TIME or SYNCHRONIZE:
for leaf_tensor in leaf_tensors:
leaf_tensor.register_hook(after_backward_hook)
see_memory_usage("After running forward on the layer", force=False)
if PROFILE_TIME:
timers(FORWARD_GLOBAL_TIMER).stop()
timers.log([FORWARD_GLOBAL_TIMER])
if SYNCHRONIZE:
get_accelerator().synchronize()
all_outputs = []
if torch.is_tensor(outputs):
all_outputs += [outputs]
else:
all_outputs += outputs
if len(all_outputs) == 1:
return all_outputs[0]
else:
return tuple(all_outputs)
@compiler.disable # WA from Pytorch repo for compile + zero 3 accuracy issue
def checkpoint(function, *args):
"""Checkpoint a model or part of the model.
This has been directly copied from torch.utils.checkpoint. """
all_outputs = []
CheckpointFunction.apply(function, all_outputs, *args)
if len(all_outputs) == 1:
return all_outputs[0]
else:
return tuple(all_outputs)
def partition_activations_in_checkpoint(partition_activation):
global PARTITION_ACTIVATIONS
PARTITION_ACTIVATIONS = partition_activation
if dist.get_rank() == 0:
logger.info(f"**************Partition Activations {PARTITION_ACTIVATIONS}************")
def set_num_layers(nlayers):
global num_layers
num_layers = nlayers
def reset():
"""Resets memory buffers related to contiguous memory optimizations.
Should be called during eval when multiple forward propagations are
computed without any backward propagation that usually clears these
buffers.
Arguments:
None
Return:
None
"""
if CONTIGUOUS_CHECKPOINTING:
global data_offsets, size_offsets
global contiguous_data_buffers, contiguous_size_buffers
for buffers in contiguous_data_buffers:
buffers = []
# frees up all the pointers to the checkpoints except for the ones
# stored by save for backward
contiguous_data_buffers = []
contiguous_size_buffers = []
data_offsets = []
size_offsets = []
def _configure_using_config_file(config, mpu=None):
global num_layers, PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \
CPU_CHECKPOINT, SYNCHRONIZE, PROFILE_TIME