## Model Gradient Test

By: Rebecca Gjini 

About: This notebook will cover testing taking the gradient over the entire SpeedyModel.  

Notes for self: 
1) import requirements 
2) setup a run through 1 set of the model with the default initial state
3) setup wrapper function around creating the model and running one timestep (with advance())
4) maybe setup another wrapper around going from our PhysicsState/PhysicsData objects to the dinosaur state object
5) try computing gradients

In [1]:
from jcm.model import SpeedyModel
import jax

jax.config.update('jax_disable_jit', False) # Turn off JIT because of an issue in shortwave_radiation.py:169
jax.config.update("jax_debug_infs", True) # doesn't add any time since the saved time is otherwise spent getting the nodal quantities
jax.config.update("jax_debug_nans", False) # some physics fields might be nan

In [5]:
model = SpeedyModel(time_step=30, save_interval=3, total_time=(1/48.0), layers=8) # takes 40 seconds on laptop gpu
state = model.get_initial_state()
final_state, predictions = model.unroll(state)
testing = model.post_process(state)
print(testing["physics"])
print(type(testing["physics"]))

PhysicsData(shortwave_rad=SWRadiationData(qcloud=Array([[1.9349747e-08, 8.3812452e-08, 1.6057885e-07, ..., 8.4109897e-06,
        1.4494001e-06, 2.5418058e-07],
       [1.9707100e-08, 8.3610104e-08, 1.5806894e-07, ..., 9.2049295e-06,
        1.5441968e-06, 2.6046598e-07],
       [2.0202636e-08, 8.3787363e-08, 1.5559841e-07, ..., 9.9404087e-06,
        1.6298264e-06, 2.6575592e-07],
       ...,
       [1.9136555e-08, 8.6413081e-08, 1.6877374e-07, ..., 5.9279378e-06,
        1.1334106e-06, 2.3024766e-07],
       [1.9063295e-08, 8.5250072e-08, 1.6587219e-07, ..., 6.7493606e-06,
        1.2416781e-06, 2.3897601e-07],
       [1.9134461e-08, 8.4369844e-08, 1.6316152e-07, ..., 7.5843404e-06,
        1.3477736e-06, 2.4698622e-07]], dtype=float32), fsol=Array([[555.0461 , 553.40027, 550.4418 , ...,   0.     ,   0.     ,
          0.     ],
       [555.0461 , 553.40027, 550.4418 , ...,   0.     ,   0.     ,
          0.     ],
       [555.0461 , 553.40027, 550.4418 , ...,   0.     ,   0.     ,
 

In [None]:
# Once the updated model version is up to date, use this setup
# def derivative_model(parameters):
#     model = SpeedyModel(
#         parameters
#     )
    
#     state = model.get_initial_state()
            
#     final_state, predictions = model.unroll(state)
#     return predictions


# def derivative_model(state):
#     model = SpeedyModel(
#     )
    
#     final_state, predictions = model.unroll(state)
#     return predictions


# # get an initial state
# state = SpeedyModel().get_initial_state()


# grad_jcm = vjp(derivative_model, state)
# d_predictions_d_parameters = grad_jcm()
