From a15a3c6675a9dbef56540eca178d1727ee060eb5 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Sat, 9 Nov 2024 12:00:56 -0800 Subject: [PATCH] ci: use stable rust and gate on number of gpus --- .github/workflows/unittest.yaml | 9 +++++++-- torchft/process_group_test.py | 4 ++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/.github/workflows/unittest.yaml b/.github/workflows/unittest.yaml index b603a4d2..586336da 100644 --- a/.github/workflows/unittest.yaml +++ b/.github/workflows/unittest.yaml @@ -25,12 +25,17 @@ jobs: gpu-arch-type: ${{ matrix.gpu-arch-type }} gpu-arch-version: ${{ matrix.gpu-arch-version }} script: | + set -ex + + # install python and protobuf conda create -n venv python=3.10 protobuf -y conda activate venv + python -m pip install --upgrade pip - yum install -y rust cargo + # install recent version of Rust via rustup + curl https://sh.rustup.rs -sSf | sh -s -- --default-toolchain=stable --profile=default -y + . "$HOME/.cargo/env" - python -m pip install --upgrade pip pip install -e .[dev] -v pytest -v diff --git a/torchft/process_group_test.py b/torchft/process_group_test.py index 5bed565f..98ad6aee 100644 --- a/torchft/process_group_test.py +++ b/torchft/process_group_test.py @@ -110,8 +110,8 @@ def test_dummy(self) -> None: m = torch.nn.parallel.DistributedDataParallel(m, process_group=pg) m(torch.rand(2, 3)) - @skipUnless(torch.cuda.is_available(), "needs CUDA") - def test_baby_nccl(self) -> None: + @skipUnless(torch.cuda.device_count() >= 2, "need two CUDA devices") + def test_baby_nccl_2gpu(self) -> None: store = TCPStore( host_name="localhost", port=0, is_master=True, wait_for_workers=False )