Skip to content

Commit

Permalink
Merge pull request #255 from nlesc-nano/err_func
Browse files Browse the repository at this point in the history
BUG: Fixed an issue wherein dataframes were summed along the wrong axis
  • Loading branch information
BvB93 committed Oct 21, 2021
2 parents 2942072 + 88c2313 commit 562ab45
Showing 1 changed file with 29 additions and 7 deletions.
36 changes: 29 additions & 7 deletions FOX/armc/err_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from typing import TYPE_CHECKING, overload
import numpy as np
import pandas as pd

if TYPE_CHECKING:
from numpy.typing import ArrayLike, NDArray
Expand Down Expand Up @@ -67,8 +68,15 @@ def mse_normalized_weighted(qm: ArrayLike, mm: ArrayLike) -> f8:
>1D array-likes are herein treated as stacks of flattened arrays.
"""
qm = np.array(qm, dtype=np.float64, ndmin=1, copy=False)
mm = np.array(mm, dtype=np.float64, ndmin=1, copy=False)
if isinstance(qm, pd.DataFrame):
qm = np.array(qm, dtype=np.float64, ndmin=1, copy=False).T
else:
qm = np.array(qm, dtype=np.float64, ndmin=1, copy=False)

if isinstance(mm, pd.DataFrame):
mm = np.array(mm, dtype=np.float64, ndmin=1, copy=False).T
else:
mm = np.array(mm, dtype=np.float64, ndmin=1, copy=False)

axes_qm = tuple(range(1, qm.ndim))
axes_qm_mm = tuple(range(1, max(qm.ndim, mm.ndim)))
Expand All @@ -84,8 +92,15 @@ def mse_normalized_max(qm: ArrayLike, mm: ArrayLike) -> f8:
>1D array-likes are herein treated as stacks of flattened arrays.
""" # noqa: E501
qm = np.array(qm, dtype=np.float64, ndmin=1, copy=False)
mm = np.array(mm, dtype=np.float64, ndmin=1, copy=False)
if isinstance(qm, pd.DataFrame):
qm = np.array(qm, dtype=np.float64, ndmin=1, copy=False).T
else:
qm = np.array(qm, dtype=np.float64, ndmin=1, copy=False)

if isinstance(mm, pd.DataFrame):
mm = np.array(mm, dtype=np.float64, ndmin=1, copy=False).T
else:
mm = np.array(mm, dtype=np.float64, ndmin=1, copy=False)

axes_qm = tuple(range(1, qm.ndim))
axes_qm_mm = tuple(range(1, max(qm.ndim, mm.ndim)))
Expand All @@ -100,7 +115,7 @@ def mse_normalized_v2(qm: ArrayLike, mm: ArrayLike) -> f8:
Normalize before squaring the error.
""" # noqa: E501
"""
qm = np.asarray(qm, dtype=np.float64)
mm = np.asarray(mm, dtype=np.float64)

Expand All @@ -117,8 +132,15 @@ def mse_normalized_weighted_v2(qm: ArrayLike, mm: ArrayLike) -> f8:
Normalize before squaring the error.
"""
qm = np.array(qm, dtype=np.float64, ndmin=1, copy=False)
mm = np.array(mm, dtype=np.float64, ndmin=1, copy=False)
if isinstance(qm, pd.DataFrame):
qm = np.array(qm, dtype=np.float64, ndmin=1, copy=False).T
else:
qm = np.array(qm, dtype=np.float64, ndmin=1, copy=False)

if isinstance(mm, pd.DataFrame):
mm = np.array(mm, dtype=np.float64, ndmin=1, copy=False).T
else:
mm = np.array(mm, dtype=np.float64, ndmin=1, copy=False)

axes_qm_mm = tuple(range(1, max(qm.ndim, mm.ndim)))
axes_qm = tuple(range(1, qm.ndim))
Expand Down

0 comments on commit 562ab45

Please sign in to comment.