Skip to content

Commit

Permalink
ENH load.fiff.variable_length_epochs(): allow variable tmin
Browse files Browse the repository at this point in the history
  • Loading branch information
christianbrodbeck committed Oct 8, 2021
1 parent 462dc67 commit 3f2d2d4
Showing 1 changed file with 23 additions and 14 deletions.
37 changes: 23 additions & 14 deletions eelbrain/_io/fiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,15 +795,15 @@ def sensor_dim(

def variable_length_epochs(
ds: Dataset,
tmin: float,
tmax: Sequence[float] = None,
tmin: Union[float, Sequence[float]],
tmax: Union[float, Sequence[float]] = None,
baseline: BaselineArg = None,
allow_truncation: bool = False,
data: DataArg = None,
exclude: Union[str, Sequence[str]] = 'bads',
sysname: str = None,
connectivity: Union[str, Sequence] = None,
tstop: Sequence[float] = None,
tstop: Union[float, Sequence[float]] = None,
name: str = None,
**kwargs,
) -> List[NDVar]:
Expand Down Expand Up @@ -872,11 +872,11 @@ def variable_length_epochs(

def variable_length_mne_epochs(
ds: Dataset,
tmin: float,
tmax: Sequence[float] = None,
tmin: Union[float, Sequence[float]],
tmax: Union[float, Sequence[float]] = None,
baseline: BaselineArg = None,
allow_truncation: bool = False,
tstop: Sequence[float] = None,
tstop: Union[float, Sequence[float]] = None,
picks: PicksArg = None,
decim: int = 1,
**kwargs,
Expand Down Expand Up @@ -916,24 +916,33 @@ def variable_length_mne_epochs(
if tmax is None:
if tstop is None:
raise TypeError(f"{tmax=}, {tstop=}: must specify at least one")
else:
sfreq = raw.info['sfreq'] / decim
start_index = int(round(tmin * sfreq))
stop_index = [int(round(t * sfreq)) for t in tstop]
tmax = [tmin + (i - start_index - 1) / sfreq for i in stop_index]
n = len(tstop)
else:
n = len(tmax)
if np.isscalar(tmin):
tmin = np.repeat(tmin, n)
else:
tmin = np.asarray(tmin)
if tmax is None:
sfreq = raw.info['sfreq'] / decim
start_index = np.round(tmin * sfreq).astype(int)
stop_index = np.round(np.asarray(tstop) * sfreq).astype(int)
tmax = tmin + (stop_index - start_index - 1) / sfreq
elif np.isscalar(tmax):
tmax = np.repeat(tmax, n)
if picks is None and raw.info['bads']:
picks = mne.pick_types(raw.info, meg=True, eeg=True, eog=True, ref_meg=False, exclude=[])
events = _mne_events(ds)
out = []
for i, tmax_i in enumerate(tmax):
for i, (tmin_i, tmax_i) in enumerate(zip(tmin, tmax)):
i_max = events[i, 0] + floor(tmax_i * raw.info['sfreq'])
if raw.last_samp < i_max:
if allow_truncation:
tmax_i = (raw.last_samp - events[i, 0]) / raw.info['sfreq']
else:
missing = (i_max - raw.last_samp) / raw.info['sfreq']
raise ValueError(f"tmax={tmax}, tmax[{i}] {tmax_i} is outside of data range by {missing:g} s")
epochs_i = mne.Epochs(raw, events[i:i+1], None, tmin, tmax_i, baseline, picks, preload=True, decim=decim, **kwargs)
raise ValueError(f"{tmax[i]=} is outside of data range by {missing:g} s")
epochs_i = mne.Epochs(raw, events[i:i+1], None, tmin_i, tmax_i, baseline, picks, preload=True, decim=decim, **kwargs)
out.append(epochs_i)
return out

Expand Down

0 comments on commit 3f2d2d4

Please sign in to comment.