Skip to content

Commit

Permalink
Introduce param_correlation_sets
Browse files Browse the repository at this point in the history
Fix typing
  • Loading branch information
dafeda committed Oct 17, 2023
1 parent d80f169 commit 270170c
Showing 1 changed file with 19 additions and 13 deletions.
32 changes: 19 additions & 13 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,32 +478,38 @@ def analysis_ES(
)
c_bool = c_AY > correlation_threshold
# Some parameters might be significantly correlated
# to the exact same responses,
# making up what we call a `parameter group``.
# to the exact same responses.
# We want to call the update only once per such parameter group
# to speed up computation.
param_groups = np.unique(c_bool, axis=0)

# Drop the parameter group that does not correlate to any responses.
row_with_all_false = np.all(~param_groups, axis=1)
param_groups = param_groups[~row_with_all_false]
# Here we create a collection of unique sets of parameter-to-observation
# correlations.
param_correlation_sets: npt.NDArray[np.bool_] = np.unique(
c_bool, axis=0
)
# Drop the correlation set that does not correlate to any responses.
row_with_all_false = np.all(~param_correlation_sets, axis=1)
param_correlation_sets = param_correlation_sets[~row_with_all_false]

for grp in param_groups:
for param_correlation_set in param_correlation_sets:
# Find the rows matching the parameter group
matching_rows = np.all(c_bool == grp, axis=1)
matching_rows = np.all(c_bool == param_correlation_set, axis=1)
# Get the indices of the matching rows
row_indices = np.where(matching_rows)[0]
X_chunk = temp_storage[param_group.name][param_batch_idx, :][
row_indices, :
]
S_chunk = S[grp, :]
observation_errors_loc = observation_errors[grp]
observation_values_loc = observation_values[grp]
S_chunk = S[param_correlation_set, :]
observation_errors_loc = observation_errors[
param_correlation_set
]
observation_values_loc = observation_values[
param_correlation_set
]
smoother.fit(
S_chunk,
observation_errors_loc,
observation_values_loc,
noise=noise[grp],
noise=noise[param_correlation_set],
truncation=truncation,
inversion=ies.InversionType(module.inversion),
param_ensemble=param_ensemble,
Expand Down

0 comments on commit 270170c

Please sign in to comment.