Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Remi-Gau committed Aug 30, 2023
1 parent a7a331c commit ce51036
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 37 deletions.
11 changes: 11 additions & 0 deletions nilearn/glm/tests/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,17 @@ def _modulated_block_paradigm():
return events


def _spm_paradigm(block_duration):
frame_times = np.linspace(0, 99, 100)
conditions = ["c0", "c0", "c0", "c1", "c1", "c1", "c2", "c2", "c2"]
onsets = [30, 50, 70, 10, 30, 80, 30, 40, 60]
durations = block_duration * np.ones(len(onsets))
events = pd.DataFrame(
{"trial_type": conditions, "onset": onsets, "duration": durations}
)
return events, frame_times


def _design_with_null_duration():
durations = _durations()
durations[2] = 0
Expand Down
48 changes: 11 additions & 37 deletions nilearn/glm/tests/test_dmtx.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
_block_paradigm,
_modulated_block_paradigm,
_modulated_event_paradigm,
_spm_paradigm,
)

# load the spm file to test cosine basis
Expand Down Expand Up @@ -145,18 +146,10 @@ def test_design_matrix_regressors_provided_manually_errors(
)


def test_convolve_regressors():
def test_convolve_regressors(frame_times):
# tests for convolve_regressors helper function
conditions = ["c0", "c1"]
onsets = [20, 40]
duration = [1, 1]
events = pd.DataFrame(
{"trial_type": conditions, "onset": onsets, "duration": duration}
)
# names not passed -> default names
frame_times = np.arange(100)
f, names = _convolve_regressors(events, "glover", frame_times)
assert names == ["c0", "c1"]
_, names = _convolve_regressors(basic_paradigm(), "glover", frame_times)
assert names == ["c0", "c1", "c2"]


def test_design_matrix_basic_paradigm_glover_hrf(frame_times):
Expand Down Expand Up @@ -287,7 +280,7 @@ def test_design_matrix_FIR_column_1_3_and_11(frame_times):
assert_array_almost_equal(X[onset + 4, 11], np.ones(3))


def test_design_matrix_FIR_time_shift(frame_times, n_frames):
def test_design_matrix_FIR_time_shift(frame_times):
# Check that the first column of FIR design matrix is OK after a 1/2
# time shift
tr = 1.0
Expand Down Expand Up @@ -453,37 +446,18 @@ def test_csv_io(tmp_path, frame_times):
assert names == names_


def test_spm_1():
@pytest.mark.parametrize(
"block_duration, array", [(1, "arr_0"), (10, "arr_1")]
)
def test_compare_design_matrix_to_spm(block_duration, array):
# Check that the nistats design matrix is close enough to the SPM one
# (it cannot be identical, because the hrf shape is different)
frame_times = np.linspace(0, 99, 100)
conditions = ["c0", "c0", "c0", "c1", "c1", "c1", "c2", "c2", "c2"]
onsets = [30, 50, 70, 10, 30, 80, 30, 40, 60]
durations = 1 * np.ones(9)
events = pd.DataFrame(
{"trial_type": conditions, "onset": onsets, "duration": durations}
)
events, frame_times = _spm_paradigm(block_duration=block_duration)
X1 = make_first_level_design_matrix(frame_times, events, drift_model=None)
_, matrix, _ = check_design_matrix(X1)
spm_design_matrix = DESIGN_MATRIX["arr_0"]
assert ((spm_design_matrix - matrix) ** 2).sum() / (
spm_design_matrix**2
).sum() < 0.1

spm_design_matrix = DESIGN_MATRIX[array]

def test_spm_2():
# Check that the nistats design matrix is close enough to the SPM one
# (it cannot be identical, because the hrf shape is different)
frame_times = np.linspace(0, 99, 100)
conditions = ["c0", "c0", "c0", "c1", "c1", "c1", "c2", "c2", "c2"]
onsets = [30, 50, 70, 10, 30, 80, 30, 40, 60]
durations = 10 * np.ones(9)
events = pd.DataFrame(
{"trial_type": conditions, "onset": onsets, "duration": durations}
)
X1 = make_first_level_design_matrix(frame_times, events, drift_model=None)
spm_design_matrix = DESIGN_MATRIX["arr_1"]
_, matrix, _ = check_design_matrix(X1)
assert ((spm_design_matrix - matrix) ** 2).sum() / (
spm_design_matrix**2
).sum() < 0.1
Expand Down

0 comments on commit ce51036

Please sign in to comment.