forked from MESMER-group/mesmer
/
train_gt.py
287 lines (223 loc) · 8.81 KB
/
train_gt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
# MESMER, land-climate dynamics group, S.I. Seneviratne
# Copyright (c) 2021 ETH Zurich, MESMER contributors listed in AUTHORS.
# Licensed under the GNU General Public License v3.0 or later see LICENSE or
# https://www.gnu.org/licenses/
"""
Functions to train global trend module of MESMER.
"""
import warnings
import numpy as np
import xarray as xr
from mesmer.io import load_strat_aod
from mesmer.io.save_mesmer_bundle import save_mesmer_data
from mesmer.stats.linear_regression import LinearRegression
from mesmer.stats.smoothing import lowess
def train_gt(var, targ, esm, time, cfg, save_params=True):
"""
Derive global trend (emissions + volcanoes) parameters from specified ensemble type
with specified method.
Parameters
----------
var : dict
nested global mean variable dictionary with keys for each scenario employed for
training
- [scen] (2d array (run, time) of globally-averaged variable time series)
targ : str
target variable (e.g., "tas")
esm : str
associated Earth System Model (e.g., "CanESM2" or "CanESM5")
time : np.ndarray
[scen] (1d array of years)
cfg : module
config file containing metadata
save_params : bool, default True
determines if parameters are saved or not, default = True
Returns
-------
params_gt : dict
dictionary containing the trained parameters for the chosen method / ensemble
type
- ["targ"] (emulated variable, str)
- ["esm"] (Earth System Model, str)
- ["method"] (applied method, str)
- ["preds"] (predictors, list of strs)
- ["scenarios"] (emission scenarios used for training, list of strs)
- [xx] (additional params depend on method employed)
Notes
-----
- Assumptions:
- All scens start at the same point in time
- If historical data is present, historical data and future scenarios are
transmitted as single time series
- No perfect smoothness enforced at transition from historical to future scenario
- No perfect overlap between future scenarios which share the same forcing in the
beginning is enforced
"""
# specify necessary variables from config file
gen = cfg.gen
method_gt = cfg.methods[targ]["gt"]
preds_gt = cfg.preds[targ]["gt"]
scenarios_tr = list(var.keys())
# initialize parameters dictionary and fill in the metadata which does not depend on
# the applied method
params_gt = {}
params_gt["targ"] = targ
params_gt["esm"] = esm
params_gt["method"] = method_gt
params_gt["preds"] = preds_gt
params_gt["scenarios"] = scenarios_tr # single entry in case of ic ensemble
# apply the chosen method to the type of ensenble
gt = {}
if "LOWESS" in params_gt["method"]:
# i.e. derive gt for each scen individually
for scen in scenarios_tr:
gt[scen], frac_lowess_name = train_gt_ic_LOWESS(var[scen])
params_gt["frac_lowess"] = frac_lowess_name
else:
raise ValueError("No alternative method to LOWESS is implemented for now.")
params_gt["time"] = {}
# i.e. if hist included
if scenarios_tr[0][:2] == "h-":
if gen == 5:
start_year_fut = 2005
elif gen == 6:
start_year_fut = 2014
idx_start_year_fut = np.where(time[scen] == start_year_fut)[0][0] + 1
params_gt["time"]["hist"] = time[scen][:idx_start_year_fut]
# compute median LOWESS estimate of historical part across all scenarios
gt_lowess_hist_all = np.zeros([len(gt.keys()), len(params_gt["time"]["hist"])])
for i, scen in enumerate(gt.keys()):
gt_lowess_hist_all[i] = gt[scen][:idx_start_year_fut]
gt_lowess_hist = np.median(gt_lowess_hist_all, axis=0)
if params_gt["method"] == "LOWESS_OLSVOLC":
scen = scenarios_tr[0]
var_all = var[scen][:, :idx_start_year_fut]
for scen in scenarios_tr[1:]:
var_tmp = var[scen][:, :idx_start_year_fut]
var_all = np.vstack([var_all, var_tmp])
# check for duplicates & exclude those runs
var_all = np.unique(var_all, axis=0)
params_gt["saod"], params_gt["hist"] = train_gt_ic_OLSVOLC(
var_all, gt_lowess_hist, params_gt["time"]["hist"]
)
elif params_gt["method"] == "LOWESS":
params_gt["hist"] = gt_lowess_hist
# isolate future scen names
scenarios_tr_f = [scen.replace("h-", "") for scen in scenarios_tr]
else:
# because first year would be already in future
idx_start_year_fut = 0
# because only future covered anyways
scenarios_tr_f = scenarios_tr
for scen_f, scen in zip(scenarios_tr_f, scenarios_tr):
params_gt["time"][scen_f] = time[scen][idx_start_year_fut:]
params_gt[scen_f] = gt[scen][idx_start_year_fut:]
# save the global trend paramters if requested
if save_params:
save_mesmer_data(
params_gt,
cfg.dir_mesmer_params,
"global",
"global_trend",
filename_parts=(
"params_gt",
method_gt,
*preds_gt,
targ,
esm,
*scenarios_tr,
),
)
return params_gt
def train_gt_ic_LOWESS(data):
"""
Derive smooth global trend of variable from single ESM ic ensemble with LOWESS
smoother.
Parameters
----------
data : np.ndarray
2d array (run, time) of globally-averaged time series
Returns
-------
gt_lowess : np.ndarray
1d array of smooth global trend of variable
frac_lowess : float
fraction of the data used when estimating each y-value
"""
data = xr.DataArray(data, dims=("ensemble", "time"))
# average across all runs to get a first smoothing
data = data.mean("ensemble")
dim = "time"
# apply lowess smoother to further smooth the Tglob time series
# rather arbitrarily chosen value that gives a smooth enough trend,
frac = 50 / data.sizes[dim]
# open to changes but if much smaller, var trend ends up very wiggly
frac_lowess_name = "50/nr_ts"
gt_lowess = lowess(data, dim=dim, frac=frac).values
return gt_lowess, frac_lowess_name
def train_gt_ic_OLSVOLC(var, gt_lowess, time, cfg=None):
"""
Derive global trend (emissions + volcanoes) parameters from single ESM ic ensemble
by adding volcanic spikes to LOWESS trend.
Parameters
----------
var : np.ndarray
2d array (run, time) of globally-averaged time series
gt_lowess : np.ndarray
1d array of smooth global trend of variable
time : np.ndarray
1d array of years
cfg : None
Passing cfg is no longer required.
Returns
-------
coef_saod : float
stratospheric AOD OLS coefficient for variable variability
gt : np.ndarray
1d array of global temperature trend with volcanic spikes
Notes
-----
- Assumptions:
- only historical time period data is passed
"""
if cfg is not None:
warnings.warn(
"Passing ``cfg`` to ``train_gt_ic_OLSVOLC`` is no longer necessary",
FutureWarning,
)
nr_runs, nr_ts = var.shape
# account for volcanic eruptions in historical time period
# load in observed stratospheric aerosol optical depth
aod_obs = load_strat_aod(time)
# drop "year" coords - aod_obs does not have coords (currently)
aod_obs = aod_obs.drop_vars("year")
# repeat aod time series as many times as runs available
aod_obs_all = xr.concat([aod_obs] * nr_runs, dim="year")
nr_aod_obs = aod_obs.shape[0]
if nr_ts != nr_aod_obs:
raise ValueError(
f"The number of time steps of the variable ({nr_ts}) and the saod "
f"({nr_aod_obs}) do not match."
)
# extract global variability (which still includes volc eruptions) by removing
# smooth trend from Tglob in historic period
# (should broadcast, and flatten the correct way - hopefully)
gv_all_for_aod = (var - gt_lowess).ravel()
gv_all_for_aod = xr.DataArray(gv_all_for_aod, dims="year").expand_dims("x")
lr = LinearRegression()
# fit linear regression of gt to aod (because some ESMs react very strongly to
# volcanoes)
# no intercept to not artifically move the ts
lr.fit(
predictors={"aod_obs": aod_obs_all},
target=gv_all_for_aod,
dim="year",
fit_intercept=False,
)
# extract the saod coefficient
coef_saod = lr.params["aod_obs"].values
# apply linear regression model to obtain volcanic spikes
contrib_volc = lr.predict(predictors={"aod_obs": aod_obs})
# merge the lowess trend wit the volc contribution
gt = gt_lowess + contrib_volc.values.squeeze()
return coef_saod, gt