Skip to content

Commit

Permalink
Fix xla_bridge_test.py test failures.
Browse files Browse the repository at this point in the history
We are splitting the plugins in the enviroment variable using os.pathsep; we should make sure to use that as the separator in the test.
  • Loading branch information
hawkinsp committed Jun 16, 2023
1 parent ba10e94 commit a05ffab
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/wheel_win_x64.yml
Expand Up @@ -38,7 +38,7 @@ jobs:
run: |
python -m pip install -r build/test-requirements.txt
"C:\\msys64\\;C:\\msys64\\usr\\bin\\;" >> $env:GITHUB_PATH
python.exe build\build.py
python.exe build\build.py --bazel_options=--color=yes
- uses: actions/upload-artifact@v3
with:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/windows_ci.yml
Expand Up @@ -47,7 +47,7 @@ jobs:
cd jax
python -m pip install -r build/test-requirements.txt
"C:\\msys64\\;C:\\msys64\\usr\\bin\\;" >> $env:GITHUB_PATH
python.exe build\build.py ('--bazel_options=--override_repository=xla=${{ github.workspace }}\xla' -replace '\\','\\')
python.exe build\build.py ('--bazel_options=--override_repository=xla=${{ github.workspace }}\xla' -replace '\\','\\') --bazel_options=--color=yes
- uses: actions/upload-artifact@v3
with:
Expand Down
8 changes: 6 additions & 2 deletions tests/xla_bridge_test.py
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import os
import platform
import time
import warnings

Expand Down Expand Up @@ -92,7 +93,10 @@ def _mock_tpu_client():

def test_register_plugin(self):
with self.assertLogs(level="WARNING") as log_output:
os.environ['PJRT_NAMES_AND_LIBRARY_PATHS'] = "name1:path1,name2:path2,name3"
if platform.system() == "windows":
os.environ['PJRT_NAMES_AND_LIBRARY_PATHS'] = "name1;path1,name2;path2,name3"
else:
os.environ['PJRT_NAMES_AND_LIBRARY_PATHS'] = "name1:path1,name2:path2,name3"
xb.register_pjrt_plugin_factories_from_env()
client_factory, priotiy = xb._backend_factories["name1"]
with mock.patch.object(xc, "make_c_api_client", autospec=True) as mock_make:
Expand All @@ -104,7 +108,7 @@ def test_register_plugin(self):
self.assertRegex(
log_output[1][0],
r"invalid value name3 in env var PJRT_NAMES_AND_LIBRARY_PATHS"
r" name1:path1,name2:path2,name3",
r" name1.path1,name2.path2,name3",
)
self.assertIn("name1", xb._backend_factories)
self.assertIn("name2", xb._backend_factories)
Expand Down

0 comments on commit a05ffab

Please sign in to comment.