From d9ae9e55f9c830d7941aaef7bd0460907b91c250 Mon Sep 17 00:00:00 2001 From: Yifu Wang Date: Sat, 16 Dec 2023 15:34:59 -0800 Subject: [PATCH] Set cuda device before init_process_group --- generate.py | 1 - tp.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/generate.py b/generate.py index bb20d6c6..5356799a 100644 --- a/generate.py +++ b/generate.py @@ -260,7 +260,6 @@ def main( rank = maybe_init_dist() use_tp = rank is not None if use_tp: - torch.cuda.set_device(rank) if rank != 0: # only print on rank 0 print = lambda *args, **kwargs: None diff --git a/tp.py b/tp.py index f320e6b6..5a39a5e5 100644 --- a/tp.py +++ b/tp.py @@ -42,6 +42,7 @@ def maybe_init_dist() -> Optional[int]: # not run via torchrun, no-op return None + torch.cuda.set_device(rank) dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) return rank