Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/changes/newsfragments/429.change
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
:func:`.read_transform` has a new parameter ``nan_policy`` to handle NaN values when transforming feature data by `Fede Raimondo`_
1 change: 1 addition & 0 deletions docs/changes/newsfragments/429.enh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Implement a NaN policy for :func:`.read_transform` by `Fede Raimondo`_
74 changes: 70 additions & 4 deletions junifer/onthefly/read_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from typing import Optional

import numpy as np
import pandas as pd

from ..typing import StorageLike
Expand All @@ -19,6 +20,7 @@ def read_transform(
transform: str,
feature_name: Optional[str] = None,
feature_md5: Optional[str] = None,
nan_policy: Optional[str] = "bypass",
transform_args: Optional[tuple] = None,
transform_kw_args: Optional[dict] = None,
) -> pd.DataFrame:
Expand All @@ -35,6 +37,16 @@ def read_transform(
Name of the feature to read (default None).
feature_md5 : str, optional
MD5 hash of the feature to read (default None).
nan_policy : str, optional
The policy to handle NaN values (default "ignore").
Options are:

* "bypass": Do nothing and pass NaN values to the transform function.
* "drop_element": Drop (skip) elements with NaN values.
* "drop_rows": Drop (skip) rows with NaN values.
* "drop_columns": Drop (skip) columns with NaN values.
* "drop_symmetric": Drop (skip) symmetric pairs with NaN values.

transform_args : tuple, optional
The positional arguments for the callable of ``transform``
(default None).
Expand All @@ -47,6 +59,18 @@ def read_transform(
pandas.DataFrame
The transformed feature as a dataframe.

Raises
------
ValueError
If ``nan_policy`` is invalid or
if *package* is invalid.
RuntimeError
If *package* is ``bctpy`` and stored data kind is not ``matrix``.
ImportError
If ``bctpy`` cannot be imported.
AttributeError
If *function* to be invoked in invalid.

Notes
-----
This function has been only tested for:
Expand All @@ -63,6 +87,18 @@ def read_transform(
transform_args = transform_args or ()
transform_kw_args = transform_kw_args or {}

if nan_policy not in [
"bypass",
"drop_element",
"drop_rows",
"drop_columns",
"drop_symmetric",
]:
raise_error(
f"Unknown nan_policy: {nan_policy}",
klass=ValueError,
)

# Read storage
stored_data = storage.read(
feature_name=feature_name, feature_md5=feature_md5
Expand Down Expand Up @@ -106,22 +142,52 @@ def read_transform(
except AttributeError as err:
raise_error(msg=str(err), klass=AttributeError)

# Apply function and store subject-wise
# Apply function and store element-wise
output_list = []
element_list = []
logger.debug(
f"Computing '{package}.{func_str}' for feature "
f"{feature_name or feature_md5} ..."
)
for subject in range(stored_data["data"].shape[2]):
for i_element, element in enumerate(stored_data["element"]):
t_data = stored_data["data"][:, :, i_element]
has_nan = np.isnan(np.min(t_data))
if nan_policy == "drop_element" and has_nan:
logger.debug(
f"Skipping element {element} due to NaN values ..."
)
continue
elif nan_policy == "drop_rows" and has_nan:
logger.debug(
f"Skipping rows with NaN values in element {element} ..."
)
t_data = t_data[~np.isnan(t_data).any(axis=1)]
elif nan_policy == "drop_columns" and has_nan:
logger.debug(
f"Skipping columns with NaN values in element {element} "
"..."
)
t_data = t_data[:, ~np.isnan(t_data).any(axis=0)]
elif nan_policy == "drop_symmetric":
logger.debug(
f"Skipping pairs of rows/columns with NaN values in "
f"element {element}..."
)
good_rows = ~np.isnan(t_data).any(axis=1)
good_columns = ~np.isnan(t_data).any(axis=0)
good_idx = np.logical_and(good_rows, good_columns)
t_data = t_data[good_idx][:, good_idx]

output = func(
stored_data["data"][:, :, subject],
t_data,
*transform_args,
**transform_kw_args,
)
output_list.append(output)
element_list.append(element)

# Create dataframe for index
idx_df = pd.DataFrame(data=stored_data["element"])
idx_df = pd.DataFrame(data=element_list)
# Create multiindex from dataframe
logger.debug(
"Generating pandas.MultiIndex for feature "
Expand Down
84 changes: 84 additions & 0 deletions junifer/onthefly/tests/test_read_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,36 @@ def matrix_storage(tmp_path: Path) -> HDF5FeatureStorage:
return storage


@pytest.fixture
def matrix_storage_with_nan(tmp_path: Path) -> HDF5FeatureStorage:
"""Return a HDF5FeatureStorage with matrix data.

Parameters
----------
tmp_path : pathlib.Path
The path to the test directory.

"""
storage = HDF5FeatureStorage(tmp_path / "matrix_store_nan.hdf5")
data = np.arange(36).reshape(3, 3, 4).astype(float)
data[1, 1, 2] = np.nan
data[1, 2, 2] = np.nan
for i in range(4):
storage.store(
kind="matrix",
meta={
"element": {"subject": f"test{i + 1}"},
"dependencies": [],
"marker": {"name": "matrix"},
"type": "BOLD",
},
data=data[:, :, i],
col_names=["f1", "f2", "f3"],
row_names=["g1", "g2", "g3"],
)
return storage


def test_incorrect_package(matrix_storage: HDF5FeatureStorage) -> None:
"""Test error check for incorrect package name.

Expand Down Expand Up @@ -176,3 +206,57 @@ def test_bctpy_function(
)
assert "Computing" in caplog.text
assert "Generating" in caplog.text


@pytest.mark.parametrize(
"nan_policy, error_msg",
[
("drop_element", None),
("drop_rows", "square"),
("drop_columns", "square"),
("drop_symmetric", None),
("bypass", "NaNs"),
("wrong", "Unknown"),
],
)
def test_bctpy_nans(
matrix_storage_with_nan: HDF5FeatureStorage,
caplog: pytest.LogCaptureFixture,
nan_policy: str,
error_msg: str,
) -> None:
"""Test working function of bctpy.

Parameters
----------
matrix_storage_with_nan : HDF5FeatureStorage
The HDF5FeatureStorage with matrix data, as fixture.
caplog : pytest.LogCaptureFixture
The pytest.LogCaptureFixture object.
nan_policy : str
The NAN policy to test.
error_msg : str
The expected error message snippet. If None, no error should be raised.

"""
# Skip test if import fails
pytest.importorskip("bct")

with caplog.at_level(logging.DEBUG):
if error_msg is None:
read_transform(
storage=matrix_storage_with_nan, # type: ignore
feature_name="BOLD_matrix",
transform="bctpy_eigenvector_centrality_und",
nan_policy=nan_policy,
)
assert "Computing" in caplog.text
assert "Generating" in caplog.text
else:
with pytest.raises(ValueError, match=error_msg):
read_transform(
storage=matrix_storage_with_nan, # type: ignore
feature_name="BOLD_matrix",
transform="bctpy_eigenvector_centrality_und",
nan_policy=nan_policy,
)