## Model Gradient Test

By: Rebecca Gjini 

About: This notebook will cover testing taking the gradient over the entire SpeedyModel.  

Notes:
1) Shape of the output state goes from 3 dimensions to 4, with an extra dimension added to the beginning (of size 0)
2) sim_time in the dinosaur object in the preditions dictionary object is empty in the model output
3) lfluxland and date.tyear are also empty arrays in the predictions dictionary object (not sure exactly why) and this makes our PhysicsData ones object not work as the input to the gradient function

In [1]:
from jcm.model import SpeedyModel
import jax
import jax.numpy as jnp
from dinosaur import primitive_equations
from jcm.physics_data import PhysicsData

jax.config.update('jax_disable_jit', False) # Turn off JIT because of an issue in shortwave_radiation.py:169
jax.config.update("jax_debug_infs", True) # doesn't add any time since the saved time is otherwise spent getting the nodal quantities
jax.config.update("jax_debug_nans", False) # some physics fields might be nan

In [None]:
model = SpeedyModel(time_step=30, save_interval=3, total_time=(2/48.0), layers=8) # takes 40 seconds on laptop gpu
state = model.get_initial_state()
final_state, predictions = model.unroll(state)
print(type(final_state))
print(type(predictions))
print(primitive_equations.validate_state_shape(final_state, model.coords))

<class 'dinosaur.primitive_equations.StateWithTime'>
<class 'dict'>
None
(1, 85, 44)


In [50]:
print(predictions.keys())
print(type(predictions['dynamics']))
print(type(predictions['physics']))

dict_keys(['dynamics', 'physics'])
<class 'dinosaur.primitive_equations.StateWithTime'>
<class 'jcm.physics_data.PhysicsData'>


In [47]:
def make_ones_dinosaur_StateWithTime_object(state, choose_sim_time = jnp.float32(1.0)):
    vorticity = jnp.ones_like(state.vorticity)
    divergence = jnp.ones_like(state.divergence)
    temperature_variation = jnp.ones_like(state.temperature_variation)
    log_surface_pressure = jnp.ones_like(state.log_surface_pressure)
    tracers = {'specific_humidity' : jnp.ones_like(state.tracers['specific_humidity'])}
    sim_time = choose_sim_time
    return primitive_equations.StateWithTime(vorticity = vorticity, divergence = divergence,
                                             temperature_variation = temperature_variation,
                                             log_surface_pressure = log_surface_pressure,
                                             tracers = tracers, sim_time = sim_time)

def make_ones_prediction_object(pred): 
    (additional, ix, il, kx) = pred['physics'].shortwave_rad.dfabs.shape
    physics_data = PhysicsData.ones((additional, ix, il), kx)
    physics_data.surface_flux.lfluxland = jnp.array([])
    physics_data.date.tyear = jnp.array([])
    return{'dynamics': make_ones_dinosaur_StateWithTime_object(pred['dynamics'], jnp.array([])), 
           'physics' : physics_data}  

In [5]:
test1 = make_ones_dinosaur_StateWithTime_object(final_state)
test2 = make_ones_prediction_object(predictions)

In [6]:
primals, f_vjp = jax.vjp(model.unroll, state) 

In [48]:
input = (make_ones_dinosaur_StateWithTime_object(primals[0]), make_ones_prediction_object(primals[1]))
df_dstate = f_vjp(input) 

In [49]:
print(make_ones_prediction_object(primals[1])['physics'].date)
print(primals[1]['physics'].date)

DateData(tyear=Array([], shape=(0,), dtype=float32))
DateData(tyear=Array([], shape=(0,), dtype=float32))


In [5]:
print(jnp.isnan(df_dstate[0].vorticity))
print(jnp.isnan(df_dstate[0].divergence))
print(jnp.isnan(df_dstate[0].temperature_variation))
print(jnp.isnan(df_dstate[0].log_surface_pressure))

[[[False False False ... False False False]
  [False False False ... False False False]
  [False False False ... False False False]
  ...
  [False False False ... False False False]
  [False False False ... False False False]
  [False False False ... False False False]]

 [[False False False ... False False False]
  [False False False ... False False False]
  [False False False ... False False False]
  ...
  [False False False ... False False False]
  [False False False ... False False False]
  [False False False ... False False False]]

 [[False False False ... False False False]
  [False False False ... False False False]
  [False False False ... False False False]
  ...
  [False False False ... False False False]
  [False False False ... False False False]
  [False False False ... False False False]]

 ...

 [[False False False ... False False False]
  [False False False ... False False False]
  [False False False ... False False False]
  ...
  [False False False ... False False Fal

In [6]:
print(type(df_dstate[0].tracers))
print(df_dstate[0].tracers.keys())
print(df_dstate[0].tracers['specific_humidity'])

<class 'dict'>
dict_keys(['specific_humidity'])
[[[ 3.5213697e-04  3.8431087e-04  8.8828390e-05 ... -8.0235037e-12
   -2.7421168e-11 -3.6952073e-11]
  [ 0.0000000e+00 -4.3038293e-04 -5.9442065e-04 ... -2.4946635e-11
   -6.0831808e-12  1.8411803e-11]
  [ 0.0000000e+00 -1.5664661e-04 -2.1635143e-04 ... -7.1181433e-12
   -1.0811508e-12  6.9932363e-12]
  ...
  [ 0.0000000e+00  0.0000000e+00  0.0000000e+00 ... -1.6312738e-14
   -5.7825530e-14 -1.4095756e-13]
  [ 0.0000000e+00  0.0000000e+00  0.0000000e+00 ...  0.0000000e+00
    2.8841819e-15  2.6970626e-15]
  [ 0.0000000e+00  0.0000000e+00  0.0000000e+00 ...  0.0000000e+00
    2.5424914e-15  1.2648489e-14]]

 [[ 3.5213697e-04  3.8431087e-04  8.8828390e-05 ... -8.0235037e-12
   -2.7421168e-11 -3.6952073e-11]
  [ 0.0000000e+00 -4.3038293e-04 -5.9442065e-04 ... -2.4946635e-11
   -6.0831808e-12  1.8411803e-11]
  [ 0.0000000e+00 -1.5664661e-04 -2.1635143e-04 ... -7.1181433e-12
   -1.0811508e-12  6.9932363e-12]
  ...
  [ 0.0000000e+00  0.0000000e

In [7]:
print(df_dstate[0].sim_time)

0.0


In [None]:
# Once the updated model version is up to date, use this setup
# def derivative_model(parameters):
#     model = SpeedyModel(
#         parameters
#     )
    
#     state = model.get_initial_state()
            
#     final_state, predictions = model.unroll(state)
#     return predictions


# def derivative_model(state):
#     model = SpeedyModel(
#     )
    
#     final_state, predictions = model.unroll(state)
#     return predictions


# # get an initial state
# state = SpeedyModel().get_initial_state()


# grad_jcm = vjp(derivative_model, state)
# d_predictions_d_parameters = grad_jcm()
