In [None]:
import jax.numpy as jnp
import xarray as xr
from jcm.boundaries import boundaries_from_file
from jcm.physics_interface import PhysicsState

realistic_orography = jnp.asarray(xr.open_dataarray('../jcm/data/bc/t30/clim/orography.nc'))
realistic_boundaries = boundaries_from_file('../jcm/data/bc/t30/clim/boundaries.nc')


In [None]:
from jcm.model import Model

model = Model(
    orography=realistic_orography,
)
predictions = model.run(
    save_interval=5,
    total_time=30,
    output_averages=True,
    boundaries=realistic_boundaries,
)

In [None]:
pred_ds = model.predictions_to_xarray(predictions)

In [None]:
pred_ds

In [None]:
pred_ds['normalized_surface_pressure'].plot(x='lon', y='lat', col='time', col_wrap=2, aspect=2)

In [None]:
pred_ds['u_wind'].mean('lon').plot(x='lat', y='level', col='time', col_wrap=3, aspect=6, yincrease=False)
pred_ds['u_wind'].isel(level=-1).plot(x='lon', y='lat', col='time', col_wrap=3, aspect=2)

In [None]:
pred_ds['specific_humidity'].mean('lon').plot(x='lat', y='level', col='time', col_wrap=3, aspect=6, yincrease=False)
pred_ds['specific_humidity'].isel(level=3).plot(x='lon', y='lat', col='time', col_wrap=3, aspect=2)

### clouds!

In [None]:
pred_ds['shortwave_rad.cloudc'].plot(x='lon', y='lat', col='time', col_wrap=3, aspect=2)
pred_ds['shortwave_rad.qcloud'].plot(x='lon', y='lat', col='time', col_wrap=3, aspect=2)
pred_ds['shortwave_rad.icltop'].plot(x='lon', y='lat', col='time', col_wrap=3, aspect=2)
pred_ds['shortwave_rad.cloudstr'].plot(x='lon', y='lat', col='time', col_wrap=3, aspect=2)

## Continue running

Continue a simulation from the previous state for an additional 30 days, saving every 5 days and outputting averages.


In [None]:
model.resume(
    total_time=30,
    output_averages=True,
    boundaries=realistic_boundaries,
)

Sometimes, for example in a coupled run we want a pure JAX interface to enable compilation of a larger model. We can do this using the `run_from_state` method:

In [None]:
model.run_from_state(
    initial_state=model._final_modal_state,
    total_time=30,
    output_averages=True,
    boundaries=realistic_boundaries,
)