Skip to content

Commit

Permalink
Merge pull request #21202 from superbobry:pallas
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 633176367
  • Loading branch information
jax authors committed May 13, 2024
2 parents 1fed784 + 8094d0d commit 54ca3d4
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions tests/pallas/pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import functools
import itertools
import os
import sys
import unittest

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.5"
Expand All @@ -33,7 +34,6 @@
from jax._src.lax.control_flow.for_loop import for_loop
from jax._src.pallas.pallas_call import _trace_to_jaxpr
from jax.experimental import pallas as pl
from jax.experimental.pallas import gpu as plgpu
from jax.experimental.pallas.ops import attention
from jax.experimental.pallas.ops import layer_norm
from jax.experimental.pallas.ops import rms_norm
Expand All @@ -42,6 +42,11 @@
import jax.numpy as jnp
import numpy as np

if sys.platform != "win32":
from jax.experimental.pallas import gpu as plgpu
else:
plgpu = None


# TODO(sharadmv): Update signatures of pallas_call to correct inputs/outputs.
# pylint: disable=no-value-for-parameter
Expand Down Expand Up @@ -129,7 +134,9 @@ def setUp(self):
self.skipTest("Only works on GPU")
if (jtu.test_device_matches(["cuda"]) and
not jtu.is_cuda_compute_capability_at_least("8.0")):
self.skipTest("Only works on GPUs with capability >= sm80")
self.skipTest("Only works on GPU with capability >= sm80")
if sys.platform == "win32":
self.skipTest("Only works on non-Windows platforms")

super().setUp()
_trace_to_jaxpr.cache_clear()
Expand Down

0 comments on commit 54ca3d4

Please sign in to comment.