### Testing optimization setup
By: Rebecca Gjini
03/21/2025

In [1]:
from jcm.model import SpeedyModel
import jax
import jax.numpy as jnp
from jcm.params import Parameters
import jcm.optimization as opt

In [None]:
# Create synthetic data
true_params = Parameters.default()
model = opt.create_model(true_params)
state = model.get_initial_state()
final_state, predictions = model.unroll(state)
y = predictions['dynamics'].temperature_variation.flatten()

print(y.shape)

R_inv_sqrt = 1.0/(0.1*y)

In [None]:
# Set paramter to estimate
params = Parameters.default()
theta = jnp.array([jnp.array(0.09)])
theta_keys = {"mod_radco": ["albsea"]}

In [None]:
# Evaluate cost function
loss = opt.loss_function(theta, opt.forward_model_wrapper, y, R_inv_sqrt, 
                         args = (theta_keys, state, params))
print("loss")

In [None]:
params = Parameters.default()
theta = jnp.array([jnp.array(4.0), jnp.array(0.9), jnp.array(0.43)])
theta_keys = {"condensation": ["trlsc", "rhlsc"], "shortwave_radiation": ["albcl"]}

new_params = opt.forward_model_wrapper(theta, theta_keys)

from pprint import pprint
def to_readable_format(x):
    if isinstance(x, jnp.ndarray):
        return x.tolist()
    return x
pprint(jax.tree_util.tree_map(to_readable_format, new_params))