Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,19 +1065,19 @@ def create_reduce_and_remove_grad_hooks(self):
param.all_gather()

#print(f"After all gather {param.device}, {param.shape}")
def wrapper(param, i):
def wrapper(param):
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]

@instrument_w_nvtx
def reduce_partition_and_remove_grads(*notneeded):
self.reduce_ready_partitions_and_remove_grads(param, i)
self.reduce_ready_partitions_and_remove_grads(param)

grad_acc.register_hook(reduce_partition_and_remove_grads)
self.grad_accs.append(grad_acc)

#print(f"param grad fn {param.expand_as(param).grad_fn}")
wrapper(param, i)
wrapper(param)

# Partition the parameter after creating the hook
param.partition()
Expand All @@ -1095,7 +1095,7 @@ def report_ipg_memory_usage(self, tag, param_elems):
force=False)

###############Independent Partition Gradient ########################
def reduce_independent_p_g_buckets_and_remove_grads(self, param, i):
def reduce_independent_p_g_buckets_and_remove_grads(self, param):
#print_rank_0(f"Inside reduce ipg buckets. {debug_param2name_id_shape(param)}, ipg elements {self.elements_in_ipg_bucket}, reduce bucket size {self.reduce_bucket_size}", force=True)

# Because the ipg bucket is initialized with a random place holder tensor, we must
Expand Down Expand Up @@ -1361,9 +1361,9 @@ def partition_grads(self, params_to_release: List[Parameter], grad_partitions: L
gradient_tensors=offload_fp32_gradients[i])
return buffers

def reduce_ready_partitions_and_remove_grads(self, param, i):
def reduce_ready_partitions_and_remove_grads(self, param):
#print_rank_0(f"Backward {debug_param2name_id_shape(param)}", force=True)
self.reduce_independent_p_g_buckets_and_remove_grads(param, i)
self.reduce_independent_p_g_buckets_and_remove_grads(param)

def zero_reduced_gradients(self, partition_id, i):

Expand Down