Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Read/update model states in native (sigma) coordinates #8

Open
jswhit opened this issue Mar 9, 2024 · 10 comments
Open

Read/update model states in native (sigma) coordinates #8

jswhit opened this issue Mar 9, 2024 · 10 comments

Comments

@jswhit
Copy link

jswhit commented Mar 9, 2024

In order to use neuralgcm for data assimilation, we need access to the model variables on sigma levels (plus surface pressure). The api.PressureLevelModel provides access to decoded data on pressure levels. I'd like to add the ability to extract sigma-level data and surface pressure from a prediction, then restart a model prediction using updated sigma-level/surface pressure fields. How would you recommend doing this with the public API?

@nniraj123 @frolovsa

@shoyer
Copy link
Collaborator

shoyer commented Mar 13, 2024

Are you interested in particular version of NeuralGCM?

The internal state of the model (outputs of PressureLevelModel.encode, inputs/outputs of encode and inputs/first output of unroll) are actually already on sigma coordinates, e.g.,

>>> jax.tree_util.tree_map(np.shape, final_state)
ModelState(state=StateWithTime(vorticity=(32, 256, 129), divergence=(32, 256, 129), temperature_variation=(32, 256, 129), log_surface_pressure=(1, 256, 129), sim_time=(), tracers={'specific_cloud_ice_water_content': (32, 256, 129), 'specific_cloud_liquid_water_content': (32, 256, 129), 'specific_humidity': (32, 256, 129)}), memory=StateWithTime(vorticity=(32, 256, 129), divergence=(32, 256, 129), temperature_variation=(32, 256, 129), log_surface_pressure=(1, 256, 129), sim_time=(), tracers={'specific_cloud_ice_water_content': (32, 256, 129), 'specific_cloud_liquid_water_content': (32, 256, 129), 'specific_humidity': (32, 256, 129)}), diagnostics={}, randomness=RandomnessState(core=(256, 128), nodal_value=(256, 128), modal_value=(256, 129)))

These variables are stored as spherical harmonic coefficient, so they need to be transformed back into real space for visualization, e.g., to visualize temperature near the surface in the demo notebook:

temp_variation = neural_gcm_model.model_coords.horizontal.to_nodal(final_state.state.temperature_variation)
xarray.DataArray(temp_variation[-1, :, :], dims=['x', 'y']).plot.imshow(size=4, aspect=2, x='x', y='y')

image

There a few tricks for conversion (e.g., to handle units and reference offsets for temperature). We'll work on documenting these as part of #11.

@jswhit
Copy link
Author

jswhit commented Mar 13, 2024

@shoyer we're most interested in the stochastic version, since we'll be running ensembles for data assimilation. What we need is to:

  1. run the model ensemble for 9 hours, saving the trajectory state every 3 hours (or perhaps 1 hour).
  2. extract the model state on sigma levels from the trajectory, use it compute the "observation equivalents" for observations valid between 3 and 9 hours.
  3. use the model states (in grid space and also in observation space) to compute an update to the model state at 6 hours using the Ensemble Kalman Filter algorithm.
  4. use these updated model states to re-initialize the model, and then rinse and repeat.
    You example looks like exactly what we need to do steps 1-3. For step 4, we need a way to re-encode the updated sigma-level states.

@jswhit
Copy link
Author

jswhit commented Apr 9, 2024

@shoyer we've got an initial version of a cycling EnKF for neuralgcm running and assimilating surface pressure observations. The results suggest that the orography that we are using for the forward observation operator (from aux_features['xarray_dataset']['geopotential_at_surface'] is perhaps not the actual orography the dycore uses (and is theerfore not consistent with the surface pressure in the model state). Does the dycore use a filtered version of that orography, and if so how can I access it?

@shoyer
Copy link
Collaborator

shoyer commented Apr 18, 2024

@jswhit Indeed, for the dycore we use a learned filtered version of orography that is a bit smoother. You can extract it from the trained models:

import dataclasses
import functools
import pickle

import gcsfs
from dinosaur import spherical_harmonic
import haiku as hk
from neuralgcm import api
from neuralgcm import orographies
import numpy as np
import matplotlib.pyplot as plt
import xarray

gcs = gcsfs.GCSFileSystem(token='anon')


@hk.transform
def get_orography():
  base_orography = functools.partial(
      orographies.FilteredCustomOrography,
      orography_data_path=None,
      renaming_dict=dict(longitude='lon', latitude='lat'),
  )
  orography_coords = dataclasses.replace(
        neural_gcm_model.model_coords,
        horizontal=spherical_harmonic.Grid.with_wavenumbers(longitude_wavenumbers=126)
  )
  return orographies.LearnedOrography(
      orography_coords,
      neural_gcm_model._structure.specs.dt,
      neural_gcm_model._structure.specs.physics_specs,
      neural_gcm_model._structure.specs.aux_features,
      base_orography_module=base_orography,
      correction_scale=1e-5,
  )()

model_name = 'neural_gcm_stochastic_1_4_deg_v0.pkl'

with gcs.open(f'gs://gresearch/neuralgcm/03_04_2024/{model_name}', 'rb') as f:
  ckpt = pickle.load(f)

neural_gcm_model = api.PressureLevelModel.from_checkpoint(ckpt)

dycore_coords = spherical_harmonic.Grid.with_wavenumbers(longitude_wavenumbers=126)

learned_correction = neural_gcm_model.params[
    'stochastic_modular_step_model/~/stochastic_physics_parameterization_step/'
    '~/custom_coords_corrector/~/dycore_with_physics_corrector/~/learned_orography'
]['orography']

learned_orography_modal = get_orography.apply(
    {'learned_orography': {'orography': learned_correction}}, rng=None
)
learned_orography = neural_gcm_model.from_nondim_units(
    dycore_coords.to_nodal(learned_orography_modal), units='meters'
)

default_orography_modal = get_orography.apply(
    {'learned_orography': {'orography': np.zeros_like(learned_correction)}}, rng=None
)
default_orography = neural_gcm_model.from_nondim_units(
    dycore_coords.to_nodal(default_orography_modal), units='meters'
)
# learned orograpy
xarray.DataArray(learned_orography, dims=['x', 'y']).plot.imshow(x='x', y='y', aspect=2, size=4, robust=True, vmin=0)

image

# default orography
xarray.DataArray(default_orography, dims=['x', 'y']).plot.imshow(x='x', y='y', aspect=2, size=4, robust=True, vmin=0)

image

@jswhit
Copy link
Author

jswhit commented May 1, 2024

Thanks @shoyer, using the learned orography really moved the needle for the data assimilation.

@shoyer
Copy link
Collaborator

shoyer commented May 9, 2024

For future reference, details of how to work with our model on sigma coordinates are now described in the NeuralGCM documentation: https://neuralgcm.readthedocs.io/en/latest/trained_models.html

@jswhit
Copy link
Author

jswhit commented May 9, 2024

@shoyer for the deterministic 0.7 degree model it looks like the learned orography is in stochastic_modular_step_model/~/dimensional_learned_weatherbench_to_primitive_with_memory_encoder/~/learned_weatherbench_to_primitive_encoder_1/~/learned_orography and correction_scale should be 2.e-6. Is this correct?

@kochkov92
Copy link
Collaborator

@jswhit good catch! Yes, we did switch to using slightly different values for that resolution.

Probably worth noting that in the model there are up to 3 different orography values (1 in encoder, 1 for the simulation and 1 in decoder).

I believe that for the model that @shoyer showed has fixed orography (modal representation of conservatively regridded ERA5 orography) used by the encoder and decoder and learned orography in the advance step (extracted by @shoyer's snipped)

In the 0.7 degree model all three components were learned. In a few experiments we found that in this setting decoder tends to keep more sharp features in the orography compared to the encoder. The variables that you've tracked down are looking at the encoder orography. Please let me know if the other one under "dycore_with_physics_corrector " doesn't show up.

We will also prioritize easier access to these in the new API that we are starting to work on.

@jswhit
Copy link
Author

jswhit commented May 9, 2024

Thanks @kochkov92, I was able to find 'stochastic_modular_step_model/~/stochastic_physics_parameterization_step/~/custom_coords_corrector/~/dycore_with_physics_corrector/~/learned_orography'. However, when I use it with the get_orography function @shoyer provided

@hk.transform
def get_orography():
  base_orography = functools.partial(
      orographies.FilteredCustomOrography,
      orography_data_path=None,
      renaming_dict=dict(longitude='lon', latitude='lat'),
  )
  orography_coords = dataclasses.replace(
        neural_gcm_model.model_coords,
        horizontal=spherical_harmonic.Grid.with_wavenumbers(longitude_wavenumbers=256)
  )
  return orographies.LearnedOrography(
      orography_coords,
      neural_gcm_model._structure.specs.dt,
      neural_gcm_model._structure.specs.physics_specs,
      neural_gcm_model._structure.specs.aux_features,
      base_orography_module=base_orography,
      correction_scale=2.e-6,

orogpath='stochastic_modular_step_model/~/stochastic_physics_parameterization_step/~/custom_coords_corrector/~/dycore_with_physics_corrector/~/learned_orography'
learned_correction = neural_gcm_model.params[orogpath]['orography']
learned_orography_modal = get_orography.apply(
    {'learned_orography': {'orography': learned_correction}}, rng=None

I get

ValueError: 'learned_orography/orography' with retrieved shape (65023,) does not match shape=(66047,) dtype=<class 'jax.numpy.float32'>

@jswhit
Copy link
Author

jswhit commented May 9, 2024

please disregard my previous question - it works if I use longitude_wavenumbers=254

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants