diff --git a/torchrec/optim/clipping.py b/torchrec/optim/clipping.py index bf5467209..94c11ac66 100644 --- a/torchrec/optim/clipping.py +++ b/torchrec/optim/clipping.py @@ -201,7 +201,7 @@ def _compute_total_norm( """ ## compute the norm |W|^p corresponding to all sharded params W - sharded_grad_norm: torch.Tensor = torch.tensor(0.0) + sharded_grad_norm: torch.Tensor = torch.tensor(0.0, pin_memory=True) combine_norm_operator = torch.maximum if norm_type == torch.inf else torch.add # We need to move sharded_grad_norm to the same device as the first shard so that we can do addition (or take max) @@ -216,7 +216,8 @@ def _compute_total_norm( process_groups=pgs, ) sharded_grad_norm = combine_norm_operator( - sharded_grad_norm.to(current_shard_norm.device), current_shard_norm + sharded_grad_norm.to(current_shard_norm.device, non_blocking=True), + current_shard_norm, ) # compute |W|^p corresponding to all replicate params W # Similar to the case above, we move replicate_grad_norm to the same device as sharded_grad_norm so that we can do addition. @@ -226,7 +227,7 @@ def _compute_total_norm( ) if replicate_grads else torch.tensor(0.0) - ).to(sharded_grad_norm.device) + ).to(sharded_grad_norm.device, non_blocking=True) # In the p-norm case, we are given norms |W_sharded|^p and |W_replicate|^p. To compute the total norm, we need to # sum them and take the p-th root. In the inf-norm case, we are given max(|W_sharded|) and max(|W_replicate|).