-
Notifications
You must be signed in to change notification settings - Fork 842
/
func_utils.py
386 lines (303 loc) · 15 KB
/
func_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
# coding: utf-8
# Copyright (c) Pymatgen Development Team.
# Distributed under the terms of the MIT License.
"""
This module contains some utility functions and classes that are used in the chemenv package.
"""
__author__ = "David Waroquiers"
__copyright__ = "Copyright 2012, The Materials Project"
__credits__ = "Geoffroy Hautier"
__version__ = "2.0"
__maintainer__ = "David Waroquiers"
__email__ = "david.waroquiers@gmail.com"
__date__ = "Feb 20, 2016"
from typing import Dict
import numpy as np
from pymatgen.analysis.chemenv.utils.math_utils import (
power2_decreasing_exp,
power2_inverse_decreasing,
power2_inverse_power2_decreasing,
smootherstep,
smoothstep,
)
class AbstractRatioFunction:
"""
Abstract class for all ratio functions
"""
ALLOWED_FUNCTIONS = {} # type: Dict[str, list]
def __init__(self, function, options_dict=None):
"""Constructor for AbstractRatioFunction
:param function: Ration function name.
:param options_dict: Dictionary containing the parameters for the ratio function.
"""
if function not in self.ALLOWED_FUNCTIONS:
raise ValueError(
'Function "{}" is not allowed in RatioFunction of '
'type "{}"'.format(function, self.__class__.__name__)
)
self.eval = object.__getattribute__(self, function)
self.function = function
self.setup_parameters(options_dict=options_dict)
def setup_parameters(self, options_dict):
"""Set up the parameters for this ratio function.
:param options_dict: Dictionary containing the parameters for the ratio function.
:return: None.
"""
function_options = self.ALLOWED_FUNCTIONS[self.function]
if len(function_options) > 0:
# Check if there are missing options
if options_dict is None:
missing_options = True
else:
missing_options = False
for op in function_options:
if op not in options_dict:
missing_options = True
break
# If there are missing options, raise an error
if missing_options:
if len(function_options) == 1:
opts = 'Option "{}"'.format(function_options[0])
else:
opts1 = ", ".join(['"{}"'.format(op) for op in function_options[:-1]])
opts = "Options {}".format(" and ".join([opts1, '"{}"'.format(function_options[-1])]))
if options_dict is None or len(options_dict) == 0:
missing = "no option was provided."
else:
optgiven = list(options_dict.keys())
if len(options_dict) == 1:
missing = 'only "{}" was provided.'.format(optgiven[0])
else:
missing1 = ", ".join(['"{}"'.format(miss) for miss in optgiven[:-1]])
missing = "only {} were provided.".format(" and ".join([missing1, '"{}"'.format(optgiven[-1])]))
raise ValueError(
'{} should be provided for function "{}" in RatioFunction of '
'type "{}" while {}'.format(opts, self.function, self.__class__.__name__, missing)
)
# Setup the options and raise an error if a wrong option is provided
for key, val in options_dict.items():
if key not in function_options:
raise ValueError(
'Option "{}" not allowed for function "{}" in RatioFunction of '
'type "{}"'.format(key, self.function, self.__class__.__name__)
)
self.__setattr__(key, val)
def evaluate(self, value):
"""Evaluate the ratio function for the given value.
:param value: Value for which ratio function has to be evaluated.
:return: Ratio function corresponding to the value.
"""
return self.eval(value)
@classmethod
def from_dict(cls, dd):
"""Construct ratio function from dict.
:param dd: Dict representation of the ratio function
:return: Ratio function object.
"""
return cls(function=dd["function"], options_dict=dd["options"])
class RatioFunction(AbstractRatioFunction):
"""Concrete implementation of a series of ratio functions."""
ALLOWED_FUNCTIONS = {
"power2_decreasing_exp": ["max", "alpha"],
"smoothstep": ["lower", "upper"],
"smootherstep": ["lower", "upper"],
"inverse_smoothstep": ["lower", "upper"],
"inverse_smootherstep": ["lower", "upper"],
"power2_inverse_decreasing": ["max"],
"power2_inverse_power2_decreasing": ["max"],
}
def power2_decreasing_exp(self, vals):
"""Get the evaluation of the ratio function f(x)=exp(-a*x)*(x-1)^2.
The values (i.e. "x"), are scaled to the "max" parameter. The "a" constant
correspond to the "alpha" parameter.
:param vals: Values for which the ratio function has to be evaluated.
:return: Result of the ratio function applied to the values.
"""
return power2_decreasing_exp(vals, edges=[0.0, self.__dict__["max"]], alpha=self.__dict__["alpha"])
def smootherstep(self, vals):
"""Get the evaluation of the smootherstep ratio function: f(x)=6*x^5-15*x^4+10*x^3.
The values (i.e. "x"), are scaled between the "lower" and "upper" parameters.
:param vals: Values for which the ratio function has to be evaluated.
:return: Result of the ratio function applied to the values.
"""
return smootherstep(vals, edges=[self.__dict__["lower"], self.__dict__["upper"]])
def smoothstep(self, vals):
"""Get the evaluation of the smoothstep ratio function: f(x)=3*x^2-2*x^3.
The values (i.e. "x"), are scaled between the "lower" and "upper" parameters.
:param vals: Values for which the ratio function has to be evaluated.
:return: Result of the ratio function applied to the values.
"""
return smoothstep(vals, edges=[self.__dict__["lower"], self.__dict__["upper"]])
def inverse_smootherstep(self, vals):
"""Get the evaluation of the "inverse" smootherstep ratio function: f(x)=1-(6*x^5-15*x^4+10*x^3).
The values (i.e. "x"), are scaled between the "lower" and "upper" parameters.
:param vals: Values for which the ratio function has to be evaluated.
:return: Result of the ratio function applied to the values.
"""
return smootherstep(vals, edges=[self.__dict__["lower"], self.__dict__["upper"]], inverse=True)
def inverse_smoothstep(self, vals):
"""Get the evaluation of the "inverse" smoothstep ratio function: f(x)=1-(3*x^2-2*x^3).
The values (i.e. "x"), are scaled between the "lower" and "upper" parameters.
:param vals: Values for which the ratio function has to be evaluated.
:return: Result of the ratio function applied to the values.
"""
return smoothstep(vals, edges=[self.__dict__["lower"], self.__dict__["upper"]], inverse=True)
def power2_inverse_decreasing(self, vals):
"""Get the evaluation of the ratio function f(x)=(x-1)^2 / x.
The values (i.e. "x"), are scaled to the "max" parameter.
:param vals: Values for which the ratio function has to be evaluated.
:return: Result of the ratio function applied to the values.
"""
return power2_inverse_decreasing(vals, edges=[0.0, self.__dict__["max"]])
def power2_inverse_power2_decreasing(self, vals):
"""Get the evaluation of the ratio function f(x)=(x-1)^2 / x^2.
The values (i.e. "x"), are scaled to the "max" parameter.
:param vals: Values for which the ratio function has to be evaluated.
:return: Result of the ratio function applied to the values.
"""
return power2_inverse_power2_decreasing(vals, edges=[0.0, self.__dict__["max"]])
class CSMFiniteRatioFunction(AbstractRatioFunction):
"""Concrete implementation of a series of ratio functions applied to the continuous symmetry measure (CSM).
Uses "finite" ratio functions.
See the following reference for details:
ChemEnv: a fast and robust coordination environment identification tool,
D. Waroquiers et al., Acta Cryst. B 76, 683 (2020).
"""
ALLOWED_FUNCTIONS = {
"power2_decreasing_exp": ["max_csm", "alpha"],
"smoothstep": ["lower_csm", "upper_csm"],
"smootherstep": ["lower_csm", "upper_csm"],
}
def power2_decreasing_exp(self, vals):
"""Get the evaluation of the ratio function f(x)=exp(-a*x)*(x-1)^2.
The CSM values (i.e. "x"), are scaled to the "max_csm" parameter. The "a" constant
correspond to the "alpha" parameter.
:param vals: CSM values for which the ratio function has to be evaluated.
:return: Result of the ratio function applied to the CSM values.
"""
return power2_decreasing_exp(vals, edges=[0.0, self.__dict__["max_csm"]], alpha=self.__dict__["alpha"])
def smootherstep(self, vals):
"""Get the evaluation of the smootherstep ratio function: f(x)=6*x^5-15*x^4+10*x^3.
The CSM values (i.e. "x"), are scaled between the "lower_csm" and "upper_csm" parameters.
:param vals: CSM values for which the ratio function has to be evaluated.
:return: Result of the ratio function applied to the CSM values.
"""
return smootherstep(
vals,
edges=[self.__dict__["lower_csm"], self.__dict__["upper_csm"]],
inverse=True,
)
def smoothstep(self, vals):
"""Get the evaluation of the smoothstep ratio function: f(x)=3*x^2-2*x^3.
The CSM values (i.e. "x"), are scaled between the "lower_csm" and "upper_csm" parameters.
:param vals: CSM values for which the ratio function has to be evaluated.
:return: Result of the ratio function applied to the CSM values.
"""
return smootherstep(
vals,
edges=[self.__dict__["lower_csm"], self.__dict__["upper_csm"]],
inverse=True,
)
def fractions(self, data):
"""Get the fractions from the CSM ratio function applied to the data.
:param data: List of CSM values to estimate fractions.
:return: Corresponding fractions for each CSM.
"""
if len(data) == 0:
return None
total = np.sum([self.eval(dd) for dd in data])
if total > 0.0:
return [self.eval(dd) / total for dd in data]
return None
def mean_estimator(self, data):
"""Get the weighted CSM using this CSM ratio function applied to the data.
:param data: List of CSM values to estimate the weighted CSM.
:return: Weighted CSM from this ratio function.
"""
if len(data) == 0:
return None
if len(data) == 1:
return data[0]
fractions = self.fractions(data)
if fractions is None:
return None
return np.sum(np.array(fractions) * np.array(data))
ratios = fractions
class CSMInfiniteRatioFunction(AbstractRatioFunction):
"""Concrete implementation of a series of ratio functions applied to the continuous symmetry measure (CSM).
Uses "infinite" ratio functions.
See the following reference for details:
ChemEnv: a fast and robust coordination environment identification tool,
D. Waroquiers et al., Acta Cryst. B 76, 683 (2020).
"""
ALLOWED_FUNCTIONS = {
"power2_inverse_decreasing": ["max_csm"],
"power2_inverse_power2_decreasing": ["max_csm"],
}
def power2_inverse_decreasing(self, vals):
"""Get the evaluation of the ratio function f(x)=(x-1)^2 / x.
The CSM values (i.e. "x"), are scaled to the "max_csm" parameter. The "a" constant
correspond to the "alpha" parameter.
:param vals: CSM values for which the ratio function has to be evaluated.
:return: Result of the ratio function applied to the CSM values.
"""
return power2_inverse_decreasing(vals, edges=[0.0, self.__dict__["max_csm"]])
def power2_inverse_power2_decreasing(self, vals):
"""Get the evaluation of the ratio function f(x)=(x-1)^2 / x^2.
The CSM values (i.e. "x"), are scaled to the "max_csm" parameter. The "a" constant
correspond to the "alpha" parameter.
:param vals: CSM values for which the ratio function has to be evaluated.
:return: Result of the ratio function applied to the CSM values.
"""
return power2_inverse_power2_decreasing(vals, edges=[0.0, self.__dict__["max_csm"]])
def fractions(self, data):
"""Get the fractions from the CSM ratio function applied to the data.
:param data: List of CSM values to estimate fractions.
:return: Corresponding fractions for each CSM.
"""
if len(data) == 0:
return None
close_to_zero = np.isclose(data, 0.0, atol=1e-10).tolist()
nzeros = close_to_zero.count(True)
if nzeros == 1:
fractions = [0.0] * len(data)
fractions[close_to_zero.index(True)] = 1.0
return fractions
if nzeros > 1:
raise RuntimeError("Should not have more than one continuous symmetry measure with value equal to 0.0")
fractions = self.eval(np.array(data))
total = np.sum(fractions)
if total > 0.0:
return fractions / total
return None
def mean_estimator(self, data):
"""Get the weighted CSM using this CSM ratio function applied to the data.
:param data: List of CSM values to estimate the weighted CSM.
:return: Weighted CSM from this ratio function.
"""
if len(data) == 0:
return None
if len(data) == 1:
return data[0]
fractions = self.fractions(data)
if fractions is None:
return None
return np.sum(np.array(fractions) * np.array(data))
ratios = fractions
class DeltaCSMRatioFunction(AbstractRatioFunction):
"""
Concrete implementation of a series of ratio functions applied to differences of
continuous symmetry measures (DeltaCSM).
Uses "finite" ratio functions.
See the following reference for details:
ChemEnv: a fast and robust coordination environment identification tool,
D. Waroquiers et al., Acta Cryst. B 76, 683 (2020).
"""
ALLOWED_FUNCTIONS = {"smootherstep": ["delta_csm_min", "delta_csm_max"]}
def smootherstep(self, vals):
"""Get the evaluation of the smootherstep ratio function: f(x)=6*x^5-15*x^4+10*x^3.
The DeltaCSM values (i.e. "x"), are scaled between the "delta_csm_min" and "delta_csm_max" parameters.
:param vals: DeltaCSM values for which the ratio function has to be evaluated.
:return: Result of the ratio function applied to the DeltaCSM values.
"""
return smootherstep(vals, edges=[self.__dict__["delta_csm_min"], self.__dict__["delta_csm_max"]])