You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Caveat: Unfortunately, I have not been able to produce minimal working example of below issue as I only encounter this deep inside a jitted training loop containing several nested vmaps etc.
Define
@partial(jit, static_argnames=['A_fun'])
def fsai2(A_fun, G0):
n = G0.shape[0]
def _calc_G_i(i, G0_i):
d_ii = jnp.dot(G0_i, A_fun(G0_i))**-0.5
return d_ii
G = vmap(_calc_G_i, (0, 0))(jnp.arange(n), G0)
return G
then
A_mvp = lambda b: A@b
P = fsai2(A_mvp, jnp.eye(K.shape[0]))
gives Segmentation fault (core dumped).
Similarly,
from functools import partial
A_mvp = partial(jnp.matmul, A)
P = fsai2(A_mvp, jnp.eye(K.shape[0]))
gives Segmentation fault (core dumped).
As does this:
from jax.tree_util import Partial
A_mvp = partial(jnp.matmul, A)
P = fsai2(A_mvp, jnp.eye(K.shape[0]))
However, the last block of code with tree_util.Partial, works and gives expected results if I remove this line @partial(jit, static_argnames=['A_fun']) decorator from the fsai2 function definition, i.e. this works:
def fsai2(A_fun, G0):
n = G0.shape[0]
def _calc_G_i(i, G0_i):
d_ii = jnp.dot(G0_i, A_fun(G0_i))**-0.5
return d_ii
G = vmap(_calc_G_i, (0, 0))(jnp.arange(n), G0)
return G
A_mvp = Partial(jnp.matmul, A)
fsai2(A_mvp, jnp.eye(K.shape[0]))
I tried this last way as it was mentioned here #1443 . Also relevant possibly #5609.
What jax/jaxlib version are you using?
0.4.1 both
Which accelerator(s) are you using?
GPU
Additional system info
3.9.12 Python
NVIDIA GPU info
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.60.13 Driver Version: 525.60.13 CUDA Version: 12.0 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA A100-SXM... On | 00000000:01:00.0 Off | 0 |
| N/A 36C P0 62W / 500W | 0MiB / 81920MiB | 0% E. Process |
| | | Disabled |
+-------------------------------+----------------------+----------------------+
| 1 NVIDIA A100-SXM... On | 00000000:41:00.0 Off | 0 |
| N/A 36C P0 60W / 500W | 0MiB / 81920MiB | 0% E. Process |
| | | Disabled |
+-------------------------------+----------------------+----------------------+
| 2 NVIDIA A100-SXM... On | 00000000:81:00.0 Off | 0 |
| N/A 36C P0 58W / 500W | 0MiB / 81920MiB | 0% E. Process |
| | | Disabled |
+-------------------------------+----------------------+----------------------+
| 3 NVIDIA A100-SXM... On | 00000000:C1:00.0 Off | 0 |
| N/A 35C P0 61W / 500W | 0MiB / 81920MiB | 0% Default |
| | | Disabled |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| No running processes found |
+-----------------------------------------------------------------------------+
The text was updated successfully, but these errors were encountered:
Description
Caveat: Unfortunately, I have not been able to produce minimal working example of below issue as I only encounter this deep inside a jitted training loop containing several nested vmaps etc.
Define
then
gives
Segmentation fault (core dumped)
.Similarly,
gives
Segmentation fault (core dumped)
.As does this:
However, the last block of code with
tree_util.Partial
, works and gives expected results if I remove this line@partial(jit, static_argnames=['A_fun'])
decorator from thefsai2
function definition, i.e. this works:I tried this last way as it was mentioned here #1443 . Also relevant possibly #5609.
What jax/jaxlib version are you using?
0.4.1 both
Which accelerator(s) are you using?
GPU
Additional system info
3.9.12 Python
NVIDIA GPU info
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.60.13 Driver Version: 525.60.13 CUDA Version: 12.0 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA A100-SXM... On | 00000000:01:00.0 Off | 0 |
| N/A 36C P0 62W / 500W | 0MiB / 81920MiB | 0% E. Process |
| | | Disabled |
+-------------------------------+----------------------+----------------------+
| 1 NVIDIA A100-SXM... On | 00000000:41:00.0 Off | 0 |
| N/A 36C P0 60W / 500W | 0MiB / 81920MiB | 0% E. Process |
| | | Disabled |
+-------------------------------+----------------------+----------------------+
| 2 NVIDIA A100-SXM... On | 00000000:81:00.0 Off | 0 |
| N/A 36C P0 58W / 500W | 0MiB / 81920MiB | 0% E. Process |
| | | Disabled |
+-------------------------------+----------------------+----------------------+
| 3 NVIDIA A100-SXM... On | 00000000:C1:00.0 Off | 0 |
| N/A 35C P0 61W / 500W | 0MiB / 81920MiB | 0% Default |
| | | Disabled |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| No running processes found |
+-----------------------------------------------------------------------------+
The text was updated successfully, but these errors were encountered: