Skip to content

Commit 2e5391d

Browse files
authored
feat: add set_extrapolation_info function in morph.Morph (#255)
* feat: add `checkExtrapolation` function in `morph.Morph` * refactor: rename the function to `set_extrapolation_info`
1 parent ad976f6 commit 2e5391d

File tree

7 files changed

+124
-59
lines changed

7 files changed

+124
-59
lines changed

news/extrap-warnings.rst

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
**Added:**
2+
3+
* Enable ``diffpy.morph`` to detect extrapolation.
4+
5+
**Changed:**
6+
7+
* <news item>
8+
9+
**Deprecated:**
10+
11+
* <news item>
12+
13+
**Removed:**
14+
15+
* <news item>
16+
17+
**Fixed:**
18+
19+
* <news item>
20+
21+
**Security:**
22+
23+
* <news item>

src/diffpy/morph/morph_io.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -410,29 +410,35 @@ def tabulate_results(multiple_morph_results):
410410

411411
def handle_warnings(squeeze_morph):
412412
if squeeze_morph is not None:
413-
eil = squeeze_morph.extrap_index_low
414-
eih = squeeze_morph.extrap_index_high
415-
416-
if eil is not None or eih is not None:
417-
if eih is None:
418-
wmsg = (
419-
"Warning: points with grid value below "
420-
f"{squeeze_morph.squeeze_cutoff_low} "
421-
f"will be extrapolated."
422-
)
423-
elif eil is None:
424-
wmsg = (
425-
"Warning: points with grid value above "
426-
f"{squeeze_morph.squeeze_cutoff_high} "
427-
f"will be extrapolated."
428-
)
429-
else:
430-
wmsg = (
431-
"Warning: points with grid value below "
432-
f"{squeeze_morph.squeeze_cutoff_low} and above "
433-
f"{squeeze_morph.squeeze_cutoff_high} "
434-
f"will be extrapolated."
435-
)
413+
extrapolation_info = squeeze_morph.extrapolation_info
414+
is_extrap_low = extrapolation_info["is_extrap_low"]
415+
is_extrap_high = extrapolation_info["is_extrap_high"]
416+
cutoff_low = extrapolation_info["cutoff_low"]
417+
cutoff_high = extrapolation_info["cutoff_high"]
418+
419+
if is_extrap_low and is_extrap_high:
420+
wmsg = (
421+
"Warning: points with grid value below "
422+
f"{cutoff_low} and above "
423+
f"{cutoff_high} "
424+
f"are extrapolated."
425+
)
426+
elif is_extrap_low:
427+
wmsg = (
428+
"Warning: points with grid value below "
429+
f"{cutoff_low} "
430+
f"are extrapolated."
431+
)
432+
elif is_extrap_high:
433+
wmsg = (
434+
"Warning: points with grid value above "
435+
f"{cutoff_high} "
436+
f"are extrapolated."
437+
)
438+
else:
439+
wmsg = None
440+
441+
if wmsg:
436442
warnings.warn(
437443
wmsg,
438444
UserWarning,

src/diffpy/morph/morphapp.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -610,10 +610,12 @@ def single_morph(
610610
config["smear"] = smear_in
611611
# Shift
612612
# Only enable hshift is squeeze is not enabled
613+
shift_morph = None
613614
if (
614615
opts.hshift is not None and squeeze_poly_deg < 0
615616
) or opts.vshift is not None:
616-
chain.append(morphs.MorphShift())
617+
shift_morph = morphs.MorphShift()
618+
chain.append(shift_morph)
617619
if opts.hshift is not None and squeeze_poly_deg < 0:
618620
hshift_in = opts.hshift
619621
config["hshift"] = hshift_in
@@ -700,6 +702,7 @@ def single_morph(
700702

701703
# THROW ANY WARNINGS HERE
702704
io.handle_warnings(squeeze_morph)
705+
io.handle_warnings(shift_morph)
703706

704707
# Get Rw for the morph range
705708
rw = tools.getRw(chain)

src/diffpy/morph/morphs/morph.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,8 @@
1212
# See LICENSE.txt for license information.
1313
#
1414
##############################################################################
15-
"""Morph -- base class for defining a morph.
16-
"""
17-
15+
"""Morph -- base class for defining a morph."""
16+
import numpy
1817

1918
LABEL_RA = "r (A)" # r-grid
2019
LABEL_GR = "G (1/A^2)" # PDF G(r)
@@ -246,6 +245,36 @@ def plotOutputs(self, xylabels=True, **plotargs):
246245
ylabel(self.youtlabel)
247246
return rv
248247

248+
def set_extrapolation_info(self, x_true, x_extrapolate):
249+
"""Set extrapolation information of the concerned morphing
250+
process.
251+
252+
Parameters
253+
----------
254+
x_true : array
255+
original x values
256+
x_extrapolate : array
257+
x values after a morphing process
258+
"""
259+
260+
cutoff_low = min(x_true)
261+
extrap_low_x = numpy.where(x_extrapolate < cutoff_low)[0]
262+
is_extrap_low = False if len(extrap_low_x) == 0 else True
263+
cutoff_high = max(x_true)
264+
extrap_high_x = numpy.where(x_extrapolate > cutoff_high)[0]
265+
is_extrap_high = False if len(extrap_high_x) == 0 else True
266+
extrap_index_low = extrap_low_x[-1] if is_extrap_low else 0
267+
extrap_index_high = extrap_high_x[0] if is_extrap_high else -1
268+
extrapolation_info = {
269+
"is_extrap_low": is_extrap_low,
270+
"cutoff_low": cutoff_low,
271+
"extrap_index_low": extrap_index_low,
272+
"is_extrap_high": is_extrap_high,
273+
"cutoff_high": cutoff_high,
274+
"extrap_index_high": extrap_index_high,
275+
}
276+
self.extrapolation_info = extrapolation_info
277+
249278
def __getattr__(self, name):
250279
"""Obtain the value from self.config, when normal lookup fails.
251280

src/diffpy/morph/morphs/morphshift.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def morph(self, x_morph, y_morph, x_target, y_target):
5757
r = self.x_morph_in - hshift
5858
self.y_morph_out = numpy.interp(r, self.x_morph_in, self.y_morph_in)
5959
self.y_morph_out += vshift
60+
self.set_extrapolation_info(self.x_morph_in, r)
6061
return self.xyallout
6162

6263

src/diffpy/morph/morphs/morphsqueeze.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Class MorphSqueeze -- Apply a polynomial to squeeze the morph
22
function."""
33

4-
import numpy as np
54
from numpy.polynomial import Polynomial
65
from scipy.interpolate import CubicSpline
76

@@ -83,14 +82,9 @@ def morph(self, x_morph, y_morph, x_target, y_target):
8382
coeffs = [self.squeeze[f"a{i}"] for i in range(len(self.squeeze))]
8483
squeeze_polynomial = Polynomial(coeffs)
8584
x_squeezed = self.x_morph_in + squeeze_polynomial(self.x_morph_in)
86-
self.squeeze_cutoff_low = min(x_squeezed)
87-
self.squeeze_cutoff_high = max(x_squeezed)
8885
self.y_morph_out = CubicSpline(x_squeezed, self.y_morph_in)(
8986
self.x_morph_in
9087
)
91-
low_extrap = np.where(self.x_morph_in < self.squeeze_cutoff_low)[0]
92-
high_extrap = np.where(self.x_morph_in > self.squeeze_cutoff_high)[0]
93-
self.extrap_index_low = low_extrap[-1] if low_extrap.size else None
94-
self.extrap_index_high = high_extrap[0] if high_extrap.size else None
88+
self.set_extrapolation_info(x_squeezed, self.x_morph_in)
9589

9690
return self.xyallout

tests/test_morphsqueeze.py

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -46,47 +46,56 @@
4646
@pytest.mark.parametrize("squeeze_coeffs", squeeze_coeffs_dic)
4747
def test_morphsqueeze(x_morph, x_target, squeeze_coeffs):
4848
y_target = np.sin(x_target)
49+
y_morph = np.sin(x_morph)
50+
# expected output
51+
y_morph_expected = y_morph
52+
x_morph_expected = x_morph
53+
x_target_expected = x_target
54+
y_target_expected = y_target
55+
# actual output
4956
coeffs = [squeeze_coeffs[f"a{i}"] for i in range(len(squeeze_coeffs))]
5057
squeeze_polynomial = Polynomial(coeffs)
5158
x_squeezed = x_morph + squeeze_polynomial(x_morph)
5259
y_morph = np.sin(x_squeezed)
53-
low_extrap = np.where(x_morph < x_squeezed[0])[0]
54-
high_extrap = np.where(x_morph > x_squeezed[-1])[0]
55-
extrap_index_low_expected = low_extrap[-1] if low_extrap.size else None
56-
extrap_index_high_expected = high_extrap[0] if high_extrap.size else None
57-
x_morph_expected = x_morph
58-
y_morph_expected = np.sin(x_morph)
5960
morph = MorphSqueeze()
6061
morph.squeeze = squeeze_coeffs
6162
x_morph_actual, y_morph_actual, x_target_actual, y_target_actual = morph(
6263
x_morph, y_morph, x_target, y_target
6364
)
64-
extrap_index_low = morph.extrap_index_low
65-
extrap_index_high = morph.extrap_index_high
66-
if extrap_index_low is None:
67-
extrap_index_low = 0
68-
elif extrap_index_high is None:
69-
extrap_index_high = -1
65+
66+
extrap_low = np.where(x_morph < min(x_squeezed))[0]
67+
extrap_high = np.where(x_morph > max(x_squeezed))[0]
68+
extrap_index_low_expected = extrap_low[-1] if extrap_low.size else 0
69+
extrap_index_high_expected = extrap_high[0] if extrap_high.size else -1
70+
71+
extrapolation_info = morph.extrapolation_info
72+
extrap_index_low_actual = extrapolation_info["extrap_index_low"]
73+
extrap_index_high_actual = extrapolation_info["extrap_index_high"]
74+
7075
assert np.allclose(
71-
y_morph_actual[extrap_index_low + 1 : extrap_index_high],
72-
y_morph_expected[extrap_index_low + 1 : extrap_index_high],
76+
y_morph_actual[
77+
extrap_index_low_expected + 1 : extrap_index_high_expected
78+
],
79+
y_morph_expected[
80+
extrap_index_low_expected + 1 : extrap_index_high_expected
81+
],
7382
atol=1e-6,
7483
)
7584
assert np.allclose(
76-
y_morph_actual[:extrap_index_low],
77-
y_morph_expected[:extrap_index_low],
85+
y_morph_actual[:extrap_index_low_expected],
86+
y_morph_expected[:extrap_index_low_expected],
7887
atol=1e-3,
7988
)
8089
assert np.allclose(
81-
y_morph_actual[extrap_index_high:],
82-
y_morph_expected[extrap_index_high:],
90+
y_morph_actual[extrap_index_high_expected:],
91+
y_morph_expected[extrap_index_high_expected:],
8392
atol=1e-3,
8493
)
85-
assert morph.extrap_index_low == extrap_index_low_expected
86-
assert morph.extrap_index_high == extrap_index_high_expected
8794
assert np.allclose(x_morph_actual, x_morph_expected)
88-
assert np.allclose(x_target_actual, x_target)
89-
assert np.allclose(y_target_actual, y_target)
95+
assert np.allclose(x_target_actual, x_target_expected)
96+
assert np.allclose(y_target_actual, y_target_expected)
97+
assert extrap_index_low_actual == extrap_index_low_expected
98+
assert extrap_index_high_actual == extrap_index_high_expected
9099

91100

92101
@pytest.mark.parametrize(
@@ -97,23 +106,23 @@ def test_morphsqueeze(x_morph, x_target, squeeze_coeffs):
97106
{"a0": 0.01},
98107
lambda x: (
99108
"Warning: points with grid value below "
100-
f"{x[0]} will be extrapolated."
109+
f"{x[0]} are extrapolated."
101110
),
102111
),
103112
# extrapolate above
104113
(
105114
{"a0": -0.01},
106115
lambda x: (
107116
"Warning: points with grid value above "
108-
f"{x[1]} will be extrapolated."
117+
f"{x[1]} are extrapolated."
109118
),
110119
),
111120
# extrapolate below and above
112121
(
113122
{"a0": 0.01, "a1": -0.002},
114123
lambda x: (
115124
"Warning: points with grid value below "
116-
f"{x[0]} and above {x[1]} will be "
125+
f"{x[0]} and above {x[1]} are "
117126
"extrapolated."
118127
),
119128
),

0 commit comments

Comments
 (0)