Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lambda function argument leads to segfault; Partial alternative works #14978

Open
HHalva opened this issue Mar 14, 2023 · 1 comment
Open

lambda function argument leads to segfault; Partial alternative works #14978

HHalva opened this issue Mar 14, 2023 · 1 comment
Labels
bug Something isn't working

Comments

@HHalva
Copy link

HHalva commented Mar 14, 2023

Description

Caveat: Unfortunately, I have not been able to produce minimal working example of below issue as I only encounter this deep inside a jitted training loop containing several nested vmaps etc.

Define

   @partial(jit, static_argnames=['A_fun'])                                                                                                                                                                     
   def fsai2(A_fun, G0):                                                                                                                                                                                        
       n = G0.shape[0]                                                                                                                                                                                          
                                                                                                                                                                                                                
       def _calc_G_i(i, G0_i):                                                                                                                                                                                  
           d_ii = jnp.dot(G0_i, A_fun(G0_i))**-0.5                                                                                                                                                              
           return d_ii                                                                                                                                                                                          
                                                                                                                                                                                                                
       G = vmap(_calc_G_i, (0, 0))(jnp.arange(n), G0)                                                                                                                                                           
       return G    

then

     A_mvp = lambda b: A@b   
     P = fsai2(A_mvp, jnp.eye(K.shape[0]))

gives
Segmentation fault (core dumped).

Similarly,

     from functools import partial
     A_mvp = partial(jnp.matmul, A)  
     P = fsai2(A_mvp, jnp.eye(K.shape[0]))

gives
Segmentation fault (core dumped).

As does this:

     from jax.tree_util import Partial
     A_mvp = partial(jnp.matmul, A)  
     P = fsai2(A_mvp, jnp.eye(K.shape[0]))

However, the last block of code with tree_util.Partial, works and gives expected results if I remove this line @partial(jit, static_argnames=['A_fun']) decorator from the fsai2 function definition, i.e. this works:

   def fsai2(A_fun, G0):                                                                                                                                                                                        
       n = G0.shape[0]                                                                                                                                                                                          
                                                                                                                                                                                                                
       def _calc_G_i(i, G0_i):                                                                                                                                                                                  
           d_ii = jnp.dot(G0_i, A_fun(G0_i))**-0.5                                                                                                                                                              
           return d_ii                                                                                                                                                                                          
                                                                                                                                                                                                                
       G = vmap(_calc_G_i, (0, 0))(jnp.arange(n), G0)                                                                                                                                                           
       return G    


  A_mvp = Partial(jnp.matmul, A)
  fsai2(A_mvp, jnp.eye(K.shape[0])) 

I tried this last way as it was mentioned here #1443 . Also relevant possibly #5609.

What jax/jaxlib version are you using?

0.4.1 both

Which accelerator(s) are you using?

GPU

Additional system info

3.9.12 Python

NVIDIA GPU info

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.60.13 Driver Version: 525.60.13 CUDA Version: 12.0 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA A100-SXM... On | 00000000:01:00.0 Off | 0 |
| N/A 36C P0 62W / 500W | 0MiB / 81920MiB | 0% E. Process |
| | | Disabled |
+-------------------------------+----------------------+----------------------+
| 1 NVIDIA A100-SXM... On | 00000000:41:00.0 Off | 0 |
| N/A 36C P0 60W / 500W | 0MiB / 81920MiB | 0% E. Process |
| | | Disabled |
+-------------------------------+----------------------+----------------------+
| 2 NVIDIA A100-SXM... On | 00000000:81:00.0 Off | 0 |
| N/A 36C P0 58W / 500W | 0MiB / 81920MiB | 0% E. Process |
| | | Disabled |
+-------------------------------+----------------------+----------------------+
| 3 NVIDIA A100-SXM... On | 00000000:C1:00.0 Off | 0 |
| N/A 35C P0 61W / 500W | 0MiB / 81920MiB | 0% Default |
| | | Disabled |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| No running processes found |
+-----------------------------------------------------------------------------+

@HHalva HHalva added the bug Something isn't working label Mar 14, 2023
@HHalva HHalva closed this as completed Mar 14, 2023
@HHalva HHalva reopened this Mar 14, 2023
@HHalva HHalva closed this as completed Mar 14, 2023
@HHalva
Copy link
Author

HHalva commented Mar 14, 2023

I spoke too soon, issue still exists

@HHalva HHalva reopened this Mar 14, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant