Gradients are nan for reasonable parameters in quasisep.Carma (pull #90) #164
Replies: 2 comments
-
@nevencaplar I can confirm the problem comes from the |
Beta Was this translation helpful? Give feedback.
-
I have not had enough time to solve it but, I am pretty sure that problems are coming from h2 = jnp.sqrt(h2_2) and from square root in line h1 = (c * h2 - jnp.sqrt(a * d2 - s2 * h2_2)) / (d + eta * real_mask). I presume that roots become numerically negative? when values are small (they are all 0. in the examples I tried), this produces I will continue looking into it tomorrow. P.S. For the longest time I thought it was connected to this issue https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where, because |
Beta Was this translation helpful? Give feedback.
-
@ywx649999311 and others I hope you can have a look. I have been testing the pull #90. As one of the tests, I generated a simple lightcurve with a damped random walk process in EzTao and tried to fit it with tinygp. I am unable to find a gradient for the model for any reasonable parameters, which leads to wrong solutions when trying to find optimal parameters. I also tried a bit more complex models (CARMA (2,0)), but I seem to have similar problems. Same problem if not applying jit. Please find below the most minimal example that I have managed to create.
I hesitate to open it as a bug, as this is still only a pull, and it could be that I am doing something wrong. I would greatly appreciate your help.
Beta Was this translation helpful? Give feedback.
All reactions