Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions torchrec/optim/clipping.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def _compute_total_norm(
"""

## compute the norm |W|^p corresponding to all sharded params W
sharded_grad_norm: torch.Tensor = torch.tensor(0.0)
sharded_grad_norm: torch.Tensor = torch.tensor(0.0, pin_memory=True)
combine_norm_operator = torch.maximum if norm_type == torch.inf else torch.add

# We need to move sharded_grad_norm to the same device as the first shard so that we can do addition (or take max)
Expand All @@ -216,7 +216,8 @@ def _compute_total_norm(
process_groups=pgs,
)
sharded_grad_norm = combine_norm_operator(
sharded_grad_norm.to(current_shard_norm.device), current_shard_norm
sharded_grad_norm.to(current_shard_norm.device, non_blocking=True),
current_shard_norm,
)
# compute |W|^p corresponding to all replicate params W
# Similar to the case above, we move replicate_grad_norm to the same device as sharded_grad_norm so that we can do addition.
Expand All @@ -226,7 +227,7 @@ def _compute_total_norm(
)
if replicate_grads
else torch.tensor(0.0)
).to(sharded_grad_norm.device)
).to(sharded_grad_norm.device, non_blocking=True)

# In the p-norm case, we are given norms |W_sharded|^p and |W_replicate|^p. To compute the total norm, we need to
# sum them and take the p-th root. In the inf-norm case, we are given max(|W_sharded|) and max(|W_replicate|).
Expand Down
Loading