## Model Gradient Test

By: Rebecca Gjini 

About: This notebook will cover testing taking the gradient over the entire SpeedyModel.  

Notes:
1) Shape of the output state goes from 3 dimensions to 4, with an extra dimension added to the beginning (of size 0)
2) sim_time in the dinosaur object in the preditions dictionary object is empty in the model output
3) lfluxland and date.tyear are also empty arrays in the predictions dictionary object (not sure exactly why) and this makes our PhysicsData ones object not work as the input to the gradient function

In [None]:
from jcm.model import Model
import jax
import jax.numpy as jnp
from dinosaur import primitive_equations
from jcm.physics.speedy.params import Parameters

jax.config.update('jax_disable_jit', False) # Turn off JIT because of an issue in shortwave_radiation.py:169
jax.config.update("jax_debug_infs", True) # doesn't add any time since the saved time is otherwise spent getting the nodal quantities
jax.config.update("jax_debug_nans", False) # some physics fields might be nan

In [None]:
# Create model object and test to make sure running the model works
model = Model()
state = model._prepare_initial_state()
predictions = model.run(save_interval=(1.0/48.0), total_time=(1.0/24.0))
final_state = model._final_state_internal
print(type(final_state))
print(type(predictions))
print(primitive_equations.validate_state_shape(final_state, model.coords))

<class 'dinosaur.primitive_equations.State'>
<class 'dict'>
None


In [None]:
print(predictions.keys())
print(type(predictions.dynamics))
print(type(predictions.physics))

In [None]:
# Call to vjp to create gradient function and model output
def fn(state):
    predictions = model.run(initial_state=state)
    return model._final_state_internal, predictions
primals, f_vjp = jax.vjp(fn, state) 

In [None]:
from jcm.utils import ones_like
# Calculate gradient with respect to the state
input = (ones_like(primals[0]), ones_like(primals[1]))
df_dstate = f_vjp(input)

In [None]:
print('Is the vorticity nan?', jnp.any(jnp.isnan(df_dstate[0].vorticity)))
print('Is the divergence nan?', jnp.any(jnp.isnan(df_dstate[0].divergence)))
print('Is the temperature variation nan?', jnp.any(jnp.isnan(df_dstate[0].temperature_variation)))
print('Is the log surcafe pressure nan?', jnp.any(jnp.isnan(df_dstate[0].log_surface_pressure)))

In [None]:
print(type(df_dstate[0].tracers))
print(df_dstate[0].tracers.keys())
print('Is the specific humidity nan?', jnp.any(jnp.isnan(df_dstate[0].tracers['specific_humidity'])))

In [None]:
print('Is the simulation time nan?', jnp.any(jnp.isnan(df_dstate[0].sim_time)))

### This next part of the notebook is to test the gradients with respect to the parameters

In [None]:
from jcm.physics.speedy.speedy_physics import SpeedyPhysics
# Create function to run the forward model with respect to the parameters
def run_model_wrt_parameters(parameters):
    model = Model(physics=SpeedyPhysics(parameters=parameters))
    predictions = model.run(state)
    return model._final_state_internal, predictions


In [None]:
# Set parameters object using the default parameter values
parameters = Parameters.default()

In [None]:
# Taking the gradient with respect to the default parameters
jcm_primals, grad_jcm = jax.vjp(run_model_wrt_parameters, parameters)

In [None]:
input = (ones_like(jcm_primals))
df_dparameters = grad_jcm(input)