Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 644555087
  • Loading branch information
langmore authored and NeuralGCM authors committed Jun 19, 2024
1 parent 301537a commit 3749311
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 17 deletions.
27 changes: 23 additions & 4 deletions neuralgcm/stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,12 @@ def __init__(
correlation_length = maybe_nondimensionalize(
correlation_length, physics_specs
)

# In sampling, phi appears as 1 - phi**2 = 1 - exp(-2 dt / tau)
self.one_minus_phi2 = -jnp.expm1(-2 * dt / tau)

self.phi = jnp.exp(-dt / tau)

self._variance = maybe_nondimensionalize(variance, physics_specs) # σ²

# [Palmer] states correlation_length = sqrt(2κT) / R, therefore
Expand Down Expand Up @@ -348,7 +353,9 @@ def _sigma_array(self) -> jax.Array:
# We do not include the extra fator of 2 in the denominator. I do not know
# why [Palmer] has this factor.
normalization = jnp.sqrt(
self._integrated_grf_variance() * (1 - self.phi**2) / sum_unnormed_vars
self._integrated_grf_variance()
* self.one_minus_phi2
/ sum_unnormed_vars
)

# The factor of coords.horizontal.radius appears because our basis vectors
Expand All @@ -373,7 +380,7 @@ def unconditional_sample(self, rng: typing.PRNGKeyArray) -> RandomnessState:
jax.random.truncated_normal(rng, -self.clip, self.clip, modal_shape),
jnp.zeros(modal_shape),
)
core = (1 - self.phi**2) ** (-0.5) * sigmas * weights
core = self.one_minus_phi2 ** (-0.5) * sigmas * weights
return RandomnessState(
core=core,
nodal_value=self.to_nodal_values(core),
Expand Down Expand Up @@ -620,6 +627,7 @@ def __init__(
initial_correlation_lengths: Sequence[Quantity | str] = gin.REQUIRED,
variances: Sequence[Quantity | str] = gin.REQUIRED,
field_subset: Optional[Sequence[int]] = None,
n_fixed_fields: Optional[int] = None,
clip: float = 6.0,
name: Optional[str] = None,
):
Expand All @@ -643,6 +651,10 @@ def __init__(
Specifies which fields to construct. If None, use all fields. E.g.,
field_subset=[0, 5] means form 3 GRFs from the 0th and 5th parameter
values.
n_fixed_fields: Number of fields that use fixed parameters. These will
be fixed at the trailing `n_fixed_fields` initial correlations. The
total number of fields is unchanged, since these fixed fields replace
learnable fields.
clip: number of standard deviations at which to clip randomness to ensure
numerical stability.
name: Name to show in xprof.
Expand All @@ -657,6 +669,7 @@ def __init__(
]
if len(set(lengths)) != 1:
raise ValueError(f'Argument lengths differed: {lengths=}')
n_fixed_fields = n_fixed_fields or 0

# Get subset of args using `field_subset`
if field_subset is not None:
Expand Down Expand Up @@ -687,9 +700,12 @@ def __init__(
])
correlation_lengths_raw = hk.get_parameter(
'correlation_lengths_raw',
shape=(self.n_fields,),
shape=(self.n_fields - n_fixed_fields,),
init=hk.initializers.Constant(0.0),
)
if n_fixed_fields:
correlation_lengths_raw = jnp.concatenate([
correlation_lengths_raw, jnp.zeros([n_fixed_fields])])
self._correlation_lengths = convert_hk_param_to_positive_scalar(
correlation_lengths_raw, initial_correlation_lengths
)
Expand All @@ -699,9 +715,12 @@ def __init__(
)
correlation_times_raw = hk.get_parameter(
'correlation_times_raw',
shape=(self.n_fields,),
shape=(self.n_fields - n_fixed_fields,),
init=hk.initializers.Constant(0.0),
)
if n_fixed_fields:
correlation_times_raw = jnp.concatenate([
correlation_times_raw, jnp.zeros([n_fixed_fields])])
self._correlation_times = convert_hk_param_to_positive_scalar(
correlation_times_raw, initial_correlation_times
)
Expand Down
149 changes: 136 additions & 13 deletions neuralgcm/stochastic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@

tree_map = jax.tree_util.tree_map

# Effectively constant correlation time/length.
CONSTANT_CORRELATION_TIME_HRS = 24 * 365 * 1000 # 1000 years in hours
CONSTANT_CORRELATION_LENGTH_KM = 40_075 * 10 # 10x circumference of earth in km


@absltest.skipThisClass('Base class')
class BaseRandomFieldTest(parameterized.TestCase):
Expand Down Expand Up @@ -461,28 +465,40 @@ def _make_grf(
initial_correlation_lengths,
initial_correlation_times,
field_subset=None,
n_fixed_fields=None,
):
physics_specs = primitive_equations.PrimitiveEquationsSpecs.from_si()
self.dt = physics_specs.nondimensionalize(1 * scales.units.hour)
self.physics_specs = primitive_equations.PrimitiveEquationsSpecs.from_si()
self.dt = self.physics_specs.nondimensionalize(1 * scales.units.hour)

return stochastic.BatchGaussianRandomFieldModule(
self.coords,
self.dt,
physics_specs,
self.physics_specs,
aux_features={},
initial_correlation_times=initial_correlation_times,
initial_correlation_lengths=initial_correlation_lengths,
variances=variances,
field_subset=field_subset,
n_fixed_fields=n_fixed_fields,
)

def nondimensionalize(self, x):
return stochastic.nondimensionalize(x, self.physics_specs)

@parameterized.named_parameters(
dict(
testcase_name='reasonable_corrs',
variances=(1.0, 2.7),
initial_correlation_lengths=(0.15, 0.2),
initial_correlation_times=(1, 2.1),
),
dict(
testcase_name='one_fixed_field',
variances=(1.0, 2.7),
initial_correlation_lengths=(0.15, 0.2),
initial_correlation_times=(1, 2.1),
n_fixed_fields=1,
),
dict(
testcase_name='reasonable_corrs_skip_middle',
# Using NaN for the one that should be skipped as a extra means to
Expand All @@ -499,6 +515,7 @@ def test_stats(
initial_correlation_lengths,
initial_correlation_times,
field_subset=None,
n_fixed_fields=None,
):
unroll_length = 10

Expand All @@ -510,23 +527,25 @@ def make_field_trajectory(key):
initial_correlation_times=initial_correlation_times,
# Do not specify the field names... Let the default naming happen.
field_subset=field_subset,
n_fixed_fields=n_fixed_fields,
)
sample = grf.unconditional_sample(key)
initial_value = grf.unconditional_sample(key)

def step_fn(c, _):
next_c = grf.advance(c)
next_output = next_c.nodal_value
return (next_c, next_output)

_, trajectory = jax.lax.scan(
step_fn, sample, xs=None, length=unroll_length
step_fn, initial_value, xs=None, length=unroll_length
)
return sample, jax.device_get(trajectory)
return initial_value, jax.device_get(trajectory)

n_fixed_fields = n_fixed_fields or 0
n_samples = 2000
rngs = jax.random.split(jax.random.PRNGKey(802701), n_samples)
params = make_field_trajectory.init(rng=rngs[0], key=rngs[0])
sample, trajectory = jax.vmap(
initial_value, trajectory = jax.vmap(
lambda rng: make_field_trajectory.apply(params, rng, rng)
)(rngs)

Expand All @@ -546,15 +565,15 @@ def step_fn(c, _):

self.assertEqual(
(n_samples, n_fields) + self.coords.horizontal.modal_shape,
sample.core.shape,
initial_value.core.shape,
)
self.assertEqual(
(n_samples, n_fields) + self.coords.horizontal.modal_shape,
sample.modal_value.shape,
initial_value.modal_value.shape,
)
self.assertEqual(
(n_samples, n_fields) + self.coords.horizontal.nodal_shape,
sample.nodal_value.shape,
initial_value.nodal_value.shape,
)

# Check stochastic params were initialized with hk parameters.
Expand All @@ -566,9 +585,18 @@ def step_fn(c, _):
['correlation_times_raw', 'correlation_lengths_raw'],
params['batch_gaussian_random_field_module'].keys(),
)
for name in ['correlation_times_raw', 'correlation_lengths_raw']:
self.assertEqual(
(n_fields - n_fixed_fields,),
params['batch_gaussian_random_field_module'][name].shape,
)

# Core should be modal
tree_map(np.testing.assert_array_equal, sample.core, sample.modal_value)
tree_map(
np.testing.assert_array_equal,
initial_value.core,
initial_value.modal_value,
)

# Nodal values should have the right statistics.
for i, (variance, correlation_length) in enumerate(
Expand All @@ -578,7 +606,7 @@ def step_fn(c, _):
strict=True,
)
):
for x in [sample.nodal_value, final_nodal_value]:
for x in [initial_value.nodal_value, final_nodal_value]:
self.check_mean(
x[:, i],
self.coords,
Expand Down Expand Up @@ -615,9 +643,104 @@ def step_fn(c, _):
# Initial and final sample should be independent as well, since we unroll
# for much longer than the correlation time.
self.check_independent(
sample.nodal_value[:, 0, 50:55, 60], final_nodal_value[:, 0, 50:55, 60]
initial_value.nodal_value[:, 0, 50:55, 60],
final_nodal_value[:, 0, 50:55, 60],
)

def test_giant_correlations_give_constant_fields(self):
unroll_length = 10

initial_correlation_lengths = [
# Include a moderate correlation batch member just to check that extreme
# correlations in other batch members don't mess this up.
0.2,
# Include the default "CONSTANT" correlations
f'{CONSTANT_CORRELATION_LENGTH_KM} km',
# Include a much larger correlation, to check for numerical stability.
f'{1000 * CONSTANT_CORRELATION_LENGTH_KM} km',
]
initial_correlation_times = [
1,
f'{CONSTANT_CORRELATION_TIME_HRS} hours',
f'{1000 * CONSTANT_CORRELATION_TIME_HRS} hours',
]
variances = [1., 1., 1.]

@hk.transform
def make_field_trajectory(key):
grf = self._make_grf(
variances=variances,
initial_correlation_lengths=initial_correlation_lengths,
initial_correlation_times=initial_correlation_times,
# Do not specify the field names... Let the default naming happen.
)
initial_value = grf.unconditional_sample(key)

def step_fn(c, _):
next_c = grf.advance(c)
next_output = next_c.nodal_value
return (next_c, next_output)

_, trajectory = jax.lax.scan(
step_fn, initial_value, xs=None, length=unroll_length
)
return initial_value, jax.device_get(trajectory)

n_samples = 100
rngs = jax.random.split(jax.random.PRNGKey(802701), n_samples)
params = make_field_trajectory.init(rng=rngs[0], key=rngs[0])
initial_value, trajectory = jax.vmap(
lambda rng: make_field_trajectory.apply(params, rng, rng)
)(rngs)
final_nodal_value = trajectory[:, -1]
initial_nodal_value = initial_value.nodal_value

self.assertTrue(np.all(np.isfinite(initial_nodal_value)))
self.assertTrue(np.all(np.isfinite(final_nodal_value)))

# All fields have the correct mean and variance.
for i in range(3):
for x in [initial_value.nodal_value, final_nodal_value]:
self.check_mean(
x[:, i],
self.coords,
expected_mean=0.0,
variance=variances[i],
correlation_length=self.nondimensionalize(
initial_correlation_lengths[i]),
mean_tol_in_standard_errs=5,
)
self.check_variance(
x[:, i],
self.coords,
correlation_length=self.nondimensionalize(
initial_correlation_lengths[i]
),
expected_variance=variances[i],
var_tol_in_standard_errs=5,
)

# Field 0 (moderate correlation) has correct correlation length.
for x in [initial_value.nodal_value, final_nodal_value]:
self.check_correlation_length(
x[:, 0],
expected_correlation_length=initial_correlation_lengths[0],
coords=self.coords,
)

# The variation in index 0 (moderate correlation) is much larger than that
# in index 1 (CONSTANT_CORRELATION_*) or 2.
# This checks the correlation length/time of fields 1, 2 is huge.
diff = final_nodal_value - initial_nodal_value
for i in [1, 2]:
self.assertGreater( # Variation in time
np.max(np.abs(diff[:, 0])), 100 * np.max(np.abs(diff[:, i]))
)
self.assertGreater( # Variation in lat/lon
np.std(initial_nodal_value[:, 0], axis=(-1, -2)).max(),
10000 * np.std(initial_nodal_value[:, i], axis=(-1, -2)).max(),
)


class DictOfGaussianRandomFieldModulesTest(BaseRandomFieldTest):

Expand Down

0 comments on commit 3749311

Please sign in to comment.