Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Remi-Gau committed Sep 1, 2023
1 parent 4a23b6b commit a6b848a
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 41 deletions.
133 changes: 94 additions & 39 deletions nilearn/glm/first_level/experimental_paradigm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
import pandas as pd
from pandas.api.types import is_numeric_dtype

VALID_FIELDS = {"onset", "duration", "trial_type", "modulation"}


def check_events(events):
"""Test that the events data describes a valid experimental paradigm.
It is valid if the events data has an 'onset' key.
It is valid if the events data has ``'onset'`` and ``'duration'`` keys
with numeric non NaN values.
This function also handles duplicate events
by summing their modulation if they have one.
Parameters
----------
Expand All @@ -47,6 +49,34 @@ def check_events(events):
Per-event modulation, (in seconds)
defaults to ones(n_events) when no duration is provided.
Raises
------
TypeError
If the events data is not a pandas DataFrame.
ValueError
If the events data has:
- no ``'onset'`` or ``'duration'`` column,
- has non numeric values
in the ``'onset'`` or ``'duration'`` columns
- has nan values in the ``'onset'`` or ``'duration'`` columns.
Warns
-----
UserWarning
If the events data:
- has no ``'trial_type'`` column,
- has any event with a duration equal to 0,
- contains columns other than ``'onset'``, ``'duration'``,
``'trial_type'`` or ``'modulation'``,
- contains duplicated events, meaning event with same:
- ``'trial_type'``
- ``'onset'``
- ``'duration'``
"""
# Check that events is a Pandas DataFrame
if not isinstance(events, pd.DataFrame):
Expand All @@ -64,79 +94,104 @@ def check_events(events):
raise ValueError(
f"The following column must not contain nan values: {col_name}"
)
# Make sure we have a numeric type for duration
if not is_numeric_dtype(events[col_name]):
try:
events = events.astype({col_name: float})
except ValueError as e:
raise ValueError(
f"Could not cast {col_name} to float in events data."
) from e

# Make a copy of the dataframe
events_copy = events.copy()

# Handle missing trial types
if "trial_type" not in events_copy.columns:
events_copy = _handle_missing_trial_types(events_copy)

_check_null_duration(events_copy)

_check_unexpected_columns(events_copy)

events_copy = _handle_modulation(events_copy)

cleaned_events = _handle_duplicate_events(events_copy)

trial_type = cleaned_events["trial_type"].values
onset = cleaned_events["onset"].values
duration = cleaned_events["duration"].values
modulation = cleaned_events["modulation"].values
return trial_type, onset, duration, modulation


def _handle_missing_trial_types(events):
if "trial_type" not in events.columns:
warnings.warn(
"'trial_type' column not found in the given events data."
)
events_copy["trial_type"] = "dummy"
events["trial_type"] = "dummy"
return events

conditions_with_null_duration = events_copy["trial_type"][
events_copy["duration"] == 0

def _check_null_duration(events):
conditions_with_null_duration = events["trial_type"][
events["duration"] == 0
].unique()
if len(conditions_with_null_duration) > 0:
warnings.warn(
"The following conditions contain events with null duration:\n"
f"{', '.join(conditions_with_null_duration)}."
)

# Handle modulation
if "modulation" in events_copy.columns:

def _handle_modulation(events):
if "modulation" in events.columns:
print(
"A 'modulation' column was found in "
"the given events data and is used."
)
else:
events_copy["modulation"] = 1
events["modulation"] = 1
return events


VALID_FIELDS = {"onset", "duration", "trial_type", "modulation"}


def _check_unexpected_columns(events):
# Warn for each unexpected column that will
# not be used afterwards
unexpected_columns = set(events_copy.columns).difference(VALID_FIELDS)
for unexpected_column in unexpected_columns:
unexpected_columns = list(set(events.columns).difference(VALID_FIELDS))
if unexpected_columns:
warnings.warn(
f"Unexpected column '{unexpected_column}' in events data. "
"It will be ignored."
"The following unexpected columns "
"in events data will be ignored: "
f"{', '.join(unexpected_columns)}"
)

# Make sure we have a numeric type for duration
if not is_numeric_dtype(events_copy["duration"]):
try:
events_copy = events_copy.astype({"duration": float})
except ValueError:
raise ValueError(
"Could not cast duration to float in events data."
)

# Handle duplicate events
# Two events are duplicates if they have the same:
# - trial type
# - onset
COLUMN_DEFINING_EVENT_IDENTITY = ["trial_type", "onset", "duration"]
# Two events are duplicates if they have the same:
# - trial type
# - onset
# - duration
COLUMN_DEFINING_EVENT_IDENTITY = ["trial_type", "onset", "duration"]

# Duplicate handling strategy
# Sum the modulation values of duplicate events
STRATEGY = {"modulation": "sum"}
# Duplicate handling strategy
# Sum the modulation values of duplicate events
STRATEGY = {"modulation": "sum"}


def _handle_duplicate_events(events):
cleaned_events = (
events_copy.groupby(COLUMN_DEFINING_EVENT_IDENTITY, sort=False)
events.groupby(COLUMN_DEFINING_EVENT_IDENTITY, sort=False)
.agg(STRATEGY)
.reset_index()
)

# If there are duplicates, give a warning
if len(cleaned_events) != len(events_copy):
if len(cleaned_events) != len(events):
warnings.warn(
"Duplicated events were detected. "
"Amplitudes of these events will be summed. "
"You might want to verify your inputs."
)

trial_type = cleaned_events["trial_type"].values
onset = cleaned_events["onset"].values
duration = cleaned_events["duration"].values
modulation = cleaned_events["modulation"].values
return trial_type, onset, duration, modulation
return cleaned_events
7 changes: 6 additions & 1 deletion nilearn/glm/tests/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,11 @@ def duplicate_events_paradigm():
onsets = [10, 30, 70, 70, 10, 30]
durations = [1.0, 1.0, 1.0, 1.0, 1.0, 1]
events = pd.DataFrame(
{"trial_type": conditions, "onset": onsets, "duration": durations}
{
"trial_type": conditions,
"onset": onsets,
"duration": durations,
"modulation": np.ones(len(onsets)),
}
)
return events
6 changes: 5 additions & 1 deletion nilearn/glm/tests/test_paradigm.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,11 @@ def test_check_events_warnings():
# An unexpected field is provided
events["foo"] = np.zeros(len(events))
with pytest.warns(
UserWarning, match="Unexpected column 'foo' in events data."
UserWarning,
match=(
"The following unexpected columns "
"in events data will be ignored: foo"
),
):
trial_type2, onset2, duration2, modulation2 = check_events(events)

Expand Down

0 comments on commit a6b848a

Please sign in to comment.