From 9b4093b37ab0596fef1743c62e4c446bbb1faeab Mon Sep 17 00:00:00 2001 From: Tunji Ruwase Date: Fri, 11 Jun 2021 18:48:02 +0000 Subject: [PATCH] Avoid partitioning small activations --- deepspeed/runtime/activation_checkpointing/checkpointing.py | 2 +- deepspeed/runtime/engine.py | 1 - deepspeed/runtime/pipe/engine.py | 4 ++-- deepspeed/runtime/utils.py | 1 - 4 files changed, 3 insertions(+), 5 deletions(-) diff --git a/deepspeed/runtime/activation_checkpointing/checkpointing.py b/deepspeed/runtime/activation_checkpointing/checkpointing.py index d1f92146a38f..efe95f91bac5 100644 --- a/deepspeed/runtime/activation_checkpointing/checkpointing.py +++ b/deepspeed/runtime/activation_checkpointing/checkpointing.py @@ -424,7 +424,7 @@ def save_args_for_backward(*all_args): inputs = [] for i, item in enumerate(args[:-1]): - if not torch.is_tensor(item): + if not torch.is_tensor(item) or mp_size > item.numel(): inputs.append(item) continue diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 24b64fbb039b..146a6c6e931c 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1283,7 +1283,6 @@ def _take_model_step(self, lr_kwargs, block_eigenvalue={}): # pipe_engine.train_batch() self.lr_scheduler.step(increment=self.train_batch_size()) - if report_progress and (self.global_steps + 1) % self.steps_per_print() == 0: self._report_progress(self.global_steps + 1) diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 14601eb0a007..09231bf3e96a 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -983,7 +983,7 @@ def _exec_recv_grads(self, buffer_id): s = list(outputs.size()) self.grad_layer = self._allocate_buffer(s, num_buffers=1)[0] else: - sizes = [list(t.size()) for t in outputs]# if t.is_floating_point()] + sizes = [list(t.size()) for t in outputs] # if t.is_floating_point()] self.grad_layer = self._allocate_buffers(sizes, num_buffers=1)[0] if isinstance(self.grad_layer, torch.Tensor): @@ -997,7 +997,7 @@ def _exec_recv_grads(self, buffer_id): dtype=torch.long, device=self.device) p2p.recv(buffer, self.next_stage) - + if self.wall_clock_breakdown(): self.timers('pipe_recv_grad').stop() diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index c3ababa43c3e..f3a4c4c0f1b3 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -87,7 +87,6 @@ def bwc_tensor_model_parallel_rank(mpu=None): return mpu.get_model_parallel_rank() - def move_to_device(item, device): """ Move tensor onto device. Works on individual tensors, and tensors contained/nested in lists, tuples, and dicts.