Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose model structure and params as public API #65

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 16 additions & 1 deletion docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,27 @@ The user-facing API for NeuralGCM models centers around `PressureLevelModel`:

### Constructor

Use this class method to create a new model:
Use this class method to create a new model from a saved checkpoint:

```{eval-rst}
.. automethod:: PressureLevelModel.from_checkpoint
```

### Attributes

These intentionally attributes store the internal state of a
`PressureLevelModel`. You can use them to introspect how a model is built.

`structure` and `params` completely define a model. In the current version of
NeuralGCM, structure is equal to the tuple `(model.gin_config, model.aux_data)`.

```{eval-rst}
.. autoproperty:: PressureLevelModel.structure
.. autoproperty:: PressureLevelModel.params
.. autoproperty:: PressureLevelModel.gin_config
.. autoproperty:: PressureLevelModel.aux_data
```

### Properties

These properties describe the coordinate system and variables for which a model
Expand Down
161 changes: 93 additions & 68 deletions neuralgcm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Public API for NeuralGCM models."""
from __future__ import annotations

import copy
from collections import abc
import datetime
import functools
Expand Down Expand Up @@ -83,7 +84,7 @@ def _static_gin_config(method):

@functools.wraps(method)
def _method(self, *args, **kwargs):
with gin_utils.specific_config(self.gin_config):
with gin_utils.specific_config(self._gin_config):
return method(self, *args, **kwargs)

return _method
Expand Down Expand Up @@ -164,12 +165,14 @@ def _rename_if_found(
_FULL_NAMES = {v: k for k, v in _ABBREVIATED_NAMES.items()}


def _expand_tracers(inputs: dict) -> dict:
inputs = inputs.copy()
inputs.update(inputs.pop('tracers'))
assert not inputs['diagnostics']
del inputs['diagnostics']
return inputs
class _Opaque:
"""Wrapper to make Python data opaque to jax.tree.

This object implements equality and hashing by identity.
"""

def __init__(self, data):
self.data = data


@tree_util.register_pytree_node_class
Expand All @@ -183,15 +186,13 @@ class PressureLevelModel:
hence should remain stable even for future NeuralGCM models.
"""

def __init__(
self,
structure: model_builder.WhirlModel,
params: Params,
gin_config: str,
):
self._structure = structure
def __init__(self, structure: Any, params: Params):
# internal model state
self._structure = _Opaque(structure)
self._params = params
self.gin_config = gin_config

# calculated variables
self._gin_config, self._aux_data = structure

self._tracer_variables = [
'specific_humidity',
Expand All @@ -210,7 +211,7 @@ def __init__(
'specific_cloud_liquid_water_content',
]
for variable in cloud_variables:
if variable in self.gin_config:
if variable in self._gin_config:
self._tracer_variables.append(variable)
self._input_variables.append(variable)

Expand All @@ -221,23 +222,59 @@ def __init__(

def __repr__(self):
return (
f'{self.__class__.__name__}(structure={self._structure},'
f'{self.__class__.__name__}(structure={self._structure.data},'
f' params={self._params})'
)

@functools.cached_property
@_static_gin_config
def _whirl_model(self) -> model_builder.WhirlModel:
"""Construct an internal WhirlModle. Not public API."""
physics_specs = physics_specifications.get_physics_specs()
aux_dataset = xarray.Dataset.from_dict(self._aux_data)
data_coords = model_builder.coordinate_system_from_dataset(aux_dataset)
model_specs = model_builder.get_model_specs(
data_coords, physics_specs, {xarray_utils.XARRAY_DS_KEY: aux_dataset}
)
whirl_model = model_builder.WhirlModel(
coords=model_specs.coords,
dt=model_specs.dt,
physics_specs=model_specs.physics_specs,
aux_features=model_specs.aux_features,
input_coords=data_coords,
output_coords=data_coords,
)
return whirl_model

@property
def structure(self) -> Any:
"""The internal configuration structure of this model."""
return self._structure.data

@property
def gin_config(self) -> str:
"""Gin configuration string used to construct this model."""
return self._gin_config

@property
def aux_data(self) -> dict: # pylint: disable=g-bare-generic
"""Auxilliary data used to construct this model."""
return copy.deepcopy(self._aux_data)

@property
def params(self) -> Params:
"""The learnable parameters of this model."""
return self._params

def tree_flatten(self):
leaves, params_def = tree_util.tree_flatten(self.params)
return (leaves, (params_def, self._structure, self.gin_config))
leaves, params_def = tree_util.tree_flatten(self._params)
return (leaves, (params_def, self._structure))

@classmethod
def tree_unflatten(cls, aux_data, leaves):
params_def, structure, gin_config = aux_data
def tree_unflatten(cls, tree_def, leaves):
params_def, structure = tree_def
params = tree_util.tree_unflatten(params_def, leaves)
return cls(structure, params, gin_config)
return cls(structure.data, params)

@property
def input_variables(self) -> list[str]:
Expand All @@ -253,29 +290,29 @@ def forcing_variables(self) -> list[str]:
def timestep(self) -> np.timedelta64:
"""Spacing between internal model timesteps."""
to_timedelta = (
self._structure.specs.physics_specs.dimensionalize_timedelta64
self._whirl_model.specs.physics_specs.dimensionalize_timedelta64
)
return to_timedelta(self._structure.specs.dt)
return to_timedelta(self._whirl_model.specs.dt)

@property
def data_coords(self) -> coordinate_systems.CoordinateSystem:
"""Coordinate system for input and output data."""
return self._structure.data_coords
return self._whirl_model.data_coords

@property
def model_coords(self) -> coordinate_systems.CoordinateSystem:
"""Coordinate system for internal model state."""
return self._structure.coords
return self._whirl_model.coords

def _check_coords(self, dataset: xarray.Dataset):
dataset_coords = model_builder.coordinate_system_from_dataset(dataset)
_check_coords(dataset_coords, self.data_coords)

def _dataset_with_sim_time(self, dataset: xarray.Dataset) -> xarray.Dataset:
ref_datetime = self._structure.specs.aux_features['reference_datetime']
ref_datetime = self._whirl_model.specs.aux_features['reference_datetime']
return xarray_utils.ds_with_sim_time(
dataset,
self._structure.specs.physics_specs,
self._whirl_model.specs.physics_specs,
reference_datetime=ref_datetime,
)

Expand All @@ -295,31 +332,31 @@ def _from_abbreviated_names_and_tracers(self, outputs: dict) -> dict:

def to_nondim_units(self, value: Numeric, units: str) -> Numeric:
"""Scale a value to the model's internal non-dimensional units."""
scale_ = self._structure.specs.physics_specs.scale
scale_ = self._whirl_model.specs.physics_specs.scale
units_ = scales.parse_units(units)
return scale_.nondimensionalize(value * units_)

def from_nondim_units(self, value: Numeric, units: str) -> Numeric:
"""Scale a value from the model's internal non-dimensional units."""
scale_ = self._structure.specs.physics_specs.scale
scale_ = self._whirl_model.specs.physics_specs.scale
units_ = scales.parse_units(units)
return scale_.dimensionalize(value, units_).magnitude

def datetime64_to_sim_time(self, datetime64: np.ndarray) -> np.ndarray:
"""Converts a datetime64 array to sim_time."""
ref_datetime = self._structure.specs.aux_features['reference_datetime']
ref_datetime = self._whirl_model.specs.aux_features['reference_datetime']
return xarray_utils.datetime64_to_nondim_time(
datetime64,
self._structure.specs.physics_specs,
self._whirl_model.specs.physics_specs,
reference_datetime=ref_datetime,
)

def sim_time_to_datetime64(self, sim_time: np.ndarray) -> np.ndarray:
"""Converts a sim_time array to datetime64."""
ref_datetime = self._structure.specs.aux_features['reference_datetime']
ref_datetime = self._whirl_model.specs.aux_features['reference_datetime']
return xarray_utils.nondim_time_to_datetime64(
sim_time,
self._structure.specs.physics_specs,
self._whirl_model.specs.physics_specs,
reference_datetime=ref_datetime,
)

Expand Down Expand Up @@ -364,9 +401,9 @@ def data_to_xarray(

Args:
data: dict of arrays with shapes matching input/outputs or encoded model
state for this model, i.e., with shape
`([time,] level, longitude, latitude)`,
where `[time,]` indicates an optional leading time dimension.
state for this model, i.e., with shape `([time,] level, longitude,
latitude)`, where `[time,]` indicates an optional leading time
dimension.
times: either `None` indicating no leading time dimension on any
variables, or a coordinate array of times with shape `(time,)`.
decoded: if `True`, use `self.data_coords` to determine the output
Expand Down Expand Up @@ -419,8 +456,8 @@ def encode(
inputs = _prepend_dummy_time_axis(inputs)
forcings = self._squeeze_level_from_forcings(forcings)
forcings = _prepend_dummy_time_axis(forcings)
f = self._structure.forcing_fn(self.params, None, forcings, sim_time)
return self._structure.encode_fn(self.params, rng_key, inputs, f)
f = self._whirl_model.forcing_fn(self.params, None, forcings, sim_time)
return self._whirl_model.encode_fn(self.params, rng_key, inputs, f)

@jax.jit
@_static_gin_config
Expand All @@ -442,8 +479,8 @@ def advance(self, state: State, forcings: Forcings) -> State:
sim_time = _sim_time_from_state(state)
forcings = self._squeeze_level_from_forcings(forcings)
forcings = _prepend_dummy_time_axis(forcings)
f = self._structure.forcing_fn(self.params, None, forcings, sim_time)
state = self._structure.advance_fn(self.params, None, state, f)
f = self._whirl_model.forcing_fn(self.params, None, forcings, sim_time)
state = self._whirl_model.advance_fn(self.params, None, state, f)
return state

@jax.jit
Expand All @@ -467,8 +504,8 @@ def decode(self, state: State, forcings: Forcings) -> Outputs:
sim_time = _sim_time_from_state(state)
forcings = self._squeeze_level_from_forcings(forcings)
forcings = _prepend_dummy_time_axis(forcings)
f = self._structure.forcing_fn(self.params, None, forcings, sim_time)
outputs = self._structure.decode_fn(self.params, None, state, f)
f = self._whirl_model.forcing_fn(self.params, None, forcings, sim_time)
outputs = self._whirl_model.decode_fn(self.params, None, state, f)
outputs = self._from_abbreviated_names_and_tracers(outputs)
return outputs

Expand Down Expand Up @@ -512,13 +549,13 @@ def unroll(
steps: number of time-steps to take.
timedelta: size of each time-step to take, which must be a multiple of the
internal model timestep. By default uses the internal model timestep.
start_with_input: if ``True``, outputs are at times ``[0, ...,
(steps - 1) * timestep]`` relative to the initial time; if ``False``,
outputs are at times ``[timestep, ..., steps * timestep]``.
start_with_input: if ``True``, outputs are at times ``[0, ..., (steps - 1)
* timestep]`` relative to the initial time; if ``False``, outputs are at
times ``[timestep, ..., steps * timestep]``.
post_process_fn: optional function to apply to each advanced state and
current forcings to create outputs like
``post_process_fn(state, forcings)``, where ``forcings`` does not
include a time axis. By default, uses ``model.decode``.
current forcings to create outputs like ``post_process_fn(state,
forcings)``, where ``forcings`` does not include a time axis. By
default, uses ``model.decode``.

Returns:
A tuple of the advanced state at time ``steps * timestamp``, and outputs
Expand All @@ -540,6 +577,7 @@ def wrapped(state):
sim_time = _sim_time_from_state(state)
forcings = get_nearest_forcings(sim_time)
return func(state, forcings)

return wrapped

if post_process_fn is None:
Expand All @@ -557,7 +595,7 @@ def wrapped(state):
return state, outputs

@classmethod
def from_checkpoint(cls, checkpoint: Any) -> PressureLevelModel:
def from_checkpoint(cls, checkpoint: dict) -> PressureLevelModel:
"""Creates a PressureLevelModel from a checkpoint.

Args:
Expand All @@ -569,21 +607,8 @@ def from_checkpoint(cls, checkpoint: Any) -> PressureLevelModel:
Instance of a `PressureLevelModel` with weights and configuration
specified by the checkpoint.
"""
with gin_utils.specific_config(checkpoint['model_config_str']):
physics_specs = physics_specifications.get_physics_specs()
aux_ds = xarray.Dataset.from_dict(checkpoint['aux_ds_dict'])
data_coords = model_builder.coordinate_system_from_dataset(aux_ds)
model_specs = model_builder.get_model_specs(
data_coords, physics_specs, {xarray_utils.XARRAY_DS_KEY: aux_ds}
)
whirl_model = model_builder.WhirlModel(
coords=model_specs.coords,
dt=model_specs.dt,
physics_specs=model_specs.physics_specs,
aux_features=model_specs.aux_features,
input_coords=data_coords,
output_coords=data_coords,
)
return cls(
whirl_model, checkpoint['params'], checkpoint['model_config_str']
)
gin_config = checkpoint['model_config_str']
aux_data = checkpoint['aux_ds_dict']
structure = (gin_config, aux_data)
params = checkpoint['params']
return cls(structure, params)
17 changes: 17 additions & 0 deletions neuralgcm/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,23 @@ def test_model_properties(self):

self.assertEqual(model.model_coords.nodal_shape, (32, 128, 64))

def test_pytree(self):
model = load_tl63_stochastic_model()
tree_def = jax.tree.structure(model)
self.assertEqual(tree_def, tree_def)

trace_count = 0

@jax.jit
def f(model): # pylint: disable=unused-argument
nonlocal trace_count
trace_count += 1
return 0

f(model)
f(model)
self.assertEqual(trace_count, 1)

def test_to_and_from_nondim_units(self):
model = load_tl63_stochastic_model()

Expand Down