diff --git a/generate.py b/generate.py index bb20d6c6..5356799a 100644 --- a/generate.py +++ b/generate.py @@ -260,7 +260,6 @@ def main( rank = maybe_init_dist() use_tp = rank is not None if use_tp: - torch.cuda.set_device(rank) if rank != 0: # only print on rank 0 print = lambda *args, **kwargs: None diff --git a/tp.py b/tp.py index f320e6b6..5a39a5e5 100644 --- a/tp.py +++ b/tp.py @@ -42,6 +42,7 @@ def maybe_init_dist() -> Optional[int]: # not run via torchrun, no-op return None + torch.cuda.set_device(rank) dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) return rank