Skip to content

Commit

Permalink
CI: Update deprecated actions, don't measure network RPS (#215)
Browse files Browse the repository at this point in the history
* CI: Switch to actions/cache@v3 (v2 is deprecated)
* Don't run measure_network_rps() in tests since it doesn't work well in
CI
  • Loading branch information
borzunov committed Jan 13, 2023
1 parent 825f5db commit 702bb5a
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 6 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/run-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
uses: actions/checkout@v2
- name: Check if the model is cached
id: cache-model
uses: actions/cache@v2
uses: actions/cache@v3
with:
path: ~/.dummy
key: model-v1-${{ hashFiles('setup.cfg', 'src/petals/cli/convert_model.py') }}
Expand All @@ -27,7 +27,7 @@ jobs:
python-version: 3.9
- name: Cache dependencies
if: steps.cache-model.outputs.cache-hit != 'true'
uses: actions/cache@v2
uses: actions/cache@v3
with:
path: ~/.cache/pip
key: Key-v1-3.9-${{ hashFiles('setup.cfg') }}
Expand Down Expand Up @@ -70,7 +70,7 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
- name: Cache dependencies
uses: actions/cache@v2
uses: actions/cache@v3
with:
path: ~/.cache/pip
key: Key-v1-${{ matrix.python-version }}-${{ hashFiles('setup.cfg') }}
Expand Down
4 changes: 1 addition & 3 deletions tests/test_aux_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

@pytest.mark.forked
@pytest.mark.parametrize("tensor_parallel", [False, True])
def test_throughput_basic(tensor_parallel: bool):
def test_compute_throughput(tensor_parallel: bool):
config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
tensor_parallel_devices = ("cpu", "cpu") if tensor_parallel else ()
compute_rps = measure_compute_rps(
Expand All @@ -20,5 +20,3 @@ def test_throughput_basic(tensor_parallel: bool):
n_steps=10,
)
assert isinstance(compute_rps, float) and compute_rps > 0
network_rps = measure_network_rps(config)
assert isinstance(network_rps, float) and network_rps > 0

0 comments on commit 702bb5a

Please sign in to comment.