## Model Gradient Test

By: Rebecca Gjini 

About: This notebook will cover testing taking the gradient over the entire SpeedyModel.  

Notes:
1) Shape of the output state goes from 3 dimensions to 4, with an extra dimension added to the beginning (of size 0)
2) sim_time in the dinosaur object in the preditions dictionary object is empty in the model output
3) lfluxland and date.tyear are also empty arrays in the predictions dictionary object (not sure exactly why) and this makes our PhysicsData ones object not work as the input to the gradient function

In [1]:
from jcm.model import SpeedyModel
import jax
import jax.numpy as jnp
from dinosaur import primitive_equations
from jcm.physics_data import PhysicsData

jax.config.update('jax_disable_jit', False) # Turn off JIT because of an issue in shortwave_radiation.py:169
jax.config.update("jax_debug_infs", True) # doesn't add any time since the saved time is otherwise spent getting the nodal quantities
jax.config.update("jax_debug_nans", False) # some physics fields might be nan

In [2]:
# Create moodel object and test to make sure running the model works
model = SpeedyModel(time_step=30, save_interval=3, total_time=(2/48.0), layers=8) # takes 40 seconds on laptop gpu
state = model.get_initial_state()
final_state, predictions = model.unroll(state)
print(type(final_state))
print(type(predictions))
print(primitive_equations.validate_state_shape(final_state, model.coords))

<class 'dinosaur.primitive_equations.StateWithTime'>
<class 'dict'>
None


In [3]:
print(predictions.keys())
print(type(predictions['dynamics']))
print(type(predictions['physics']))

dict_keys(['dynamics', 'physics'])
<class 'dinosaur.primitive_equations.StateWithTime'>
<class 'jcm.physics_data.PhysicsData'>


In [4]:
# Functions to create ones objects to feed into the jax gradient function
def make_ones_dinosaur_StateWithTime_object(state, choose_sim_time = jnp.float32(1.0)):
    vorticity = jnp.ones_like(state.vorticity)
    divergence = jnp.ones_like(state.divergence)
    temperature_variation = jnp.ones_like(state.temperature_variation)
    log_surface_pressure = jnp.ones_like(state.log_surface_pressure)
    tracers = {'specific_humidity' : jnp.ones_like(state.tracers['specific_humidity'])}
    sim_time = choose_sim_time
    return primitive_equations.StateWithTime(vorticity = vorticity, divergence = divergence,
                                             temperature_variation = temperature_variation,
                                             log_surface_pressure = log_surface_pressure,
                                             tracers = tracers, sim_time = sim_time)

def make_ones_prediction_object(pred): 
    (additional, ix, il, kx) = pred['physics'].shortwave_rad.dfabs.shape
    physics_data = PhysicsData.ones((additional, ix, il), kx)
    physics_data.surface_flux.lfluxland = jnp.array([])
    physics_data.date.tyear = jnp.array([])
    return{'dynamics': make_ones_dinosaur_StateWithTime_object(pred['dynamics'], jnp.array([])), 
           'physics' : physics_data}  

In [5]:
# Call to vjp to create gradient function and model output
primals, f_vjp = jax.vjp(model.unroll, state) 

In [6]:
# Calculate gradient with respect to the state
input = (make_ones_dinosaur_StateWithTime_object(primals[0]), make_ones_prediction_object(primals[1]))
df_dstate = f_vjp(input) 

In [7]:
print('Is the vorticity nan?', jnp.any(jnp.isnan(df_dstate[0].vorticity)))
print('Is the divergence nan?', jnp.any(jnp.isnan(df_dstate[0].divergence)))
print('Is the temperature variation nan?', jnp.any(jnp.isnan(df_dstate[0].temperature_variation)))
print('Is the log surcafe pressure nan?', jnp.any(jnp.isnan(df_dstate[0].log_surface_pressure)))

Is the vorticity nan? False
Is the divergence nan? False
Is the temperature variation nan? False
Is the log surcafe pressure nan? False


In [8]:
print(type(df_dstate[0].tracers))
print(df_dstate[0].tracers.keys())
print('Is the specific humidity nan?', jnp.any(jnp.isnan(df_dstate[0].tracers['specific_humidity'])))

<class 'dict'>
dict_keys(['specific_humidity'])
Is the specific humidity nan? False


In [9]:
print('Is the simulation time nan?', jnp.any(jnp.isnan(df_dstate[0].sim_time)))

Is the simulation time nan? False


In [None]:
# Once the updated model version is up to date, use this setup
# def derivative_model(parameters):
#     model = SpeedyModel(
#         parameters
#     )
    
#     state = model.get_initial_state()
            
#     final_state, predictions = model.unroll(state)
#     return predictions


# def derivative_model(state):
#     model = SpeedyModel(
#     )
    
#     final_state, predictions = model.unroll(state)
#     return predictions


# # get an initial state
# state = SpeedyModel().get_initial_state()


# grad_jcm = vjp(derivative_model, state)
# d_predictions_d_parameters = grad_jcm()
