diff --git a/torchrec/distributed/tensor_pool.py b/torchrec/distributed/tensor_pool.py index 80d0abbd3..07dac5e32 100644 --- a/torchrec/distributed/tensor_pool.py +++ b/torchrec/distributed/tensor_pool.py @@ -305,7 +305,6 @@ def __init__( @torch.jit.export def set_device(self, device_str: str) -> None: self.current_device = torch.device(device_str) - self._shard.to(self.current_device) def forward(self, rank_ids: torch.Tensor) -> torch.Tensor: """