diff --git a/deepspeed/runtime/zero/linear.py b/deepspeed/runtime/zero/linear.py index b97a833beacb..175013361fb8 100644 --- a/deepspeed/runtime/zero/linear.py +++ b/deepspeed/runtime/zero/linear.py @@ -26,8 +26,6 @@ from deepspeed import comm as dist from deepspeed.accelerator import get_accelerator -tensor_map = {} - def print_rank_0(message, debug=False, force=False): if dist.get_rank() == 0 and (debug or force): @@ -50,14 +48,7 @@ class LinearFunctionForZeroStage3(torch.autograd.Function): # bias is an optional argument def forward(ctx, input, weight, bias=None): - weight_id = id(weight) - bias_id = id(bias) - - #ctx.save_for_backward(input, weight, bias) - ctx.save_for_backward(input, torch.tensor(weight_id), torch.tensor(bias_id)) - - tensor_map[weight_id] = weight - tensor_map[bias_id] = bias + ctx.save_for_backward(input, weight, bias) if input.dim() == 2 and bias is not None: # fused op is marginally faster @@ -79,11 +70,7 @@ def backward(ctx, grad_output): # None. Thanks to the fact that additional trailing Nones are # ignored, the return statement is simple even when the function has # optional inputs. - #input, weight, bias = ctx.saved_tensors - - input, weight_id, bias_id = ctx.saved_tensors - weight = tensor_map[weight_id.item()] - bias = tensor_map[bias_id.item()] + input, weight, bias = ctx.saved_tensors grad_input = grad_weight = grad_bias = None