Skip to content

Commit

Permalink
Merge pull request #4 from ShadenSmith/olruwase/partition_activation
Browse files Browse the repository at this point in the history
Avoid partitioning small activations
  • Loading branch information
tjruwase committed Jun 11, 2021
2 parents a096d32 + 9b4093b commit 182be7b
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def save_args_for_backward(*all_args):

inputs = []
for i, item in enumerate(args[:-1]):
if not torch.is_tensor(item):
if not torch.is_tensor(item) or mp_size > item.numel():
inputs.append(item)
continue

Expand Down
1 change: 0 additions & 1 deletion deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1283,7 +1283,6 @@ def _take_model_step(self, lr_kwargs, block_eigenvalue={}):
# pipe_engine.train_batch()
self.lr_scheduler.step(increment=self.train_batch_size())


if report_progress and (self.global_steps + 1) % self.steps_per_print() == 0:
self._report_progress(self.global_steps + 1)

Expand Down
4 changes: 2 additions & 2 deletions deepspeed/runtime/pipe/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,7 +983,7 @@ def _exec_recv_grads(self, buffer_id):
s = list(outputs.size())
self.grad_layer = self._allocate_buffer(s, num_buffers=1)[0]
else:
sizes = [list(t.size()) for t in outputs]# if t.is_floating_point()]
sizes = [list(t.size()) for t in outputs] # if t.is_floating_point()]
self.grad_layer = self._allocate_buffers(sizes, num_buffers=1)[0]

if isinstance(self.grad_layer, torch.Tensor):
Expand All @@ -997,7 +997,7 @@ def _exec_recv_grads(self, buffer_id):
dtype=torch.long,
device=self.device)
p2p.recv(buffer, self.next_stage)

if self.wall_clock_breakdown():
self.timers('pipe_recv_grad').stop()

Expand Down
1 change: 0 additions & 1 deletion deepspeed/runtime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ def bwc_tensor_model_parallel_rank(mpu=None):
return mpu.get_model_parallel_rank()



def move_to_device(item, device):
"""
Move tensor onto device. Works on individual tensors, and tensors contained/nested in lists, tuples, and dicts.
Expand Down

0 comments on commit 182be7b

Please sign in to comment.