diff --git a/docs/api.md b/docs/api.md index 7f50668..0449d38 100644 --- a/docs/api.md +++ b/docs/api.md @@ -14,12 +14,27 @@ The user-facing API for NeuralGCM models centers around `PressureLevelModel`: ### Constructor -Use this class method to create a new model: +Use this class method to create a new model from a saved checkpoint: ```{eval-rst} .. automethod:: PressureLevelModel.from_checkpoint ``` +### Attributes + +These intentionally attributes store the internal state of a +`PressureLevelModel`. You can use them to introspect how a model is built. + +`structure` and `params` completely define a model. In the current version of +NeuralGCM, structure is equal to the tuple `(model.gin_config, model.aux_data)`. + +```{eval-rst} +.. autoproperty:: PressureLevelModel.structure +.. autoproperty:: PressureLevelModel.params +.. autoproperty:: PressureLevelModel.gin_config +.. autoproperty:: PressureLevelModel.aux_data +``` + ### Properties These properties describe the coordinate system and variables for which a model diff --git a/neuralgcm/api.py b/neuralgcm/api.py index c1b910f..d1a1a7a 100644 --- a/neuralgcm/api.py +++ b/neuralgcm/api.py @@ -14,6 +14,7 @@ """Public API for NeuralGCM models.""" from __future__ import annotations +import copy from collections import abc import datetime import functools @@ -83,7 +84,7 @@ def _static_gin_config(method): @functools.wraps(method) def _method(self, *args, **kwargs): - with gin_utils.specific_config(self.gin_config): + with gin_utils.specific_config(self._gin_config): return method(self, *args, **kwargs) return _method @@ -164,12 +165,14 @@ def _rename_if_found( _FULL_NAMES = {v: k for k, v in _ABBREVIATED_NAMES.items()} -def _expand_tracers(inputs: dict) -> dict: - inputs = inputs.copy() - inputs.update(inputs.pop('tracers')) - assert not inputs['diagnostics'] - del inputs['diagnostics'] - return inputs +class _Opaque: + """Wrapper to make Python data opaque to jax.tree. + + This object implements equality and hashing by identity. + """ + + def __init__(self, data): + self.data = data @tree_util.register_pytree_node_class @@ -183,15 +186,13 @@ class PressureLevelModel: hence should remain stable even for future NeuralGCM models. """ - def __init__( - self, - structure: model_builder.WhirlModel, - params: Params, - gin_config: str, - ): - self._structure = structure + def __init__(self, structure: Any, params: Params): + # internal model state + self._structure = _Opaque(structure) self._params = params - self.gin_config = gin_config + + # calculated variables + self._gin_config, self._aux_data = structure self._tracer_variables = [ 'specific_humidity', @@ -210,7 +211,7 @@ def __init__( 'specific_cloud_liquid_water_content', ] for variable in cloud_variables: - if variable in self.gin_config: + if variable in self._gin_config: self._tracer_variables.append(variable) self._input_variables.append(variable) @@ -221,23 +222,59 @@ def __init__( def __repr__(self): return ( - f'{self.__class__.__name__}(structure={self._structure},' + f'{self.__class__.__name__}(structure={self._structure.data},' f' params={self._params})' ) + @functools.cached_property + @_static_gin_config + def _whirl_model(self) -> model_builder.WhirlModel: + """Construct an internal WhirlModle. Not public API.""" + physics_specs = physics_specifications.get_physics_specs() + aux_dataset = xarray.Dataset.from_dict(self._aux_data) + data_coords = model_builder.coordinate_system_from_dataset(aux_dataset) + model_specs = model_builder.get_model_specs( + data_coords, physics_specs, {xarray_utils.XARRAY_DS_KEY: aux_dataset} + ) + whirl_model = model_builder.WhirlModel( + coords=model_specs.coords, + dt=model_specs.dt, + physics_specs=model_specs.physics_specs, + aux_features=model_specs.aux_features, + input_coords=data_coords, + output_coords=data_coords, + ) + return whirl_model + + @property + def structure(self) -> Any: + """The internal configuration structure of this model.""" + return self._structure.data + + @property + def gin_config(self) -> str: + """Gin configuration string used to construct this model.""" + return self._gin_config + + @property + def aux_data(self) -> dict: # pylint: disable=g-bare-generic + """Auxilliary data used to construct this model.""" + return copy.deepcopy(self._aux_data) + @property def params(self) -> Params: + """The learnable parameters of this model.""" return self._params def tree_flatten(self): - leaves, params_def = tree_util.tree_flatten(self.params) - return (leaves, (params_def, self._structure, self.gin_config)) + leaves, params_def = tree_util.tree_flatten(self._params) + return (leaves, (params_def, self._structure)) @classmethod - def tree_unflatten(cls, aux_data, leaves): - params_def, structure, gin_config = aux_data + def tree_unflatten(cls, tree_def, leaves): + params_def, structure = tree_def params = tree_util.tree_unflatten(params_def, leaves) - return cls(structure, params, gin_config) + return cls(structure.data, params) @property def input_variables(self) -> list[str]: @@ -253,29 +290,29 @@ def forcing_variables(self) -> list[str]: def timestep(self) -> np.timedelta64: """Spacing between internal model timesteps.""" to_timedelta = ( - self._structure.specs.physics_specs.dimensionalize_timedelta64 + self._whirl_model.specs.physics_specs.dimensionalize_timedelta64 ) - return to_timedelta(self._structure.specs.dt) + return to_timedelta(self._whirl_model.specs.dt) @property def data_coords(self) -> coordinate_systems.CoordinateSystem: """Coordinate system for input and output data.""" - return self._structure.data_coords + return self._whirl_model.data_coords @property def model_coords(self) -> coordinate_systems.CoordinateSystem: """Coordinate system for internal model state.""" - return self._structure.coords + return self._whirl_model.coords def _check_coords(self, dataset: xarray.Dataset): dataset_coords = model_builder.coordinate_system_from_dataset(dataset) _check_coords(dataset_coords, self.data_coords) def _dataset_with_sim_time(self, dataset: xarray.Dataset) -> xarray.Dataset: - ref_datetime = self._structure.specs.aux_features['reference_datetime'] + ref_datetime = self._whirl_model.specs.aux_features['reference_datetime'] return xarray_utils.ds_with_sim_time( dataset, - self._structure.specs.physics_specs, + self._whirl_model.specs.physics_specs, reference_datetime=ref_datetime, ) @@ -295,31 +332,31 @@ def _from_abbreviated_names_and_tracers(self, outputs: dict) -> dict: def to_nondim_units(self, value: Numeric, units: str) -> Numeric: """Scale a value to the model's internal non-dimensional units.""" - scale_ = self._structure.specs.physics_specs.scale + scale_ = self._whirl_model.specs.physics_specs.scale units_ = scales.parse_units(units) return scale_.nondimensionalize(value * units_) def from_nondim_units(self, value: Numeric, units: str) -> Numeric: """Scale a value from the model's internal non-dimensional units.""" - scale_ = self._structure.specs.physics_specs.scale + scale_ = self._whirl_model.specs.physics_specs.scale units_ = scales.parse_units(units) return scale_.dimensionalize(value, units_).magnitude def datetime64_to_sim_time(self, datetime64: np.ndarray) -> np.ndarray: """Converts a datetime64 array to sim_time.""" - ref_datetime = self._structure.specs.aux_features['reference_datetime'] + ref_datetime = self._whirl_model.specs.aux_features['reference_datetime'] return xarray_utils.datetime64_to_nondim_time( datetime64, - self._structure.specs.physics_specs, + self._whirl_model.specs.physics_specs, reference_datetime=ref_datetime, ) def sim_time_to_datetime64(self, sim_time: np.ndarray) -> np.ndarray: """Converts a sim_time array to datetime64.""" - ref_datetime = self._structure.specs.aux_features['reference_datetime'] + ref_datetime = self._whirl_model.specs.aux_features['reference_datetime'] return xarray_utils.nondim_time_to_datetime64( sim_time, - self._structure.specs.physics_specs, + self._whirl_model.specs.physics_specs, reference_datetime=ref_datetime, ) @@ -364,9 +401,9 @@ def data_to_xarray( Args: data: dict of arrays with shapes matching input/outputs or encoded model - state for this model, i.e., with shape - `([time,] level, longitude, latitude)`, - where `[time,]` indicates an optional leading time dimension. + state for this model, i.e., with shape `([time,] level, longitude, + latitude)`, where `[time,]` indicates an optional leading time + dimension. times: either `None` indicating no leading time dimension on any variables, or a coordinate array of times with shape `(time,)`. decoded: if `True`, use `self.data_coords` to determine the output @@ -419,8 +456,8 @@ def encode( inputs = _prepend_dummy_time_axis(inputs) forcings = self._squeeze_level_from_forcings(forcings) forcings = _prepend_dummy_time_axis(forcings) - f = self._structure.forcing_fn(self.params, None, forcings, sim_time) - return self._structure.encode_fn(self.params, rng_key, inputs, f) + f = self._whirl_model.forcing_fn(self.params, None, forcings, sim_time) + return self._whirl_model.encode_fn(self.params, rng_key, inputs, f) @jax.jit @_static_gin_config @@ -442,8 +479,8 @@ def advance(self, state: State, forcings: Forcings) -> State: sim_time = _sim_time_from_state(state) forcings = self._squeeze_level_from_forcings(forcings) forcings = _prepend_dummy_time_axis(forcings) - f = self._structure.forcing_fn(self.params, None, forcings, sim_time) - state = self._structure.advance_fn(self.params, None, state, f) + f = self._whirl_model.forcing_fn(self.params, None, forcings, sim_time) + state = self._whirl_model.advance_fn(self.params, None, state, f) return state @jax.jit @@ -467,8 +504,8 @@ def decode(self, state: State, forcings: Forcings) -> Outputs: sim_time = _sim_time_from_state(state) forcings = self._squeeze_level_from_forcings(forcings) forcings = _prepend_dummy_time_axis(forcings) - f = self._structure.forcing_fn(self.params, None, forcings, sim_time) - outputs = self._structure.decode_fn(self.params, None, state, f) + f = self._whirl_model.forcing_fn(self.params, None, forcings, sim_time) + outputs = self._whirl_model.decode_fn(self.params, None, state, f) outputs = self._from_abbreviated_names_and_tracers(outputs) return outputs @@ -512,13 +549,13 @@ def unroll( steps: number of time-steps to take. timedelta: size of each time-step to take, which must be a multiple of the internal model timestep. By default uses the internal model timestep. - start_with_input: if ``True``, outputs are at times ``[0, ..., - (steps - 1) * timestep]`` relative to the initial time; if ``False``, - outputs are at times ``[timestep, ..., steps * timestep]``. + start_with_input: if ``True``, outputs are at times ``[0, ..., (steps - 1) + * timestep]`` relative to the initial time; if ``False``, outputs are at + times ``[timestep, ..., steps * timestep]``. post_process_fn: optional function to apply to each advanced state and - current forcings to create outputs like - ``post_process_fn(state, forcings)``, where ``forcings`` does not - include a time axis. By default, uses ``model.decode``. + current forcings to create outputs like ``post_process_fn(state, + forcings)``, where ``forcings`` does not include a time axis. By + default, uses ``model.decode``. Returns: A tuple of the advanced state at time ``steps * timestamp``, and outputs @@ -540,6 +577,7 @@ def wrapped(state): sim_time = _sim_time_from_state(state) forcings = get_nearest_forcings(sim_time) return func(state, forcings) + return wrapped if post_process_fn is None: @@ -557,7 +595,7 @@ def wrapped(state): return state, outputs @classmethod - def from_checkpoint(cls, checkpoint: Any) -> PressureLevelModel: + def from_checkpoint(cls, checkpoint: dict) -> PressureLevelModel: """Creates a PressureLevelModel from a checkpoint. Args: @@ -569,21 +607,8 @@ def from_checkpoint(cls, checkpoint: Any) -> PressureLevelModel: Instance of a `PressureLevelModel` with weights and configuration specified by the checkpoint. """ - with gin_utils.specific_config(checkpoint['model_config_str']): - physics_specs = physics_specifications.get_physics_specs() - aux_ds = xarray.Dataset.from_dict(checkpoint['aux_ds_dict']) - data_coords = model_builder.coordinate_system_from_dataset(aux_ds) - model_specs = model_builder.get_model_specs( - data_coords, physics_specs, {xarray_utils.XARRAY_DS_KEY: aux_ds} - ) - whirl_model = model_builder.WhirlModel( - coords=model_specs.coords, - dt=model_specs.dt, - physics_specs=model_specs.physics_specs, - aux_features=model_specs.aux_features, - input_coords=data_coords, - output_coords=data_coords, - ) - return cls( - whirl_model, checkpoint['params'], checkpoint['model_config_str'] - ) + gin_config = checkpoint['model_config_str'] + aux_data = checkpoint['aux_ds_dict'] + structure = (gin_config, aux_data) + params = checkpoint['params'] + return cls(structure, params) diff --git a/neuralgcm/api_test.py b/neuralgcm/api_test.py index 02eee7b..6b50565 100644 --- a/neuralgcm/api_test.py +++ b/neuralgcm/api_test.py @@ -83,6 +83,23 @@ def test_model_properties(self): self.assertEqual(model.model_coords.nodal_shape, (32, 128, 64)) + def test_pytree(self): + model = load_tl63_stochastic_model() + tree_def = jax.tree.structure(model) + self.assertEqual(tree_def, tree_def) + + trace_count = 0 + + @jax.jit + def f(model): # pylint: disable=unused-argument + nonlocal trace_count + trace_count += 1 + return 0 + + f(model) + f(model) + self.assertEqual(trace_count, 1) + def test_to_and_from_nondim_units(self): model = load_tl63_stochastic_model()