Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Improve file handling, add overwrite parameter #78

Merged
merged 17 commits into from Aug 7, 2021
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/python_tests.yml
Expand Up @@ -54,7 +54,7 @@ jobs:
mne sys_info

- name: Check formatting
if: "matrix.platform == 'ubuntu-18.04'"
if: ${{ matrix.platform == 'ubuntu-18.04' && matrix.python-version == '3.9' }}
run: make pep

- name: Test with pytest
Expand All @@ -66,14 +66,14 @@ jobs:
make build-doc

- name: Upload artifacts
if: "matrix.platform == 'ubuntu-18.04'"
if: ${{ matrix.platform == 'ubuntu-18.04' && matrix.python-version == '3.9' }}
uses: actions/upload-artifact@v2
with:
name: docs-artifact
path: docs/_build/html

- name: Upload coverage report
if: "matrix.platform == 'ubuntu-18.04'"
if: ${{ matrix.platform == 'ubuntu-18.04' && matrix.python-version == '3.9' }}
uses: codecov/codecov-action@v2
with:
files: ./coverage.xml
5 changes: 4 additions & 1 deletion docs/changelog.rst
Expand Up @@ -29,7 +29,10 @@ Here we list a changelog of pybv.
Current (unreleased)
====================

- nothing so far

API
~~~
- :func:`pybv.write_brainvision` now has an ``overwrite`` parameter that defaults to ``False``, by `Stefan Appelhoff`_ (:gh:`78`).

0.5.0 (2021-01-03)
==================
Expand Down
61 changes: 42 additions & 19 deletions pybv/io.py
Expand Up @@ -14,7 +14,7 @@
import codecs
import datetime
import os
import warnings
import shutil
from os import path as op
from warnings import warn

Expand All @@ -36,6 +36,7 @@


def write_brainvision(*, data, sfreq, ch_names, fname_base, folder_out,
overwrite=False,
events=None, resolution=0.1, unit='µV',
fmt='binary_float32', meas_date=None):
sappelhoff marked this conversation as resolved.
Show resolved Hide resolved
"""Write raw data to BrainVision format [1]_.
Expand All @@ -48,14 +49,16 @@ def write_brainvision(*, data, sfreq, ch_names, fname_base, folder_out,
specified by ``unit``) are never scaled (e.g. ``'°C'``).
sfreq : int | float
The sampling frequency of the data.
ch_names : list of strings, shape (n_channels,)
ch_names : list of str, shape (n_channels,)
sappelhoff marked this conversation as resolved.
Show resolved Hide resolved
The name of each channel.
fname_base : str
The base name for the output files. Three files will be created
(.vhdr, .vmrk, .eeg) and all will share this base name.
folder_out : str
The folder where output files will be saved. Will be created if it does
not exist yet.
overwrite : bool
Whether or not to overwrite existing files. Defaults to False.
events : np.ndarray, shape (n_events, 2) or (n_events, 3) | None
Events to write in the marker file. This array has either two or three
columns. The first column is always the zero-based index of each event
Expand Down Expand Up @@ -125,6 +128,9 @@ def write_brainvision(*, data, sfreq, ch_names, fname_base, folder_out,

"""
# Input checks
if not isinstance(overwrite, bool) and overwrite not in [0, 1]:
sappelhoff marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("overwrite must be a boolean (True or False).")

ev_err = ("events must be an ndarray of shape (n_events, 2) or "
"(n_events, 3) containing numeric values, or None")
if not isinstance(events, (np.ndarray, type(None))):
Expand All @@ -140,6 +146,10 @@ def write_brainvision(*, data, sfreq, ch_names, fname_base, folder_out,
raise ValueError(ev_err)

nchan = len(ch_names)
for ch in ch_names:
if not isinstance(ch, (str, int)):
raise ValueError("ch_names must be a list of str.")
sappelhoff marked this conversation as resolved.
Show resolved Hide resolved
ch_names = [str(ch) for ch in ch_names]

if len(data) != nchan:
raise ValueError(f"Number of channels in data ({len(data)}) does not "
Expand All @@ -162,8 +172,6 @@ def write_brainvision(*, data, sfreq, ch_names, fname_base, folder_out,
if np.any(resolution <= 0):
raise ValueError("Resolution should be > 0")

_chk_fmt(fmt)

# check unit is single str
if isinstance(unit, str):
# convert unit to list, assuming all units are the same
Expand All @@ -184,7 +192,7 @@ def write_brainvision(*, data, sfreq, ch_names, fname_base, folder_out,

# only show the warning once if a greek letter was encountered
if show_warning:
warnings.warn(
warn(
f"Encountered small Greek letter mu 'μ' or 'u' in unit: {unit}. "
f"Converting to micro sign 'µ'."
)
Expand All @@ -202,22 +210,37 @@ def write_brainvision(*, data, sfreq, ch_names, fname_base, folder_out,
'as expected. Please supply a str in the format: '
'"YYYYMMDDhhmmssuuuuuu".')

# Create output file names/paths
# Create output file names/paths, checking if they already exist
folder_out_created = not op.exists(folder_out)
os.makedirs(folder_out, exist_ok=True)

vhdr_fname = op.join(folder_out, fname_base + '.vhdr')
vmrk_fname = op.join(folder_out, fname_base + '.vmrk')
eeg_fname = op.join(folder_out, fname_base + '.eeg')

# Write output files
# NOTE: call _write_bveeg_file first, so that if it raises ValueError,
# no files are written.
_write_bveeg_file(eeg_fname, data, orientation='multiplexed', format=fmt,
resolution=resolution, units=units)
_write_vmrk_file(vmrk_fname, eeg_fname, events, meas_date)
_write_vhdr_file(vhdr_fname, vmrk_fname, eeg_fname, data, sfreq,
ch_names, orientation='multiplexed', format=fmt,
resolution=resolution, units=units)
vmrk_fname = op.join(folder_out, fname_base + '.vmrk')
vhdr_fname = op.join(folder_out, fname_base + '.vhdr')
msg = "File already exists: {}.\nConsider setting overwrite=True."
for fname in (eeg_fname, vmrk_fname, vhdr_fname):
if op.exists(fname) and not overwrite:
raise IOError(msg.format(fname))
sappelhoff marked this conversation as resolved.
Show resolved Hide resolved

# Write output files, but delete everything if we come across an error
try:

_write_bveeg_file(eeg_fname, data, orientation='multiplexed',
format=fmt, resolution=resolution, units=units)
_write_vmrk_file(vmrk_fname, eeg_fname, events, meas_date)
_write_vhdr_file(vhdr_fname, vmrk_fname, eeg_fname, data, sfreq,
ch_names, orientation='multiplexed', format=fmt,
resolution=resolution, units=units)
except ValueError:
if folder_out_created:
# if this is a new folder, remove everything
shutil.rmtree(folder_out)
else:
# else, only remove the files we might have created
for fname in (eeg_fname, vmrk_fname, vhdr_fname):
if op.exists(fname): # pragma: no cover
os.remove(fname)

raise


def _chk_fmt(fmt):
Expand Down
52 changes: 49 additions & 3 deletions pybv/tests/test_bv_writer.py
Expand Up @@ -11,6 +11,8 @@
#
# License: BSD-3-Clause

import os
import os.path as op
from datetime import datetime, timezone

import mne
Expand Down Expand Up @@ -74,6 +76,9 @@ def test_bv_writer_inputs(tmpdir):
with pytest.raises(ValueError, match='Channel names must be unique'):
write_brainvision(data=data[0:2, :], sfreq=sfreq, ch_names=['b', 'b'],
fname_base=fname, folder_out=tmpdir)
with pytest.raises(ValueError, match='ch_names must be a list of str.'):
write_brainvision(data=data[0:2, :], sfreq=sfreq, ch_names=['b', 2.3],
fname_base=fname, folder_out=tmpdir)
with pytest.raises(ValueError, match='sfreq must be one of '):
write_brainvision(data=data, sfreq='100', ch_names=ch_names,
fname_base=fname, folder_out=tmpdir)
Expand All @@ -87,6 +92,9 @@ def test_bv_writer_inputs(tmpdir):
write_brainvision(data=data, sfreq=sfreq, ch_names=ch_names,
fname_base=fname, folder_out=tmpdir,
resolution=np.arange(n_chans-1))
with pytest.raises(ValueError, match='overwrite must be a boolean'):
write_brainvision(data=data[1:, :], sfreq=sfreq, ch_names=ch_names,
fname_base=fname, folder_out=tmpdir, overwrite=2)


def test_bv_bad_format(tmpdir):
Expand Down Expand Up @@ -159,15 +167,15 @@ def test_write_read_cycle(tmpdir, meas_date):
'non-voltage unit'):
write_brainvision(data=data, sfreq=sfreq, ch_names=ch_names,
fname_base=fname, folder_out=tmpdir,
unit=unsupported_unit)
unit=unsupported_unit, overwrite=True)

# write and read data to BV format
# ensure that greek small letter mu gets converted to micro sign
with pytest.warns(UserWarning, match="Encountered small Greek letter mu"):
write_brainvision(data=data, sfreq=sfreq, ch_names=ch_names,
fname_base=fname, folder_out=tmpdir, events=events,
resolution=np.power(10., -np.arange(10)),
unit='μV', meas_date=meas_date)
unit='μV', meas_date=meas_date, overwrite=True)
vhdr_fname = tmpdir / fname + '.vhdr'
raw_written = mne.io.read_raw_brainvision(vhdr_fname=vhdr_fname,
preload=True)
Expand Down Expand Up @@ -288,7 +296,7 @@ def test_write_multiple_units(tmpdir, unit):
# write file and read back in
write_brainvision(data=data, sfreq=sfreq, ch_names=ch_names,
fname_base=fname, folder_out=tmpdir,
unit=units)
unit=units, overwrite=True)
raw_written = mne.io.read_raw_brainvision(vhdr_fname=vhdr_fname,
preload=True)

Expand Down Expand Up @@ -326,3 +334,41 @@ def test_write_unsupported_units(tmpdir):
assert len(set(orig_units)) == 2
assert all([orig_units[idx] == unit for idx in range(n_chans - 1)])
assert orig_units[-1] == '°C'


def test_cleanup(tmpdir):
"""Test cleaning up intermediate data upon a writing failure."""
folder_out = tmpdir / "my_output"
with pytest.raises(ValueError, match="Data format binary_float999"):
write_brainvision(data=data, sfreq=sfreq, ch_names=ch_names,
fname_base=fname, folder_out=folder_out,
fmt="binary_float999")
assert not op.exists(folder_out)
assert not op.exists(folder_out / fname + ".eeg")
assert not op.exists(folder_out / fname + ".vmrk")
assert not op.exists(folder_out / fname + ".vhdr")

# if folder already existed before erroneous writing, it is not deleted
os.makedirs(folder_out)
with pytest.raises(ValueError, match="Data format binary_float999"):
write_brainvision(data=data, sfreq=sfreq, ch_names=ch_names,
fname_base=fname, folder_out=folder_out,
fmt="binary_float999")
assert op.exists(folder_out)

# but all other (incomplete/erroneous) files are deleted
assert not op.exists(folder_out / fname + ".eeg")
assert not op.exists(folder_out / fname + ".vmrk")
assert not op.exists(folder_out / fname + ".vhdr")


def test_overwrite(tmpdir):
"""Test overwriting behavior."""
write_brainvision(data=data, sfreq=sfreq, ch_names=ch_names,
fname_base=fname, folder_out=tmpdir,
overwrite=False)

with pytest.raises(IOError, match="File already exists"):
write_brainvision(data=data, sfreq=sfreq, ch_names=ch_names,
fname_base=fname, folder_out=tmpdir,
overwrite=False)