Skip to content

Commit

Permalink
Merge pull request #256 from nlesc-nano/err_func
Browse files Browse the repository at this point in the history
ENH: Add error functions based on non-mse errors
  • Loading branch information
BvB93 committed Oct 21, 2021
2 parents 562ab45 + 6d733a2 commit 3c893ee
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 1 deletion.
4 changes: 3 additions & 1 deletion FOX/armc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
mse_normalized_max,
mse_normalized_v2,
mse_normalized_weighted_v2,
err_normalized,
err_normalized_weighted,
)

__all__ = [
Expand All @@ -27,5 +29,5 @@
'PackageManager', 'PackageManagerABC',
'ParamMapping', 'ParamMappingABC',
'default_error_func', 'mse_normalized', 'mse_normalized_weighted', 'mse_normalized_max',
'mse_normalized_v2', 'mse_normalized_weighted_v2',
'mse_normalized_v2', 'mse_normalized_weighted_v2', 'err_normalized', 'err_normalized_weighted',
]
44 changes: 44 additions & 0 deletions FOX/armc/err_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
.. autofunction:: mse_normalized_max
.. autofunction:: mse_normalized_v2
.. autofunction:: mse_normalized_weighted_v2
.. autofunction:: err_normalized
.. autofunction:: err_normalized_weighted
.. data:: default_error_func
:value: FOX.armc.mse_normalized
Expand All @@ -41,6 +43,8 @@
"default_error_func",
"mse_normalized_v2",
"mse_normalized_weighted_v2",
"err_normalized",
"err_normalized_weighted",
]


Expand Down Expand Up @@ -152,4 +156,44 @@ def mse_normalized_weighted_v2(qm: ArrayLike, mm: ArrayLike) -> f8:
return (err_vec**2).sum() / err_vec.size


def err_normalized(qm: ArrayLike, mm: ArrayLike) -> f8:
"""Return a normalized wrror over the flattened input.
Normalize before taking the exponent - 1 of the error.
"""
qm = np.asarray(qm, dtype=np.float64)
mm = np.asarray(mm, dtype=np.float64)

delta = np.abs(qm - mm)
delta /= np.abs(qm).sum()
return delta.sum()


def err_normalized_weighted(qm: ArrayLike, mm: ArrayLike) -> f8:
"""Return a normalized error over the flattened subarrays of the input.
>1D array-likes are herein treated as stacks of flattened arrays.
"""
if isinstance(qm, pd.DataFrame):
qm = np.asarray(qm, dtype=np.float64).T
else:
qm = np.array(qm, dtype=np.float64, ndmin=1, copy=False)

if isinstance(mm, pd.DataFrame):
mm = np.asarray(mm, dtype=np.float64).T
else:
mm = np.array(mm, dtype=np.float64, ndmin=1, copy=False)

axes_qm_mm = tuple(range(1, max(qm.ndim, mm.ndim)))
axes_qm = tuple(range(1, qm.ndim))
padding_qm = len(axes_qm) * (None,)

delta = np.abs(qm - mm)
delta /= np.abs(qm).sum(axis=axes_qm)[(..., *padding_qm)]
err_vec = delta.sum(axis=axes_qm_mm)
return (err_vec**2).sum() / err_vec.size


default_error_func = mse_normalized

0 comments on commit 3c893ee

Please sign in to comment.