Skip to content

Commit

Permalink
eric review
Browse files Browse the repository at this point in the history
  • Loading branch information
alexrockhill committed Mar 31, 2022
1 parent 9424143 commit 3f6c46d
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 105 deletions.
40 changes: 10 additions & 30 deletions mne_bids/dig.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,29 +685,19 @@ def convert_montage_to_ras(montage, subject, subjects_dir=None, verbose=None):
'formatted, T1.mgz not found')
T1 = nib.load(T1_fname)

# scale from mm to m
scale_t = np.eye(4)
scale_t[:3, :3] *= 1000
scale_trans = mne.transforms.Transform(fro='mri', to='mri', trans=scale_t)
montage.apply_trans(scale_trans)

# transform from "mri" (Freesurfer surface RAS) to "ras" (scanner RAS)
mri_vox_t = np.linalg.inv(T1.header.get_vox2ras_tkr())
mri_vox_t[:3, :3] *= 1000 # scale from mm to m
mri_vox_trans = mne.transforms.Transform(
fro='mri', to='mri_voxel', trans=mri_vox_t)
montage.apply_trans(mri_vox_trans) # mri->vox

vox_ras_t = T1.header.get_vox2ras()
vox_ras_t[:3] /= 1000 # scale from mm to m
vox_ras_trans = mne.transforms.Transform(
fro='mri_voxel', to='ras', trans=vox_ras_t)
montage.apply_trans(vox_ras_trans) # vox->ras

# finally, need to put back in m
scale_inv_t = np.eye(4)
scale_inv_t[:3, :3] /= 1000
scale_inv_trans = mne.transforms.Transform(
fro='ras', to='ras', trans=scale_inv_t)
montage.apply_trans(scale_inv_trans)
montage.apply_trans( # mri->vox + vox->ras = mri->ras
mne.transforms.combine_transforms(mri_vox_trans, vox_ras_trans,
fro='mri', to='ras'))


@verbose
Expand Down Expand Up @@ -740,26 +730,16 @@ def convert_montage_to_mri(montage, subject, subjects_dir=None, verbose=None):
'formatted, T1.mgz not found')
T1 = nib.load(T1_fname)

# scale from mm to m
scale_t = np.eye(4)
scale_t[:3, :3] *= 1000
scale_trans = mne.transforms.Transform(fro='ras', to='ras', trans=scale_t)
montage.apply_trans(scale_trans)

# transform from "ras" (scanner RAS) to "mri" (Freesurfer surface RAS)
ras_vox_t = T1.header.get_ras2vox()
ras_vox_t[:3, :3] *= 1000 # scale from mm to m
ras_vox_trans = mne.transforms.Transform(
fro='ras', to='mri_voxel', trans=ras_vox_t)
montage.apply_trans(ras_vox_trans) # ras->vox

vox_mri_t = T1.header.get_vox2ras_tkr()
vox_mri_t[:3] /= 1000 # scale from mm to m
vox_mri_trans = mne.transforms.Transform(
fro='mri_voxel', to='mri', trans=vox_mri_t)
montage.apply_trans(vox_mri_trans) # vox->mri

# finally, need to put back in m
scale_inv_t = np.eye(4)
scale_inv_t[:3, :3] /= 1000
scale_inv_trans = mne.transforms.Transform(
fro='mri', to='mri', trans=scale_inv_t)
montage.apply_trans(scale_inv_trans)
montage.apply_trans( # ras->vox + vox->mri = ras->mri
mne.transforms.combine_transforms(ras_vox_trans, vox_mri_trans,
fro='ras', to='mri'))
158 changes: 83 additions & 75 deletions mne_bids/tests/test_dig.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,48 +35,52 @@
task=task)

data_path = testing.data_path()
raw_fname = op.join(data_path, 'MEG', 'sample',
'sample_audvis_trunc_raw.fif')
trans = mne.read_trans(op.join(data_path, 'MEG', 'sample',
'sample_audvis_trunc-trans.fif'))
raw = mne.io.read_raw(raw_fname)
raw.drop_channels(raw.info['bads'])
raw.info['line_freq'] = 60
montage = raw.get_montage()


def _load_raw():
"""Load the sample raw data."""
raw_fname = op.join(data_path, 'MEG', 'sample',
'sample_audvis_trunc_raw.fif')
raw = mne.io.read_raw(raw_fname)
raw.drop_channels(raw.info['bads'])
raw.info['line_freq'] = 60
return raw


def test_dig_io(tmp_path):
"""Test passing different coordinate frames give proper warnings."""
bids_root = tmp_path / 'bids1'
raw_test = raw.copy()
raw = _load_raw()
for datatype in ('eeg', 'ieeg'):
os.makedirs(op.join(bids_root, 'sub-01', 'ses-01', datatype))

# test no coordinate frame in dig or in bids_path.space
mnt = montage.copy()
mnt.apply_trans(mne.transforms.Transform('head', 'unknown'))
montage = raw.get_montage()
montage.apply_trans(mne.transforms.Transform('head', 'unknown'))
for datatype in ('eeg', 'ieeg'):
bids_path = _bids_path.copy().update(root=bids_root, datatype=datatype,
space=None)
with pytest.warns(RuntimeWarning,
match='Coordinate frame could not be inferred'):
_write_dig_bids(bids_path, raw_test, mnt, acpc_aligned=True)
_write_dig_bids(bids_path, raw, montage, acpc_aligned=True)

# test coordinate frame-BIDSPath.space mismatch
mnt = montage.copy()
raw = _load_raw()
montage = raw.get_montage()
print(montage.get_positions()['coord_frame'])
bids_path = _bids_path.copy().update(
root=bids_root, datatype='eeg', space='fsaverage')
with pytest.raises(ValueError, match='Coordinates in the raw object '
'or montage are in the CapTrak '
'coordinate frame but '
'BIDSPath.space is fsaverage'):
_write_dig_bids(bids_path, raw_test, mnt)
_write_dig_bids(bids_path, raw, montage)

# test MEG space conflict fif (ElektaNeuromag) != CTF
bids_path = _bids_path.copy().update(
root=bids_root, datatype='meg', space='CTF')
with pytest.raises(ValueError, match='conflicts'):
write_raw_bids(raw_test, bids_path)
write_raw_bids(raw, bids_path)


def test_dig_pixels(tmp_path):
Expand All @@ -88,28 +92,28 @@ def test_dig_pixels(tmp_path):
root=bids_root, datatype='ieeg', space='Pixels')
os.makedirs(op.join(bids_root, 'sub-01', 'ses-01', bids_path.datatype),
exist_ok=True)
raw_test = raw.copy()
raw_test.pick_types(eeg=True)
raw_test.del_proj()
raw_test.set_channel_types({ch: 'ecog' for ch in raw_test.ch_names})
raw = _load_raw()
raw.pick_types(eeg=True)
raw.del_proj()
raw.set_channel_types({ch: 'ecog' for ch in raw.ch_names})

mnt = raw_test.get_montage()
montage = raw.get_montage()
# fake transform to pixel coordinates
mnt.apply_trans(mne.transforms.Transform('head', 'unknown'))
_write_dig_bids(bids_path, raw_test, mnt)
montage.apply_trans(mne.transforms.Transform('head', 'unknown'))
_write_dig_bids(bids_path, raw, montage)
electrodes_path = bids_path.copy().update(
task=None, run=None, suffix='electrodes', extension='.tsv')
coordsystem_path = bids_path.copy().update(
task=None, run=None, suffix='coordsystem', extension='.json')
with pytest.warns(RuntimeWarning,
match='not an MNE-Python coordinate frame'):
_read_dig_bids(electrodes_path, coordsystem_path,
bids_path.datatype, raw_test)
mnt2 = raw_test.get_montage()
assert mnt2.get_positions()['coord_frame'] == 'unknown'
bids_path.datatype, raw)
montage2 = raw.get_montage()
assert montage2.get_positions()['coord_frame'] == 'unknown'
assert_almost_equal(
np.array(list(mnt.get_positions()['ch_pos'].values())),
np.array(list(mnt2.get_positions()['ch_pos'].values()))
np.array(list(montage.get_positions()['ch_pos'].values())),
np.array(list(montage2.get_positions()['ch_pos'].values()))
)


Expand All @@ -120,21 +124,22 @@ def test_dig_template(tmp_path):
for datatype in ('eeg', 'ieeg'):
(bids_root / 'sub-01' / 'ses-01' / datatype).mkdir(parents=True)

raw_test = raw.copy().pick_types(eeg=True)

for datatype in ('eeg', 'ieeg'):
bids_path = _bids_path.copy().update(root=bids_root, datatype=datatype)
for coord_frame in BIDS_STANDARD_TEMPLATE_COORDINATE_SYSTEMS:
raw = _load_raw()
raw.pick_types(eeg=True)
bids_path.update(space=coord_frame)
mnt = montage.copy()
pos = mnt.get_positions()
montage = raw.get_montage()
pos = montage.get_positions()
mne_coord_frame = BIDS_TO_MNE_FRAMES.get(coord_frame, None)
if mne_coord_frame is None:
mnt.apply_trans(mne.transforms.Transform('head', 'unknown'))
montage.apply_trans(
mne.transforms.Transform('head', 'unknown'))
else:
mnt.apply_trans(mne.transforms.Transform(
montage.apply_trans(mne.transforms.Transform(
'head', mne_coord_frame))
_write_dig_bids(bids_path, raw_test, mnt, acpc_aligned=True)
_write_dig_bids(bids_path, raw, montage, acpc_aligned=True)
electrodes_path = bids_path.copy().update(
task=None, run=None, suffix='electrodes', extension='.tsv')
coordsystem_path = bids_path.copy().update(
Expand All @@ -143,15 +148,15 @@ def test_dig_template(tmp_path):
with pytest.warns(RuntimeWarning,
match='not an MNE-Python coordinate frame'):
_read_dig_bids(electrodes_path, coordsystem_path,
datatype, raw_test)
datatype, raw)
else:
if coord_frame == 'MNI305': # saved to fsaverage, same
electrodes_path.update(space='fsaverage')
coordsystem_path.update(space='fsaverage')
_read_dig_bids(electrodes_path, coordsystem_path,
datatype, raw_test)
mnt2 = raw_test.get_montage()
pos2 = mnt2.get_positions()
datatype, raw)
montage2 = raw.get_montage()
pos2 = montage2.get_positions()
np.testing.assert_array_almost_equal(
np.array(list(pos['ch_pos'].values())),
np.array(list(pos2['ch_pos'].values())))
Expand All @@ -161,13 +166,13 @@ def test_dig_template(tmp_path):
assert pos2['coord_frame'] == mne_coord_frame

# test MEG
raw_test = raw.copy()
raw = _load_raw()
for coord_frame in BIDS_STANDARD_TEMPLATE_COORDINATE_SYSTEMS:
bids_path = _bids_path.copy().update(root=bids_root, datatype='meg',
space=coord_frame)
write_raw_bids(raw_test, bids_path)
raw_test2 = read_raw_bids(bids_path)
for ch, ch2 in zip(raw.info['chs'], raw_test2.info['chs']):
write_raw_bids(raw, bids_path)
raw2 = read_raw_bids(bids_path)
for ch, ch2 in zip(raw.info['chs'], raw2.info['chs']):
np.testing.assert_array_equal(ch['loc'], ch2['loc'])
assert ch['coord_frame'] == ch2['coord_frame']

Expand Down Expand Up @@ -201,49 +206,50 @@ def _test_montage_trans(raw, montage, pos_test, space='fsaverage',
def test_template_to_head():
"""Test transforming a template montage to head."""
# test no montage
raw_test = raw.copy()
raw_test.set_montage(None)
raw = _load_raw()
raw.set_montage(None)
with pytest.raises(RuntimeError, match='No montage found'):
template_to_head(raw_test.info, 'fsaverage', coord_frame='auto')
template_to_head(raw.info, 'fsaverage', coord_frame='auto')

# test no channels
raw = _load_raw()
montage_empty = mne.channels.make_dig_montage(hsp=[[0, 0, 0]])
_set_montage_no_trans(raw_test, montage_empty)
_set_montage_no_trans(raw, montage_empty)
with pytest.raises(RuntimeError, match='No channel locations '
'found in the montage'):
template_to_head(raw_test.info, 'fsaverage', coord_frame='auto')
template_to_head(raw.info, 'fsaverage', coord_frame='auto')

# test unexpected coordinate frame
raw_test = raw.copy()
raw = _load_raw()
with pytest.raises(RuntimeError, match='not expected for a template'):
template_to_head(raw_test.info, 'fsaverage', coord_frame='auto')
template_to_head(raw.info, 'fsaverage', coord_frame='auto')

# test all coordinate frames
raw_test = raw.copy()
raw_test.set_montage(None)
raw_test.pick_types(eeg=True)
raw_test.drop_channels(raw_test.ch_names[3:])
raw = _load_raw()
raw.set_montage(None)
raw.pick_types(eeg=True)
raw.drop_channels(raw.ch_names[3:])
montage = mne.channels.make_dig_montage(
ch_pos={raw_test.ch_names[0]: [0, 0, 0],
raw_test.ch_names[1]: [0, 0, 0.1],
raw_test.ch_names[2]: [0, 0, 0.2]},
ch_pos={raw.ch_names[0]: [0, 0, 0],
raw.ch_names[1]: [0, 0, 0.1],
raw.ch_names[2]: [0, 0, 0.2]},
coord_frame='unknown')
for space in BIDS_STANDARD_TEMPLATE_COORDINATE_SYSTEMS:
for cf in ('mri', 'mri_voxel', 'ras'):
_set_montage_no_trans(raw_test, montage)
trans = template_to_head(raw_test.info, space, cf)[1]
_set_montage_no_trans(raw, montage)
trans = template_to_head(raw.info, space, cf)[1]
assert trans['from'] == MNE_STR_TO_FRAME['head']
assert trans['to'] == MNE_STR_TO_FRAME['mri']
montage_test = raw_test.get_montage()
montage_test = raw.get_montage()
pos = montage_test.get_positions()
assert pos['coord_frame'] == 'head'
assert pos['nasion'] is not None
assert pos['lpa'] is not None
assert pos['rpa'] is not None

# test that we get the right transform
_set_montage_no_trans(raw_test, montage)
trans = template_to_head(raw_test.info, 'fsaverage', 'mri')[1]
_set_montage_no_trans(raw, montage)
trans = template_to_head(raw.info, 'fsaverage', 'mri')[1]
trans2 = mne.read_trans(op.join(
op.dirname(op.dirname(mne_bids.__file__)), 'mne_bids', 'data',
'space-fsaverage_trans.fif'))
Expand All @@ -253,46 +259,48 @@ def test_template_to_head():

# test auto voxels
montage_vox = mne.channels.make_dig_montage(
ch_pos={raw_test.ch_names[0]: [2, 0, 10],
raw_test.ch_names[1]: [0, 0, 5.5],
raw_test.ch_names[2]: [0, 1, 3]},
ch_pos={raw.ch_names[0]: [2, 0, 10],
raw.ch_names[1]: [0, 0, 5.5],
raw.ch_names[2]: [0, 1, 3]},
coord_frame='unknown')
pos_test = np.array([[0.126, -0.118, 0.128],
[0.128, -0.1225, 0.128],
[0.128, -0.125, 0.127]])
_test_montage_trans(raw_test, montage_vox, pos_test,
_test_montage_trans(raw, montage_vox, pos_test,
coord_frame='auto', unit='mm')

# now negative values => scanner RAS
montage_ras = mne.channels.make_dig_montage(
ch_pos={raw_test.ch_names[0]: [-30.2, 20, -40],
raw_test.ch_names[1]: [10, 30, 53.5],
raw_test.ch_names[2]: [30, -21, 33]},
ch_pos={raw.ch_names[0]: [-30.2, 20, -40],
raw.ch_names[1]: [10, 30, 53.5],
raw.ch_names[2]: [30, -21, 33]},
coord_frame='unknown')
pos_test = np.array([[-0.0302, 0.02, -0.04],
[0.01, 0.03, 0.0535],
[0.03, -0.021, 0.033]])
_set_montage_no_trans(raw_test, montage_ras)
_test_montage_trans(raw_test, montage_ras, pos_test,
_set_montage_no_trans(raw, montage_ras)
_test_montage_trans(raw, montage_ras, pos_test,
coord_frame='auto', unit='mm')

# test auto unit
montage_mm = montage_ras.copy()
_set_montage_no_trans(raw_test, montage_mm)
_test_montage_trans(raw_test, montage_mm, pos_test,
_set_montage_no_trans(raw, montage_mm)
_test_montage_trans(raw, montage_mm, pos_test,
coord_frame='ras', unit='auto')

montage_m = montage_ras.copy()
for d in montage_m.dig:
d['r'] = np.array(d['r']) / 1000
_test_montage_trans(raw_test, montage_m, pos_test,
_test_montage_trans(raw, montage_m, pos_test,
coord_frame='ras', unit='auto')


def test_convert_montage():
"""Test the montage RAS conversion."""
raw_test = raw.copy()
montage = raw_test.get_montage()
raw = _load_raw()
montage = raw.get_montage()
trans = mne.read_trans(op.join(data_path, 'MEG', 'sample',
'sample_audvis_trunc-trans.fif'))
montage.apply_trans(trans)

subjects_dir = op.join(data_path, 'subjects')
Expand Down

0 comments on commit 3f6c46d

Please sign in to comment.