Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for reference channel specification #75

Merged
merged 15 commits into from Aug 14, 2021
Merged
4 changes: 4 additions & 0 deletions docs/changelog.rst
Expand Up @@ -29,6 +29,10 @@ Here we list a changelog of pybv.
Current (unreleased)
====================

Changelog
~~~~~~~~~

- :func:`pybv.write_brainvision` gained a new parameter, ``ref_ch_names``, to specify the reference channels used during recording, by `Richard Höchenberger`_ and `Stefan Appelhoff`_ (:gh:`75`)

API
~~~
Expand Down
73 changes: 62 additions & 11 deletions pybv/io.py
Expand Up @@ -11,7 +11,6 @@
#
# License: BSD-3-Clause

import codecs
import datetime
import os
import shutil
Expand All @@ -35,7 +34,10 @@
}


def write_brainvision(*, data, sfreq, ch_names, fname_base, folder_out,
def write_brainvision(*, data, sfreq, ch_names,
ref_ch_names=None,
fname_base,
folder_out,
overwrite=False,
events=None,
resolution=0.1,
Expand All @@ -53,7 +55,17 @@ def write_brainvision(*, data, sfreq, ch_names, fname_base, folder_out,
sfreq : int | float
The sampling frequency of the data.
ch_names : list of {str | int}, len (n_channels)
The name of each channel.
The names of the channels.
ref_ch_names : str | list of str, len (n_channels) | None
The name of the channel used as a reference during the recording. If
references differed between channels, you may supply a list of
reference channel names corresponding to each channel in ``ch_names``.
If ``None`` (default), assume that all channels are referenced to a
common channel that is not further specified (BrainVision default).

.. note:: The reference channel name specified here does not need to
appear in ``ch_names``. It is permissible to specify a
reference channel that is not present in ``data``.
fname_base : str
The base name for the output files. Three files will be created
(.vhdr, .vmrk, .eeg) and all will share this base name.
Expand Down Expand Up @@ -161,6 +173,38 @@ def write_brainvision(*, data, sfreq, ch_names, fname_base, folder_out,
if len(set(ch_names)) != nchan:
raise ValueError("Channel names must be unique, found duplicate name.")

# Ensure we have a list of strings as reference channel names
if ref_ch_names is None:
ref_ch_names = [''] * nchan # common but unspecified reference
elif isinstance(ref_ch_names, str):
ref_ch_names = [ref_ch_names] * nchan
else:
if "" in ref_ch_names:
msg = (f"ref_ch_names contains an empty string: {ref_ch_names}\n"
f"Empty strings are reserved values and not permitted "
f"as reference channel names.")
raise ValueError(msg)
ref_ch_names = [str(ref_ch_name) for ref_ch_name in ref_ch_names]

if len(ref_ch_names) != nchan:
raise ValueError(
f'The number of reference channel names ({len(ref_ch_names)})'
f'must match the number of channels in your data ({nchan})'
)

# ensure ref chs that are in data are zero
for ref_ch_name in list(set(ref_ch_names) & set(ch_names)):
if not np.allclose(data[ch_names.index(ref_ch_name), :], 0):
raise ValueError(
f"The provided data for the reference channel "
f"{ref_ch_name} does not appear to be zero across "
f"all time points. This indicates that this channel "
f"either did not serve as a reference during the recording, "
f"or the data has been altered since. Please either pick a "
f"different reference channel, or omit the "
f"ref_ch_name parameter."
)

if not isinstance(sfreq, (int, float)):
raise ValueError("sfreq must be one of (float | int)")
sfreq = float(sfreq)
Expand Down Expand Up @@ -230,9 +274,11 @@ def write_brainvision(*, data, sfreq, ch_names, fname_base, folder_out,
_write_bveeg_file(eeg_fname, data, orientation='multiplexed',
format=fmt, resolution=resolution, units=units)
_write_vmrk_file(vmrk_fname, eeg_fname, events, meas_date)
_write_vhdr_file(vhdr_fname, vmrk_fname, eeg_fname, data, sfreq,
ch_names, orientation='multiplexed', format=fmt,
resolution=resolution, units=units)
_write_vhdr_file(vhdr_fname=vhdr_fname, vmrk_fname=vmrk_fname,
eeg_fname=eeg_fname, data=data, sfreq=sfreq,
ch_names=ch_names, ref_ch_names=ref_ch_names,
orientation='multiplexed',
format=fmt, resolution=resolution, units=units)
except ValueError:
if folder_out_created:
# if this is a new folder, remove everything
Expand Down Expand Up @@ -266,7 +312,7 @@ def _chk_multiplexed(orientation):

def _write_vmrk_file(vmrk_fname, eeg_fname, events, meas_date):
"""Write BrainvVision marker file."""
with codecs.open(vmrk_fname, 'w', encoding='utf-8') as fout:
with open(vmrk_fname, 'w', encoding='utf-8') as fout:
print('Brain Vision Data Exchange Marker File, Version 1.0', file=fout)
print(f';Exported using pybv {__version__}', file=fout)
print('', file=fout)
Expand Down Expand Up @@ -348,14 +394,15 @@ def _scale_data_to_unit(data, units):
return data * scales


def _write_vhdr_file(vhdr_fname, vmrk_fname, eeg_fname, data, sfreq, ch_names,
orientation, format, resolution, units):
def _write_vhdr_file(*, vhdr_fname, vmrk_fname, eeg_fname, data, sfreq,
ch_names, ref_ch_names, orientation, format, resolution,
units):
"""Write BrainvVision header file."""
bvfmt, _ = _chk_fmt(format)

multiplexed = _chk_multiplexed(orientation)

with codecs.open(vhdr_fname, 'w', encoding='utf-8') as fout:
with open(vhdr_fname, 'w', encoding='utf-8') as fout:
print('Brain Vision Data Exchange Header File Version 1.0', file=fout)
print(f'; Written using pybv {__version__}', file=fout)
print('', file=fout)
Expand Down Expand Up @@ -392,10 +439,14 @@ def _write_vhdr_file(vhdr_fname, vmrk_fname, eeg_fname, data, sfreq, ch_names,
resolutions = resolution * np.ones((nchan,))

for i in range(nchan):
# take care of commas in the channel names
_ch_name = ch_names[i].replace(',', r'\1')
_ref_ch_name = ref_ch_names[i].replace(',', r'\1')

resolution = np.format_float_positional(resolutions[i], trim="-")
unit = units[i]
print(f'Ch{i + 1}={_ch_name},,{resolution},{unit}', file=fout)
print(f'Ch{i + 1}={_ch_name},{_ref_ch_name},{resolution},{unit}',
file=fout)

print('', file=fout)
print('[Comment]', file=fout)
Expand Down
79 changes: 68 additions & 11 deletions pybv/tests/test_bv_writer.py
Expand Up @@ -13,6 +13,7 @@

import os
import os.path as op
import re
from datetime import datetime, timezone

import mne
Expand All @@ -38,6 +39,10 @@
# scale random data to reasonable EEG signal magnitude in V
data = rng.randn(n_chans, n_times) * 10 * 1e-6

# add reference channel
ref_ch_name = ch_names[-1]
data[-1, :] = 0.


@pytest.mark.parametrize(
"events_errormsg",
Expand Down Expand Up @@ -95,6 +100,24 @@ def test_bv_writer_inputs(tmpdir):
with pytest.raises(ValueError, match='overwrite must be a boolean'):
write_brainvision(data=data[1:, :], sfreq=sfreq, ch_names=ch_names,
fname_base=fname, folder_out=tmpdir, overwrite=1)
with pytest.raises(ValueError, match='number of reference channel names'):
write_brainvision(data=data, sfreq=sfreq, ch_names=ch_names,
ref_ch_names=['foo', 'bar'], fname_base=fname,
folder_out=tmpdir)
# Passing data that's not all-zero for a reference channel should raise
# an exception
data_ = data.copy()
data_[ch_names.index(ref_ch_name), :] = 5
with pytest.raises(ValueError, match='reference channel.*not.*zero'):
write_brainvision(data=data_, sfreq=sfreq, ch_names=ch_names,
ref_ch_names=ref_ch_name, fname_base=fname,
folder_out=tmpdir)
# Empty str is a reserved value for ref_ch_names
with pytest.raises(ValueError, match='Empty strings are reserved values'):
_ref_ch_names = [""] + ch_names[1:]
write_brainvision(data=data_, sfreq=sfreq, ch_names=ch_names,
ref_ch_names=_ref_ch_names, fname_base=fname,
folder_out=tmpdir)


def test_bv_bad_format(tmpdir):
Expand All @@ -104,13 +127,16 @@ def test_bv_bad_format(tmpdir):
eeg_fname = tmpdir / fname + ".eeg"

with pytest.raises(ValueError, match='Orientation bad not supported'):
_write_vhdr_file(vhdr_fname, vmrk_fname, eeg_fname, data=data,
sfreq=sfreq, ch_names=ch_names, orientation='bad',
_write_vhdr_file(vhdr_fname=vhdr_fname, vmrk_fname=vmrk_fname,
eeg_fname=eeg_fname, data=data,
sfreq=sfreq, ch_names=ch_names, ref_ch_names=None,
orientation='bad',
format="binary_float32", resolution=1e-6,
units=["V"] * n_chans)
with pytest.raises(ValueError, match='Data format bad not supported'):
_write_vhdr_file(vhdr_fname, vmrk_fname, eeg_fname, data=data,
sfreq=sfreq, ch_names=ch_names,
_write_vhdr_file(vhdr_fname=vhdr_fname, vmrk_fname=vmrk_fname,
eeg_fname=eeg_fname, data=data,
sfreq=sfreq, ch_names=ch_names, ref_ch_names=None,
orientation='multiplexed', format="bad",
resolution=1e-6,
units=["V"] * n_chans)
Expand Down Expand Up @@ -156,24 +182,30 @@ def test_comma_in_ch_name(tmpdir, ch_names_tricky):
assert_allclose(data, raw_written._data) # data round-trip


@pytest.mark.parametrize("meas_date",
[('20000101120000000000'),
(datetime(2000, 1, 1, 12, 0, 0, 0))])
def test_write_read_cycle(tmpdir, meas_date):
@pytest.mark.parametrize(
"meas_date, ref_ch_names",
[
('20000101120000000000', ref_ch_name),
(datetime(2000, 1, 1, 12, 0, 0, 0), None)
]
)
def test_write_read_cycle(tmpdir, meas_date, ref_ch_names):
"""Test that a write/read cycle produces identical data."""
# First fail writing due to wrong unit
unsupported_unit = "rV"
with pytest.warns(UserWarning, match='Encountered unsupported '
'non-voltage unit'):
write_brainvision(data=data, sfreq=sfreq, ch_names=ch_names,
fname_base=fname, folder_out=tmpdir,
unit=unsupported_unit, overwrite=True)
ref_ch_names=ref_ch_names, fname_base=fname,
folder_out=tmpdir, unit=unsupported_unit,
overwrite=True)

# write and read data to BV format
# ensure that greek small letter mu gets converted to micro sign
with pytest.warns(UserWarning, match="Encountered small Greek letter mu"):
write_brainvision(data=data, sfreq=sfreq, ch_names=ch_names,
fname_base=fname, folder_out=tmpdir, events=events,
ref_ch_names=ref_ch_names, fname_base=fname,
folder_out=tmpdir, events=events,
resolution=np.power(10., -np.arange(10)),
unit='μV', meas_date=meas_date, overwrite=True)
vhdr_fname = tmpdir / fname + '.vhdr'
Expand Down Expand Up @@ -336,6 +368,31 @@ def test_write_unsupported_units(tmpdir):
assert orig_units[-1] == '°C'


@pytest.mark.parametrize(
'ref_ch_names', (
None,
ref_ch_name,
[ref_ch_name] * n_chans,
'foobar'
)
)
def test_ref_ch(tmpdir, ref_ch_names):
"""Test reference channel writing."""
# these are the default values
resolution = '0.1'
unit = 'µV'
vhdr_fname = tmpdir / fname + '.vhdr'

write_brainvision(data=data, sfreq=sfreq, ch_names=ch_names,
ref_ch_names=ref_ch_name, fname_base=fname,
folder_out=tmpdir)

vhdr = vhdr_fname.read_text(encoding='utf-8')
regexp = f'Ch.*=ch.*,{ref_ch_name},{resolution},{unit}'
matches = re.findall(pattern=regexp, string=vhdr)
assert len(matches) == len(ch_names)


def test_cleanup(tmpdir):
"""Test cleaning up intermediate data upon a writing failure."""
folder_out = tmpdir / "my_output"
Expand Down