In [7]:
import jax
import jax.numpy as jnp
from jax import grad
from jax import vjp
from jcm.humidity import get_qsat, rel_hum_to_spec_hum, spec_hum_to_rel_hum
from jcm.physics_data import ConvectionData
from jcm.physics import PhysicsData, PhysicsState
from jcm.params import kx 

# Note:  fix the functions because they're gonna be different now
# Use the unit tests to see how to run the functions and then use those to rework the 
# testing notebooks that I made
# Easy to: some just update the physics, ex: clouds might be good to try (gradient should
# always be 0)
# Try not to do 64 bit

### Notes about additional JAX capabilities
- jax.jvp is the forward differentiation method, but I don't really know how it works or what we would use it for
- jax.vjp is good... but it sums over the derivative values for each function partial, so we can only get the gradient with respect to the parameters for the entire function outputs
-jax.jacobian I think will give us all the partials and we can use that if we need it
- I'm assuming we will be taking gradients of cost functions (or forward functions that concatinate their output into a single vector)

In [5]:
#Testing functions in humidity.py

def get_qsat_gradient(f, ta, ps, sig): 
    primals, fun_vjp = vjp(f, ta, ps, sig)
    input = jnp.ones((ta.shape))
    dqsat_dta, dqsat_dps, dqsat_dsig = fun_vjp(input)
    return dqsat_dta, dqsat_dps, dqsat_dsig

def get_rel_hum_to_spec_hum(f, ta, ps, sig, rh): 
    primals, fun_vjp = vjp(f, ta, ps, sig, rh)
    input = (jnp.ones(ta.shape), jnp.ones(ta.shape))
    df_dta, df_dps, df_dsig, df_drh = fun_vjp(input)
    return df_dta, df_dps, df_dsig, df_drh




In [27]:
#Testing values for input variables
# Initialize test parameters here if needed
ta_ = jnp.ones((96,48,8))*273.0
ps_ = jnp.ones((96,48))*0.5  #supposed to be 2D 
sig_ = 4  #Going to assume that we need sigma to be an integer bc it gets passed in as 
          #an index... can always change it later....  
qg_ = jnp.ones((96,48,8))*2.0

convection_data = ConvectionData((96,48), kx, psa=ps_)
physics_data = PhysicsData((96,48), kx, convection=convection_data)
rh_ = physics_data.humidity.rh[:,:,0]

dqsat_dta, dqsat_dps, dqsat_dsig = get_qsat_gradient(get_qsat, ta_[:,:,sig_], ps_, sig_)
print((dqsat_dta, dqsat_dps, dqsat_dsig))

df_dta, df_dps, df_dsig, _ = get_rel_hum_to_spec_hum(rel_hum_to_spec_hum, ta_[:,:,0], ps_, sig_, rh_)
print(df_dta)

[[0.15499867 0.15499867 0.15499867 ... 0.15499867 0.15499867 0.15499867]
 [0.15499867 0.15499867 0.15499867 ... 0.15499867 0.15499867 0.15499867]
 [0.15499867 0.15499867 0.15499867 ... 0.15499867 0.15499867 0.15499867]
 ...
 [0.15499867 0.15499867 0.15499867 ... 0.15499867 0.15499867 0.15499867]
 [0.15499867 0.15499867 0.15499867 ... 0.15499867 0.15499867 0.15499867]
 [0.15499867 0.15499867 0.15499867 ... 0.15499867 0.15499867 0.15499867]]


In [27]:
####### Mini test functions ########
def my_func(y, x): 
    return y**2*x**2, y**2  #When there are multiple outputs, gradient gets summed over 
                         #Might need to you forward auto-differentiation?
                         #Maybe also check old laptop with other examples of using jax.vjp

def grad_my_func(f, y, x): 
    primals, f_vjp = vjp(f, y, x)   #primals are the values of the function evaluated at the variables
    print(primals)
    input = (jnp.ones(y.shape), jnp.ones(x.shape))
    ybar, xbar= f_vjp(input)
    return ybar, xbar  #so here I think it's summing over all the partials for y 
                       # and all the partials for x and returns that?

xx = 3*jnp.ones(3)
yy = 3*jnp.ones(3)
ybar, xbar = grad_my_func(my_func, yy, xx)

print(jax.jvp(my_func, (yy, xx), my_func(yy, xx)))
print(ybar)
print(xbar)

(Array([81., 81., 81.], dtype=float32), Array([9., 9., 9.], dtype=float32))
((Array([81., 81., 81.], dtype=float32), Array([9., 9., 9.], dtype=float32)), (Array([4860., 4860., 4860.], dtype=float32), Array([486., 486., 486.], dtype=float32)))
[60. 60. 60.]
[54. 54. 54.]


In [12]:
####### Mini test functions ########
def my_func(x): 
    return x**2, x[:-1]

def grad_my_func(f, x): 
    primals, f_vjp = vjp(f, x)
    print(primals)
    input = jnp.ones(x.shape), jnp.ones(x[:-1].shape)
    xbar = f_vjp(input)
    return xbar

xx = 2*jnp.ones(3)
yy = jnp.ones(3)
xbar = grad_my_func(my_func, xx)
print(xbar)

#Notes: Figure out how to get the gradient of both function outputs with respect to the input variable

(Array([4., 4., 4.], dtype=float32), Array([2., 2.], dtype=float32))
(Array([5., 5., 4.], dtype=float32),)


In [13]:
##### Nested Functions working ##### 

def g(x, y): 
    return 2*my_func(x, y)

def grad_g(x, y): 
    primals, f_vjp = vjp(g, x, y)
    print(primals)
    input = jnp.ones(x.shape)
    xbar, ybar = f_vjp(input)
    return xbar, ybar

xbar, ybar = grad_g(xx, yy)
print(xbar)
print(ybar)


TypeError: my_func() takes 1 positional argument but 2 were given