In [1]:
from dinosaur.xarray_utils import data_to_xarray
from dinosaur import primitive_equations as pe
from dinosaur import primitive_equations_states as pes
from dinosaur import spherical_harmonic as sh
from dinosaur.scales import units

from jcm.model import SpeedyModel

import numpy as np
import jax

jax.config.update('jax_disable_jit', False) # Turn off JIT because of an issue in shortwave_radiation.py:169
jax.config.update("jax_debug_infs", True) # doesn't add any time since the saved time is otherwise spent getting the nodal quantities
jax.config.update("jax_debug_nans", False) # some physics fields might be nan


In [2]:
model = SpeedyModel(time_step=120, save_interval=1, total_time=1, layers=7)

state = model.get_initial_state()

final_state, predictions = model.unroll(state)

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [3]:
predictions.keys()

dict_keys(['div_nodal', 'evap', 'gse', 'icltop', 'iptop', 'lwftop', 'precls', 'slr', 'specific_humidity_nodal', 'ssr', 'swftop', 't_nodal', 'u_nodal', 'v_nodal', 'vor_nodal'])

In [8]:
u_nodal, v_nodal = sh.vor_div_to_uv_nodal(model.coords.horizontal, final_state.vorticity, final_state.divergence)

# This doesn't currently work for predictions because of the explicit padding shape in there
# w_nodal = -pe.compute_vertical_velocity(predictions, coords)

nodal_predictions = pe.compute_diagnostic_state(final_state, model.coords)


In [9]:
# Do the unit conversion before converting to xarray
t_ref = 288 #K
nodal_predictions.temperature_variation = t_ref + model.physics_specs.dimensionalize(final_state.temperature_variation, units.kelvin).m

nodal_predictions.tracers['specific_humidity'] = model.physics_specs.dimensionalize(final_state.tracers['specific_humidity'], units.gram / units.kilogram).m

In [12]:
broken_keys = ['cos_lat_u', 'cos_lat_grad_log_sp', 'cos_lat_grad_log_sp', ] # These are tuples which are not supported by xarray
broken_keys += ['sigma_dot_explicit', 'sigma_dot_full'] # These only have four time steps for some reason...
pred_ds = data_to_xarray({k: v for k, v in predictions.items() if k not in broken_keys}, 
                         coords=model.coords, times=model.times)

pred_ds = pred_ds.rename_vars({'temperature_variation': 'temperature'})

# Skip this for now
# log_surface_pressure_nodal = coords.horizontal.to_nodal(final_state.log_surface_pressure)
# surface_pressure_nodal = np.exp(log_surface_pressure_nodal) * 1e5

pred_ds['u'] = data_to_xarray({'u': model.physics_specs.dimensionalize(np.asarray(u_nodal), units.meter / units.second).m}, coords=model.coords, times=model.times)['u']
pred_ds['v'] = data_to_xarray({'v': model.physics_specs.dimensionalize(np.asarray(v_nodal), units.meter / units.second).m}, coords=model.coords, times=model.times)['v']

# Flip the vertical dimension so that it goes from the surface to the top of the atmosphere
pred_ds = pred_ds.isel(level=slice(None, None, -1))

pred_ds

ValueError: Value of shape (1, 128, 64, 3) is not in shape_to_dims={(1,): ('time',), (1, 7, 85, 44): ('time', 'level', 'longitudinal_mode', 'total_wavenumber'), (1, 7, 128, 64): ('time', 'level', 'lon', 'lat'), (1, 128, 64): ('time', 'lon', 'lat'), (1, 85, 44): ('time', 'longitudinal_mode', 'total_wavenumber'), (1, 1, 128, 64): ('time', 'surface', 'lon', 'lat'), (1, 1, 85, 44): ('time', 'surface', 'longitudinal_mode', 'total_wavenumber'), (1, 1): ('time', 'surface')}

In [None]:
pred_ds['u'].mean('lon').plot(x='lat', y='level', col='time', col_wrap=4)

In [None]:
pred_ds['v'].mean('lon').plot(x='lat', y='level', col='time', col_wrap=4)

In [None]:
pred_ds['temperature'].mean('lon').plot(x='lat', y='level', col='time', col_wrap=4)

In [None]:
pred_ds['specific_humidity'].mean('lon').plot(x='lat', y='level', col='time', col_wrap=4)

In [None]:
pred_ds['specific_humidity'].isel(level=0).plot(x='lon', y='lat', col='time', col_wrap=4)