### Testing optimization setup
By: Rebecca Gjini
03/21/2025

In [1]:
from jcm.model import SpeedyModel
import jax
import jax.numpy as jnp
from jcm.params import Parameters
import jcm.optimization as opt

In [2]:
# Run forward function for true paramters
true_params = Parameters.default()  # True parameters are default parameters 
model = opt.create_model(true_params) # Create model that will run for 5 days
state = model.get_initial_state()
final_state, predictions = model.unroll(state)  # run model forward

In [None]:
# Create synthetic data
xr_predictions = model.predictions_to_xarray(predictions)
means = {var: xr_predictions[var].mean().values for var in xr_predictions.data_vars}
y = jnp.array(list(means.values()))
R_inv_sqrt = jnp.ones_like(y)

# # Can do varaible check
# mean_dataset = xr_predictions.mean(dim=list(xr_predictions.dims))

In [6]:
# Set paramters to estimate
params = Parameters.default()  # Set all other paramters
theta = jnp.array([jnp.array(0.08)])  # Choose initial guess for estimated parameter/s
theta_keys = {"mod_radcon": ["albsea"]} # Define parameters to be estimated

In [7]:
# Define hyperparameters
step_size = 0.1
num_iters = 1

In [6]:
# Optimization loop
for i in range(num_iters):
    grad = opt.grad_fn(theta, opt.forward_model_wrapper, y, R_inv_sqrt, 
                         args = (theta_keys, state)) # Compute gradient
    print("The gradient is: ", grad[0][0])
    # theta -= step_size * grad  # Gradient descent update
    # print(f"Iteration {i+1}: theta = {theta}")

print("Optimized theta:", theta)

2025-04-03 16:48:35.026704: E external/xla/xla/service/slow_operation_alarm.cc:65] 
********************************
[Compiling module jit_multistep] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
2025-04-03 16:50:21.166725: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 3m46.144529s

********************************
[Compiling module jit_multistep] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************


The gradient is:  -6.6480563e+18
Optimized theta: [0.08]


In [8]:
# Evaluate cost function
loss = opt.loss_function(theta, opt.forward_model_wrapper, y, R_inv_sqrt, 
                         args = (theta_keys, state, params))
print(loss)

9.561319


In [None]:
# Loss function loop
thetas = jnp.arange(0.0, 1.0, 0.025)
losses = []
for ii in range(len(thetas)):
    theta = jnp.array([jnp.array(thetas[ii])])
    losses.append(opt.loss_function(theta, opt.forward_model_wrapper, y, R_inv_sqrt, 
                         args = (theta_keys, state, params)))

In [None]:
# # Example of how to define multiple parameters to estimate
# params = Parameters.default()
# theta = jnp.array([jnp.array(4.0), jnp.array(0.9), jnp.array(0.43)])
# theta_keys = {"condensation": ["trlsc", "rhlsc"], "shortwave_radiation": ["albcl"]}

# new_params = opt.forward_model_wrapper(theta, theta_keys)

# from pprint import pprint
# def to_readable_format(x):
#     if isinstance(x, jnp.ndarray):
#         return x.tolist()
#     return x
# pprint(jax.tree_util.tree_map(to_readable_format, new_params))

In [None]:
# def test_wrapping(theta): 
#     parameters = Parameters.default()
#     ii = 0
#     for attr, params in theta_keys.items():
#         for param in params:
#             setattr(getattr(parameters, attr), param, theta[ii])
#             ii += 1
#     model = opt.create_model(parameters) 
#     state = model.get_initial_state()
#     final_state, predictions = model.unroll(state)
#     return jnp.linalg.norm(y - final_state.temperature_variation.flatten())

# test_theta = jnp.array([jnp.array(0.19)]) 

# primal, f_vjp = jax.vjp(test_wrapping, test_theta)
# df_dparams = f_vjp(opt.make_ones_prediction_object(primal))

# print(df_dparams[0][0])

### Current Notes: 
- Gradients are working for save interval = 1/24 and total time = 5/24, but not working for save interval = 1 and total_time = 5
- Some reason the gradients work on the predictions but not the final state?
- Ideas to try: 
    - Maybe try same total time but smaller save interval
    - Maybe try smaller time_step to reduce chance of instability
    - Check to see how far you can push the total_time 