### Testing optimization setup
By: Rebecca Gjini
03/21/2025

In [1]:
from jcm.model import SpeedyModel
import jax
import jax.numpy as jnp
from jcm.params import Parameters
import jcm.optimization as opt

In [None]:
# Run forward function for true paramters
true_params = Parameters.default()  # True parameters are default parameters 
model = opt.create_model(true_params) # Create model that will run for 5 days
state = model.get_initial_state()
final_state, predictions = model.unroll(state)  # run model forward

In [None]:
# Create synthetic data
y = final_state.temperature_variation.flatten()
R_inv_sqrt = 100*jnp.ones_like(y)

(16632,)


In [None]:
# Set paramters to estimate
params = Parameters.default()  # Set all other paramters
theta = jnp.array([jnp.array(0.19)])  # Choose initial guess for estimated parameter/s
theta_keys = {"mod_radcon": ["albsea"]} # Define parameters to be estimated

In [None]:
# Define hyperparameters
step_size = 0.1
num_iters = 2

In [None]:
# Optimization loop
for i in range(num_iters):
    grad = opt.grad_fn(theta, opt.forward_model_wrapper, y, R_inv_sqrt, 
                         args = (theta_keys, state, params)) # Compute gradient
    theta -= step_size * grad  # Gradient descent update
    print(f"Iteration {i+1}: theta = {theta}")

print("Optimized theta:", theta)

In [8]:
# Evaluate cost function
loss = opt.loss_function(theta, opt.forward_model_wrapper, y, R_inv_sqrt, 
                         args = (theta_keys, state, params))
print(loss)

0.24769543


In [None]:
params = Parameters.default()
theta = jnp.array([jnp.array(4.0), jnp.array(0.9), jnp.array(0.43)])
theta_keys = {"condensation": ["trlsc", "rhlsc"], "shortwave_radiation": ["albcl"]}

new_params = opt.forward_model_wrapper(theta, theta_keys)

from pprint import pprint
def to_readable_format(x):
    if isinstance(x, jnp.ndarray):
        return x.tolist()
    return x
pprint(jax.tree_util.tree_map(to_readable_format, new_params))