diff --git a/.github/workflows/qt_viz_tests.yml b/.github/workflows/qt_viz_tests.yml index acd37218..02076273 100644 --- a/.github/workflows/qt_viz_tests.yml +++ b/.github/workflows/qt_viz_tests.yml @@ -101,7 +101,9 @@ jobs: run: mne sys_info - run: pytest -m pgtest --cov=mne_qt_browser --cov-report=xml ../mne-python/mne/viz name: Run MNE-Tests - - run: pytest mne_qt_browser/tests + - run: pytest mne_qt_browser/tests/test_pg_specific.py + name: Run pyqtgraph-specific tests + - run: pytest mne_qt_browser/tests/test_speed.py name: Run benchmarks - uses: codecov/codecov-action@v1 if: always() diff --git a/mne_qt_browser/_pg_figure.py b/mne_qt_browser/_pg_figure.py index ab855b59..8ff2614c 100644 --- a/mne_qt_browser/_pg_figure.py +++ b/mne_qt_browser/_pg_figure.py @@ -1,11 +1,12 @@ # -*- coding: utf-8 -*- """Base classes and functions for 2D browser backends.""" -# Authors: Martin Schulz +# Author: Martin Schulz # -# License: Simplified BSD +# License: BSD-3-Clause import datetime +import functools import gc import math import platform @@ -13,6 +14,7 @@ from ast import literal_eval from collections import OrderedDict from contextlib import contextmanager +from copy import copy from functools import partial from os.path import getsize @@ -33,6 +35,7 @@ QGraphicsLineItem, QGraphicsScene, QTextEdit, QSizePolicy, QSpinBox, QDesktopWidget, QSlider) from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg +from matplotlib.colors import to_rgba_array from pyqtgraph import (AxisItem, GraphicsView, InfLineLabel, InfiniteLine, LinearRegionItem, PlotCurveItem, PlotItem, Point, TextItem, ViewBox, mkBrush, @@ -62,7 +65,6 @@ def capture_exceptions(): name = 'pyqtgraph' - # This can be removed when mne==1.0 is released. try: from mne.viz.backends._utils import _init_mne_qtapp @@ -124,7 +126,7 @@ def _init_mne_qtapp(enable_icon=True, pg_app=False): def _get_std_icon(icon_name): return QApplication.instance().style().standardIcon( - getattr(QStyle, icon_name)) + getattr(QStyle, icon_name)) def _get_color(color_spec): @@ -149,34 +151,143 @@ def _get_color(color_spec): return color -class RawTraceItem(PlotCurveItem): +def propagate_to_children(method): + @functools.wraps(method) + def wrapper(*args, **kwargs): + propagate = kwargs.pop('propagate', True) + result = method(*args, **kwargs) + if args[0].mne.is_epochs and propagate: + # parent always goes first + if hasattr(args[0], 'child_traces'): + for child_trace in args[0].child_traces: + getattr(child_trace, method.__name__)(*args[1:], **kwargs) + return result + + return wrapper + + +class DataTrace(PlotCurveItem): """Graphics-Object for single data trace.""" - def __init__(self, mne, ch_idx, child=False): - super().__init__(clickable=True) - self.mne = mne + def __init__(self, main, ch_idx, child_idx=None, parent_trace=None): + super().__init__() + self.main = main + self.mne = main.mne + + # Set clickable with small area around trace to make clicking easier. + self.setClickable(True, 12) # Set default z-value to 1 to be before other items in scene self.setZValue(1) - if self.mne.is_epochs and not child: - self.bad_trace = RawTraceItem(self.mne, ch_idx, child=True) - + # General attributes + # The ch_idx is the index of the channel represented by this trace + # in the channel-order from the unchanged instance (which also picks + # refer to). + self.ch_idx = None + # The range_idx is the index of the channel represented by this trace + # in the shown range. + self.range_idx = None + # The order_idx is the index of the channel represented by this trace + # in the channel-order (defined e.g. by group_by). + self.order_idx = None + # Name of the channel the trace represents. + self.ch_name = None + # Indicates if trace is bad. + self.isbad = None + # Channel-type of trace. + self.ch_type = None + # Color-specifier (all possible matplotlib color formats) + self.color = None + + # Attributes for epochs-mode + # Index of child if child. + self.child_idx = child_idx + # Reference to parent if child. + self.parent_trace = parent_trace + + # Only for parent traces + if self.parent_trace is None: + # Add to main trace list + self.mne.traces.append(self) + # References to children + self.child_traces = list() + # Colors of trace in viewrange + self.trace_colors = None + + # set attributes self.set_ch_idx(ch_idx) self.update_color() - self.update_data() + self.update_scale() + # Avoid calling self.update_data() twice on initialization + # (because of update_scale()). + if self.mne.clipping is None: + self.update_data() + # Add to main plot + self.mne.plt.addItem(self) + + @propagate_to_children + def remove(self): + self.mne.plt.removeItem(self) + # Only for parent trace + if self.parent_trace is None: + self.mne.traces.remove(self) + self.deleteLater() + + @propagate_to_children def update_color(self): - """Update the color of the trace (depending on ch_type and bad).""" - if self.isbad and not self.mne.butterfly: - self.setPen(_get_color(self.mne.ch_color_bad)) + """Update the color of the trace.""" + + # Epochs + if self.mne.is_epochs: + # Add child traces if shown trace needs to have multiple colors + # (PlotCurveItem only supports one color per object). + # There always as many color-specific traces added depending + # on the whole time range of the instance regardless of the + # currently visible time range (to avoid checking for new colors + # while scrolling horizontally). + + # Only for parent trace + if hasattr(self, 'child_traces'): + self.trace_colors = np.unique( + self.mne.epoch_color_ref[self.ch_idx], axis=0) + n_childs = len(self.child_traces) + trace_diff = len(self.trace_colors) - n_childs - 1 + # Add child traces if necessary + if trace_diff > 0: + for cix in range(n_childs, n_childs + trace_diff): + child = DataTrace(self.main, self.ch_idx, + child_idx=cix, parent_trace=self) + self.child_traces.append(child) + elif trace_diff < 0: + for _ in range(abs(trace_diff)): + rm_trace = self.child_traces.pop() + rm_trace.remove() + + # Set parent color + self.color = self.trace_colors[0] + + # Only for child trace + else: + self.color = self.parent_trace.trace_colors[ + self.child_idx + 1] + + # Raw/ICA else: - self.setPen(_get_color(self.color)) + if self.isbad: + self.color = self.mne.ch_color_bad + else: + self.color = self.mne.ch_color_ref[self.ch_name] + self.setPen(_get_color(self.color)) + + @propagate_to_children def update_range_idx(self): """Should be updated when view-range or ch_idx changes.""" self.range_idx = np.argwhere(self.mne.picks == self.ch_idx)[0][0] + @propagate_to_children def update_ypos(self): """Should be updated when butterfly is toggled or ch_idx changes.""" if self.mne.butterfly and self.mne.fig_selection is not None: @@ -189,6 +300,16 @@ def update_ypos(self): else: self.ypos = self.range_idx + self.mne.ch_start + 1 + @propagate_to_children + def update_scale(self): + transform = QTransform() + transform.scale(1., self.mne.scale_factor) + self.setTransform(transform) + + if self.mne.clipping is not None: + self.update_data(propagate=False) + + @propagate_to_children def set_ch_idx(self, ch_idx): """Sets the channel index and all deriving indices.""" # The ch_idx is the index of the channel represented by this trace @@ -197,16 +318,16 @@ def set_ch_idx(self, ch_idx): self.ch_idx = ch_idx # The range_idx is the index of the channel represented by this trace # in the shown range. - self.update_range_idx() + self.update_range_idx(propagate=False) # The order_idx is the index of the channel represented by this trace # in the channel-order (defined e.g. by group_by). self.order_idx = np.argwhere(self.mne.ch_order == self.ch_idx)[0][0] self.ch_name = self.mne.inst.ch_names[ch_idx] self.isbad = self.ch_name in self.mne.info['bads'] self.ch_type = self.mne.ch_types[ch_idx] - self.color = self.mne.ch_color_assoc[self.ch_name] - self.update_ypos() + self.update_ypos(propagate=False) + @propagate_to_children def update_data(self): """Update data (fetch data from self.mne according to self.ch_idx).""" if self.mne.is_epochs or (self.mne.clipping is not None and @@ -229,20 +350,106 @@ def update_data(self): else: times = self.mne.times + # For multiple color traces with epochs + # replace values from other colors with NaN. + if self.mne.is_epochs: + data = np.copy(data) + check_color = self.mne.epoch_color_ref[self.ch_idx, + self.mne.epoch_idx] + bool_ixs = np.invert(np.equal(self.color, check_color).all(axis=1)) + starts = self.mne.boundary_times[self.mne.epoch_idx][bool_ixs] + stops = self.mne.boundary_times[self.mne.epoch_idx + 1][bool_ixs] + + for start, stop in zip(starts, stops): + data[np.logical_and(start <= times, times <= stop)] = np.nan + self.setData(times, data, connect=connect, skipFiniteCheck=skip, antialias=self.mne.antialiasing) self.setPos(0, self.ypos) + def toggle_bad(self, x=None): + """Toggle bad status.""" + # Toggle bad epoch + if self.mne.is_epochs and x is not None: + epoch_idx, color = self.main._toggle_bad_epoch(x) + + # Update epoch color + if color != 'none': + new_epo_color = np.repeat(to_rgba_array(color), + len(self.mne.inst.ch_names), axis=0) + elif self.mne.epoch_colors is None: + new_epo_color = np.concatenate( + [to_rgba_array(c) for c + in self.mne.ch_color_ref.values()]) + else: + new_epo_color = \ + np.concatenate([to_rgba_array(c) for c in + self.mne.epoch_colors[epoch_idx]]) + + # Update bad channel colors + bad_idxs = np.in1d(self.mne.ch_names, self.mne.info['bads']) + new_epo_color[bad_idxs] = to_rgba_array(self.mne.ch_color_bad) + + self.mne.epoch_color_ref[:, epoch_idx] = new_epo_color + + # Update overview-bar + self.mne.overview_bar.update_bad_epochs() + + # Update other traces inlcuding self + for trace in self.mne.traces: + trace.update_color() + # Update data is necessary because colored segments will vary + trace.update_data() + + # Toggle bad channel + else: + bad_color, pick, marked_bad = self.main._toggle_bad_channel( + self.range_idx) + + # Update line color status + self.isbad = not self.isbad + + # Update colors for epochs + if self.mne.is_epochs: + if marked_bad: + new_ch_color = np.repeat(to_rgba_array(bad_color), + len(self.mne.inst), axis=0) + elif self.mne.epoch_colors is None: + ch_color = self.mne.ch_color_ref[self.ch_name] + new_ch_color = np.repeat(to_rgba_array(ch_color), + len(self.mne.inst), axis=0) + else: + new_ch_color = np.concatenate([to_rgba_array(c[pick]) for + c in self.mne.epoch_colors]) + + self.mne.epoch_color_ref[pick, :] = new_ch_color + + # Update trace color + self.update_color() + if self.mne.is_epochs: + self.update_data() + + # Update channel-axis + self.main._update_yaxis_labels() + + # Update overview-bar + self.mne.overview_bar.update_bad_channels() + + # Update sensor color (if in selection mode) + if self.mne.fig_selection is not None: + self.mne.fig_selection._update_bad_sensors(pick, marked_bad) + def mouseClickEvent(self, ev): """Customize mouse click events.""" if (not self.clickable or ev.button() != Qt.MouseButton.LeftButton or self.mne.annotation_mode): + # Explicitly ignore events in annotation-mode ev.ignore() return if self.mouseShape().contains(ev.pos()): ev.accept() - self.sigClicked.emit(self, ev) + self.toggle_bad(ev.pos().x()) def get_xdata(self): """Get xdata for testing.""" @@ -264,9 +471,10 @@ def __init__(self, mne): def tickValues(self, minVal, maxVal, size): """Customize creation of axis values from visible axis range.""" if self.mne.is_epochs: - values = self.mne.midpoints[np.argwhere( - minVal <= self.mne.midpoints <= maxVal)] - tick_values = [(len(self.mne.inst.times), values)] + value_idxs = np.searchsorted(self.mne.midpoints, [minVal, maxVal]) + values = self.mne.midpoints[slice(*value_idxs)] + spacing = len(self.mne.inst.times) / self.mne.info['sfreq'] + tick_values = [(spacing, values)] return tick_values else: # Save _spacing for later use @@ -277,7 +485,7 @@ def tickStrings(self, values, scale, spacing): """Customize strings of axis values.""" if self.mne.is_epochs: epoch_nums = self.mne.inst.selection - ts = epoch_nums[np.in1d(self.mne.midpoints, values).nonzero()[0]] + ts = epoch_nums[np.searchsorted(self.mne.midpoints, values)] tick_strings = [str(v) for v in ts] elif self.mne.time_format == 'clock': @@ -362,7 +570,7 @@ def drawPicture(self, p, axisSpec, tickSpecs, textSpecs): elif text in self.mne.info['bads']: p.setPen(_get_color(self.mne.ch_color_bad)) else: - p.setPen(_get_color(self.mne.ch_color_assoc[text])) + p.setPen(_get_color(self.mne.ch_color_ref[text])) self.ch_texts[text] = ((rect.left(), rect.left() + rect.width()), (rect.top(), rect.top() + rect.height())) p.drawText(rect, int(flags), text) @@ -387,7 +595,7 @@ def mouseClickEvent(self, event): trace = [tr for tr in self.mne.traces if tr.ch_name == ch_name][0] if event.button() == Qt.LeftButton: - self.main._bad_ch_clicked(trace) + trace.toggle_bad() elif event.button() == Qt.RightButton: self.main._create_ch_context_fig(trace.range_idx) @@ -415,8 +623,8 @@ def mousePressEvent(self, event): opt = QStyleOptionSlider() self.initStyleOption(opt) control = self.style().hitTestComplexControl( - QStyle.CC_ScrollBar, opt, - event.pos(), self) + QStyle.CC_ScrollBar, opt, + event.pos(), self) if (control == QStyle.SC_ScrollBarAddPage or control == QStyle.SC_ScrollBarSubPage): # scroll here @@ -441,8 +649,9 @@ def mousePressEvent(self, event): sliderMin = gr.y() sliderMax = gr.bottom() - sliderLength + 1 self.setValue(QStyle.sliderValueFromPosition( - self.minimum(), self.maximum(), - pos - sliderMin, sliderMax - sliderMin, opt.upsideDown)) + self.minimum(), self.maximum(), + pos - sliderMin, sliderMax - sliderMin, + opt.upsideDown)) return return super().mousePressEvent(event) @@ -454,12 +663,10 @@ class TimeScrollBar(BaseScrollBar): def __init__(self, mne): super().__init__(Qt.Horizontal) self.mne = mne - self.step_factor = None - + self.step_factor = 1 self.setMinimum(0) self.setSingleStep(1) - self.setPageStep(self.mne.scroll_sensitivity) - self._update_duration() + self.update_duration() self.setFocusPolicy(Qt.WheelFocus) # Because valueChanged is needed (captures every input to scrollbar, # not just sliderMoved), there has to be made a differentiation @@ -469,7 +676,11 @@ def __init__(self, mne): def _time_changed(self, value): if not self.external_change: - value /= self.step_factor + if self.mne.is_epochs: + # Convert Epoch index to time + value = self.mne.boundary_times[int(value)] + else: + value /= self.step_factor self.mne.plt.setXRange(value, value + self.mne.duration, padding=0) @@ -477,21 +688,27 @@ def update_value(self, value): """Update value of the ScrollBar.""" # Mark change as external to avoid setting # XRange again in _time_changed. - self._update_duration() self.external_change = True - self.setValue(int(value * self.step_factor)) + if self.mne.is_epochs: + set_value = np.searchsorted(self.mne.midpoints, value) + else: + set_value = int(value * self.step_factor) + self.setValue(set_value) self.external_change = False - def _update_duration(self): - new_step_factor = self.mne.scroll_sensitivity / self.mne.duration - if new_step_factor != self.step_factor: - self.step_factor = new_step_factor - new_maximum = int((self.mne.xmax - self.mne.duration) - * self.step_factor) - self.setMaximum(new_maximum) + def update_duration(self): + """Update bar size.""" + if self.mne.is_epochs: + self.setPageStep(self.mne.n_epochs) + self.setMaximum(len(self.mne.inst) - self.mne.n_epochs) + else: + self.setPageStep(int(self.mne.duration)) + self.step_factor = self.mne.scroll_sensitivity / self.mne.duration + self.setMaximum(int((self.mne.xmax - self.mne.duration) + * self.step_factor)) def _update_scroll_sensitivity(self): - self.setPageStep(self.mne.scroll_sensitivity) + self.update_duration() self.update_value(self.value() / self.step_factor) def keyPressEvent(self, event): @@ -508,8 +725,8 @@ def __init__(self, mne): self.mne = mne self.setMinimum(0) - self._update_nchan() self.setSingleStep(1) + self.update_nchan() self.setFocusPolicy(Qt.WheelFocus) # Because valueChanged is needed (captures every input to scrollbar, # not just sliderMoved), there has to be made a differentiation @@ -534,9 +751,9 @@ def update_value(self, value): self.external_change = True self.setValue(value) self.external_change = False - self._update_nchan() - def _update_nchan(self): + def update_nchan(self): + """Update bar size.""" if self.mne.group_by in ['position', 'selection']: self.setPageStep(1) self.setMaximum(len(self.mne.ch_selections) - 1) @@ -584,14 +801,36 @@ def __init__(self, main): self.event_line_dict = dict() self.update_events() - # Annotations - self.annotations_rect_dict = dict() - self.update_annotations() + if self.mne.is_epochs: + # Epochs Lines + self.epoch_line_dict = dict() + self.update_epoch_lines() + self.bad_epoch_rect_dict = dict() + self.update_bad_epochs() + else: + # Annotations + self.annotations_rect_dict = dict() + self.update_annotations() + + # VLine + self.v_line = None + self.update_vline() # View Range self.viewrange_rect = None self.update_viewrange() + def update_epoch_lines(self): + """Update representation of epoch lines.""" + epoch_line_pen = mkPen(color='k', width=1) + for t in self.mne.boundary_times[1:-1]: + top_left = self._mapFromData(t, 0) + bottom_right = self._mapFromData(t, len(self.mne.ch_order)) + line = self.scene().addLine(QLineF(top_left, bottom_right), + epoch_line_pen) + line.setZValue(1) + self.epoch_line_dict[t] = line + def update_bad_channels(self): """Update representation of bad channels.""" bad_set = set(self.mne.info['bads']) @@ -613,7 +852,30 @@ def update_bad_channels(self): self.scene().removeItem(self.bad_line_dict[ch_name]) self.bad_line_dict.pop(ch_name) + def update_bad_epochs(self): + bad_set = set(self.mne.bad_epochs) + rect_set = set(self.bad_epoch_rect_dict.keys()) + + add_epos = bad_set.difference(rect_set) + rm_epos = rect_set.difference(bad_set) + + for epo_num in self.mne.inst.selection: + if epo_num in add_epos: + epo_idx = self.mne.inst.selection.tolist().index(epo_num) + start, stop = self.mne.boundary_times[epo_idx:epo_idx + 2] + top_left = self._mapFromData(start, 0) + bottom_right = self._mapFromData(stop, len(self.mne.ch_order)) + pen = _get_color(self.mne.epoch_color_bad) + rect = self.scene().addRect(QRectF(top_left, bottom_right), + pen=pen, brush=pen) + rect.setZValue(3) + self.bad_epoch_rect_dict[epo_num] = rect + elif epo_num in rm_epos: + self.scene().removeItem(self.bad_epoch_rect_dict[epo_num]) + self.bad_epoch_rect_dict.pop(epo_num) + def update_events(self): + """Update representation of events.""" if self.mne.event_nums is not None and self.mne.events_visible: for ev_t, ev_id in zip(self.mne.event_times, self.mne.event_nums): color_name = self.mne.event_color_dict[ev_id] @@ -685,7 +947,30 @@ def update_annotations(self): len(self.mne.ch_order)) rect.setRect(QRectF(top_left, bottom_right)) + def update_vline(self): + """Update representation of vline.""" + if self.mne.is_epochs: + # VLine representation not useful in epochs-mode + pass + # Add VLine-Representation + elif self.mne.vline is not None: + value = self.mne.vline.value() + top_left = self._mapFromData(value, 0) + bottom_right = self._mapFromData(value, len(self.mne.ch_order)) + line = QLineF(top_left, bottom_right) + if self.v_line is None: + pen = mkPen('g') + self.v_line = self.scene().addLine(line, pen) + self.v_line.setZValue(1) + else: + self.v_line.setLine(line) + # Remove VLine-Representation + elif self.v_line is not None: + self.scene().removeItem(self.v_line) + self.v_line = None + def update_viewrange(self): + """Update representation of viewrange.""" if self.mne.butterfly: top_left = self._mapFromData(self.mne.t_start, 0) bottom_right = self._mapFromData(self.mne.t_start + @@ -707,24 +992,34 @@ def update_viewrange(self): def _set_range_from_pos(self, pos): x, y = self._mapToData(pos) - if x == '-offbounds': - xmin, xmax = (0, self.mne.duration) - elif x == '+offbounds': - xmin, xmax = (self.mne.xmax - self.mne.duration, self.mne.xmax) + + # Set X + # Move click position to middle of view range + if self.mne.is_epochs: + epo_idx = max(x - self.mne.n_epochs // 2, 0) + xmin = self.mne.boundary_times[epo_idx] else: - # Move middle of view range to click position xmin = x - self.mne.duration / 2 - xmax = xmin + self.mne.duration + xmax = xmin + self.mne.duration + + # Check boundaries + if x == '-offbounds' or xmin < 0: + xmin = 0 + xmax = self.mne.duration + elif x == '+offbounds' or xmax > self.mne.xmax: + xmin = self.mne.xmax - self.mne.duration + xmax = self.mne.xmax self.mne.plt.setXRange(xmin, xmax, padding=0) - if y == '-offbounds': + # Set Y + ymin = y - self.mne.n_channels / 2 + ymax = ymin + self.mne.n_channels + 1 + # Check boundaries + if y == '-offbounds' or ymin < 0: ymin, ymax = (0, self.mne.n_channels + 1) - elif y == '+offbounds': + elif y == '+offbounds' or ymax > self.mne.ymax: ymin, ymax = (self.mne.ymax - self.mne.n_channels - 1, self.mne.ymax) - else: - ymin = y - self.mne.n_channels / 2 - ymax = ymin + self.mne.n_channels + 1 if self.mne.fig_selection: self.mne.fig_selection._scroll_to_idx(int(ymin)) else: @@ -762,10 +1057,6 @@ def resizeEvent(self, event): # Resize backgounrd self._fit_bg_img() - # ToDo: This could be improved a lot with view-transforms e.g. with - # QGraphicsView.fitInView. The margin-problem could be approached - # with https://stackoverflow.com/questions/19640642/ - # qgraphicsview-fitinview-margins, but came with other problems. # Resize Graphics Items (assuming height never changes) # Resize bad_channels for bad_ch_line in self.bad_line_dict.values(): @@ -780,16 +1071,37 @@ def resizeEvent(self, event): bottom_right = self._mapFromData(ev_t, len(self.mne.ch_order)) event_line.setLine(QLineF(top_left, bottom_right)) - # Resize annotation-rects - for annot_dict in self.annotations_rect_dict.values(): - annot_rect = annot_dict['rect'] - plot_onset = annot_dict['plot_onset'] - duration = annot_dict['duration'] + if self.mne.is_epochs: + # Resize epoch lines + for epo_t, epoch_line in self.epoch_line_dict.items(): + top_left = self._mapFromData(epo_t, 0) + bottom_right = self._mapFromData(epo_t, + len(self.mne.ch_order)) + epoch_line.setLine(QLineF(top_left, bottom_right)) + # Resize bad rects + for epo_idx, epoch_rect in self.bad_epoch_rect_dict.items(): + start, stop = self.mne.boundary_times[epo_idx:epo_idx + 2] + top_left = self._mapFromData(start, 0) + bottom_right = self._mapFromData(stop, len(self.mne.ch_order)) + epoch_rect.setRect(QRectF(top_left, bottom_right)) + else: + # Resize annotation-rects + for annot_dict in self.annotations_rect_dict.values(): + annot_rect = annot_dict['rect'] + plot_onset = annot_dict['plot_onset'] + duration = annot_dict['duration'] - top_left = self._mapFromData(plot_onset, 0) - bottom_right = self._mapFromData(plot_onset + duration, - len(self.mne.ch_order)) - annot_rect.setRect(QRectF(top_left, bottom_right)) + top_left = self._mapFromData(plot_onset, 0) + bottom_right = self._mapFromData(plot_onset + duration, + len(self.mne.ch_order)) + annot_rect.setRect(QRectF(top_left, bottom_right)) + + # Update vline + if all([i is not None for i in [self.v_line, self.mne.vline]]): + value = self.mne.vline.value() + top_left = self._mapFromData(value, 0) + bottom_right = self._mapFromData(value, len(self.mne.ch_order)) + self.v_line.setLine(QLineF(top_left, bottom_right)) # Update viewrange-rect top_left = self._mapFromData(self.mne.t_start, self.mne.ch_start) @@ -842,8 +1154,12 @@ def _mapToData(self, point): elif xnorm > 1: x = '+offbounds' else: - time_idx = int((len(self.mne.inst.times) - 1) * xnorm) - x = self.mne.inst.times[time_idx] + if self.mne.is_epochs: + # Return epoch index for epochs + x = int(len(self.mne.inst) * xnorm) + else: + time_idx = int((len(self.mne.inst.times) - 1) * xnorm) + x = self.mne.inst.times[time_idx] ynorm = point.y() / self.height() if ynorm < 0: @@ -880,7 +1196,7 @@ def mouseDragEvent(self, event, axis=None): description = self.mne.current_description if event.isStart(): self._drag_start = self.mapSceneToView( - event.lastScenePos()).x() + event.lastScenePos()).x() drag_stop = self.mapSceneToView(event.scenePos()).x() self._drag_region = AnnotRegion(self.mne, description=description, @@ -926,12 +1242,14 @@ def mouseDragEvent(self, event, axis=None): # Update Overview-Bar self.mne.overview_bar.update_annotations() else: - self._drag_region.setRegion((self._drag_start, - self.mapSceneToView( - event.scenePos()).x())) + x_to = self.mapSceneToView(event.scenePos()).x() + self._drag_region.setRegion((self._drag_start, x_to)) + elif event.isFinish(): - QMessageBox.warning(self.main, 'No description!', - 'No description is given, add one!') + self.main.message_box(text='No description!', + info_text='No description is given, ' + 'add one!', + icon=QMessageBox.Warning) def mouseClickEvent(self, event): """Customize mouse click events.""" @@ -940,7 +1258,7 @@ def mouseClickEvent(self, event): if not self.mne.annotation_mode: if event.button() == Qt.LeftButton: self.main._add_vline(self.mapSceneToView( - event.scenePos()).x()) + event.scenePos()).x()) elif event.button() == Qt.RightButton: self.main._remove_vline() @@ -963,34 +1281,49 @@ class VLineLabel(InfLineLabel): def __init__(self, vline): super().__init__(vline, text='{value:.3f} s', position=0.98, fill='g', color='b', movable=True) - self.vline = vline self.cursorOffset = None def mouseDragEvent(self, ev): """Customize mouse drag events.""" if self.movable and ev.button() == Qt.LeftButton: if ev.isStart(): - self.vline.moving = True - self.cursorOffset = (self.vline.pos() - + self.line.moving = True + self.cursorOffset = (self.line.pos() - self.mapToView(ev.buttonDownPos())) ev.accept() - if not self.vline.moving: + if not self.line.moving: return - self.vline.setPos(self.cursorOffset + self.mapToView(ev.pos())) - self.vline.sigDragged.emit(self) + self.line.setPos(self.cursorOffset + self.mapToView(ev.pos())) + self.line.sigDragged.emit(self) if ev.isFinish(): - self.vline.moving = False - self.vline.sigPositionChangeFinished.emit(self) + self.line.moving = False + self.line.sigPositionChangeFinished.emit(self.line) + + def valueChanged(self): + """Customize what happens on value change.""" + if not self.isVisible(): + return + value = self.line.value() + if self.line.mne.is_epochs: + # Show epoch-time + t_vals_abs = np.linspace(0, self.line.mne.epoch_dur, + len(self.line.mne.inst.times)) + search_val = value % self.line.mne.epoch_dur + t_idx = np.searchsorted(t_vals_abs, search_val) + value = self.line.mne.inst.times[t_idx] + self.setText(self.format.format(value=value)) + self.updatePosition() class VLine(InfiniteLine): """Marker to be placed inside the Trace-Plot.""" - def __init__(self, pos, bounds): + def __init__(self, mne, pos, bounds): super().__init__(pos, pen='g', hoverPen='y', movable=True, bounds=bounds) + self.mne = mne self.label = VLineLabel(self) @@ -1183,31 +1516,31 @@ def __init__(self, main, **kwargs): self.downsampling_box.setMinimum(0) self.downsampling_box.setSpecialValueText('Auto') self.downsampling_box.valueChanged.connect(partial( - self._value_changed, value_name='downsampling')) + self._value_changed, value_name='downsampling')) self.downsampling_box.setValue(0 if self.mne.downsampling == 'auto' else self.mne.downsampling) layout.addRow('downsampling', self.downsampling_box) self.ds_method_cmbx = QComboBox() self.ds_method_cmbx.setToolTip( - '

Downsampling Method

' - '' - '(Those methods are adapted from ' - 'pyqtgraph)
' - 'Default is "peak".') + '

Downsampling Method

' + '' + '(Those methods are adapted from ' + 'pyqtgraph)
' + 'Default is "peak".') self.ds_method_cmbx.addItems(['subsample', 'mean', 'peak']) self.ds_method_cmbx.currentTextChanged.connect(partial( - self._value_changed, value_name='ds_method')) + self._value_changed, value_name='ds_method')) self.ds_method_cmbx.setCurrentText( - self.mne.ds_method) + self.mne.ds_method) layout.addRow('ds_method', self.ds_method_cmbx) self.scroll_sensitivity_slider = QSlider(Qt.Horizontal) @@ -1217,7 +1550,7 @@ def __init__(self, main, **kwargs): 'the scrolling in ' 'horizontal direction.') self.scroll_sensitivity_slider.valueChanged.connect(partial( - self._value_changed, value_name='scroll_sensitivity')) + self._value_changed, value_name='scroll_sensitivity')) # Set default self.scroll_sensitivity_slider.setValue(self.mne.scroll_sensitivity) layout.addRow('horizontal scroll sensitivity', @@ -1539,7 +1872,7 @@ def _style_butterfly(self): def _scroll_selection(self, step): name_idx = list(self.mne.ch_selections.keys()).index( - self.mne.old_selection) + self.mne.old_selection) new_idx = np.clip(name_idx + step, 0, len(self.mne.ch_selections) - 1) new_label = list(self.mne.ch_selections.keys())[new_idx] @@ -1668,18 +2001,17 @@ def __init__(self, annot_dock): self.ad = annot_dock self.current_mode = None - self.curr_des = None layout = QVBoxLayout() self.descr_label = QLabel() if self.mne.selected_region: self.mode_cmbx = QComboBox() - self.mode_cmbx.addItems(['group', 'current']) + self.mode_cmbx.addItems(['all', 'selected']) self.mode_cmbx.currentTextChanged.connect(self._mode_changed) layout.addWidget(QLabel('Edit Scope:')) layout.addWidget(self.mode_cmbx) # Set group as default - self._mode_changed('group') + self._mode_changed('all') layout.addWidget(self.descr_label) self.input_w = QLineEdit() @@ -1697,55 +2029,19 @@ def __init__(self, annot_dock): def _mode_changed(self, mode): self.current_mode = mode - if mode == 'group': + if mode == 'all': curr_des = self.ad.description_cmbx.currentText() else: curr_des = self.mne.selected_region.description self.descr_label.setText(f'Change "{curr_des}" to:') - self.curr_des = curr_des def _edit(self): new_des = self.input_w.text() if new_des: - if self.current_mode == 'group' or self.mne.selected_region is \ - None: - edit_regions = [r for r in self.mne.regions - if r.description == self.curr_des] - for ed_region in edit_regions: - idx = self.main._get_onset_idx( - ed_region.getRegion()[0]) - self.mne.inst.annotations.description[idx] = new_des - ed_region.update_description(new_des) - self.mne.new_annotation_labels.remove(self.curr_des) - self.mne.new_annotation_labels = \ - self.main._get_annotation_labels() - self.mne.visible_annotations[new_des] = \ - self.mne.visible_annotations.pop(self.curr_des) - self.mne.annotation_segment_colors[new_des] = \ - self.mne.annotation_segment_colors.pop( - self.curr_des) + if self.current_mode == 'all' or self.mne.selected_region is None: + self.ad._edit_description_all(new_des) else: - idx = self.main._get_onset_idx( - self.mne.selected_region.getRegion()[0]) - self.mne.inst.annotations.description[idx] = new_des - self.mne.selected_region.update_description(new_des) - if new_des not in self.mne.new_annotation_labels: - self.mne.new_annotation_labels.append(new_des) - self.mne.visible_annotations[new_des] = \ - self.mne.visible_annotations[self.curr_des] - self.mne.annotation_segment_colors[new_des] = \ - self.mne.annotation_segment_colors[self.curr_des] - if self.curr_des not in \ - self.mne.inst.annotations.description: - self.mne.new_annotation_labels.remove( - self.curr_des) - self.mne.visible_annotations.pop(self.curr_des) - self.mne.annotation_segment_colors.pop( - self.curr_des) - self.mne.current_description = new_des - self.main._setup_annotation_colors() - self.ad._update_description_cmbx() - self.ad._update_regions_colors() + self.ad._edit_description_selected(new_des) self.close() @@ -1777,11 +2073,11 @@ def _init_ui(self): layout.addWidget(add_bt) rm_bt = QPushButton('Remove Description') - rm_bt.clicked.connect(self._remove_description) + rm_bt.clicked.connect(self._remove_description_dlg) layout.addWidget(rm_bt) edit_bt = QPushButton('Edit Description') - edit_bt.clicked.connect(self._edit_description) + edit_bt.clicked.connect(self._edit_description_dlg) layout.addWidget(edit_bt) # Uncomment when custom colors for annotations are implemented in @@ -1840,35 +2136,66 @@ def _add_description_dlg(self): and new_description not in self.mne.new_annotation_labels: self._add_description(new_description) - def _edit_description(self): + def _edit_description_all(self, new_des): + """Update descriptions of all annotations with the same description.""" + old_des = self.description_cmbx.currentText() + edit_regions = [r for r in self.mne.regions + if r.description == old_des] + # Update regions & annotations + for ed_region in edit_regions: + idx = self.main._get_onset_idx(ed_region.getRegion()[0]) + self.mne.inst.annotations.description[idx] = new_des + ed_region.update_description(new_des) + # Update containers with annotation-attributes + self.mne.new_annotation_labels.remove(old_des) + self.mne.new_annotation_labels = self.main._get_annotation_labels() + self.mne.visible_annotations[new_des] = \ + self.mne.visible_annotations.pop(old_des) + self.mne.annotation_segment_colors[new_des] = \ + self.mne.annotation_segment_colors.pop(old_des) + + # Update related widgets + self.main._setup_annotation_colors() + self._update_regions_colors() + self._update_description_cmbx() + + def _edit_description_selected(self, new_des): + """Update description only of selected region.""" + old_des = self.mne.selected_region.description + idx = self.main._get_onset_idx(self.mne.selected_region.getRegion()[0]) + # Update regions & annotations + self.mne.inst.annotations.description[idx] = new_des + self.mne.selected_region.update_description(new_des) + # Update containers with annotation-attributes + if new_des not in self.mne.new_annotation_labels: + self.mne.new_annotation_labels.append(new_des) + self.mne.visible_annotations[new_des] = \ + copy(self.mne.visible_annotations[old_des]) + if old_des not in self.mne.inst.annotations.description: + self.mne.new_annotation_labels.remove(old_des) + self.mne.visible_annotations.pop(old_des) + self.mne.annotation_segment_colors[new_des] = \ + self.mne.annotation_segment_colors.pop(old_des) + + # Update related widgets + self.main._setup_annotation_colors() + self._update_regions_colors() + self._update_description_cmbx() + + def _edit_description_dlg(self): if len(self.mne.inst.annotations.description) > 0: _AnnotEditDialog(self) else: - QMessageBox.information(self, 'No Annotations!', - 'There are no annotations yet to edit!') + self.main.message_box(text='No Annotations!', + info_text='There are no annotations ' + 'yet to edit!', + icon=QMessageBox.Information) - def _remove_description(self): - rm_description = self.description_cmbx.currentText() - existing_annot = list(self.mne.inst.annotations.description).count( - rm_description) - if existing_annot > 0: - ans = QMessageBox.question(self, - f'Remove annotations ' - f'with {rm_description}?', - f'There exist {existing_annot} ' - f'annotations with ' - f'"{rm_description}".\n' - f'Do you really want to remove them?') - if ans == QMessageBox.Yes: - rm_idxs = np.where( - self.mne.inst.annotations.description == rm_description) - for idx in rm_idxs: - self.mne.inst.annotations.delete(idx) - for rm_region in [r for r in self.mne.regions - if r.description == rm_description]: - rm_region.remove() - else: - return + def _remove_description(self, rm_description): + # Remove regions + for rm_region in [r for r in self.mne.regions + if r.description == rm_description]: + rm_region.remove() # Remove from descriptions self.mne.new_annotation_labels.remove(rm_description) @@ -1887,6 +2214,26 @@ def _remove_description(self): self.mne.current_description = \ self.description_cmbx.currentText() + def _remove_description_dlg(self): + rm_description = self.description_cmbx.currentText() + existing_annot = list(self.mne.inst.annotations.description).count( + rm_description) + if existing_annot > 0: + text = f'Remove annotations with {rm_description}?' + info_text = f'There exist {existing_annot} annotations with ' \ + f'"{rm_description}".\n' \ + f'Do you really want to remove them?' + buttons = QMessageBox.Yes | QMessageBox.No + ans = self.main.message_box(text=text, info_text=info_text, + buttons=buttons, + default_button=QMessageBox.Yes, + icon=QMessageBox.Question) + else: + ans = QMessageBox.Yes + + if ans == QMessageBox.Yes: + self._remove_description(rm_description) + def _select_annotations(self): def _set_visible_region(state, description): self.mne.visible_annotations[description] = bool(state) @@ -1954,8 +2301,11 @@ def _start_changed(self): if start < stop: sel_region.setRegion((start, stop)) else: - QMessageBox.warning(self, 'Invalid value!', - 'Start can\'t be bigger or equal to Stop!') + self.main.message_box(text='Invalid value!', + info_text='Start can\'t be bigger or ' + 'equal to Stop!', + icon=QMessageBox.Critical, + modal=False) self.start_bx.setValue(sel_region.getRegion()[0]) def _stop_changed(self): @@ -1966,9 +2316,10 @@ def _stop_changed(self): if start < stop: sel_region.setRegion((start, stop)) else: - QMessageBox.warning(self, 'Invalid value!', - 'Stop can\'t be smaller ' - 'or equal to Start!') + self.main.message_box(text='Invalid value!', + info_text='Stop can\'t be smaller or ' + 'equal to Start!', + icon=QMessageBox.Critical) self.stop_bx.setValue(sel_region.getRegion()[1]) def _set_color(self): @@ -2011,36 +2362,35 @@ def reset(self): self.stop_bx.setValue(0) def _show_help(self): - QMessageBox.information(self, 'Annotations-Help', - '

Help

' - '

Annotations

' - '

Add Annotations

' - 'Drag inside the data-view to create ' - 'annotations with the description currently ' - 'selected (leftmost item of the toolbar).' - 'If there is no description yet, add one ' - 'with the button "Add description".' - '

Remove Annotations

' - 'You can remove single annotations by ' - 'right-clicking on them.' - '

Edit Annotations

' - 'You can edit annotations by dragging them or ' - 'their boundaries. Or you can use the dials ' - 'in the toolbar to adjust the boundaries for ' - 'the current selected annotation.' - '

Descriptions

' - '

Add Description

' - 'Add a new description with the button' - '"Add description".' - '

Edit Description

' - 'You can edit the description of one single ' - 'annotation or all annotations of the ' - 'currently selected kind with the button ' - '"Edit description".' - '

Remove Description

' - 'You can remove all annotations of the ' - 'currently selected kind with the button ' - '"Remove description".') + info_text = '

Help

' \ + '

Annotations

' \ + '

Add Annotations

' \ + 'Drag inside the data-view to create annotations with '\ + 'the description currently selected (leftmost item of '\ + 'the toolbar).If there is no description yet, add one ' \ + 'with the button "Add description".' \ + '

Remove Annotations

' \ + 'You can remove single annotations by right-clicking on '\ + 'them.' \ + '

Edit Annotations

' \ + 'You can edit annotations by dragging them or their '\ + 'boundaries. Or you can use the dials in the toolbar to '\ + 'adjust the boundaries for the current selected '\ + 'annotation.' \ + '

Descriptions

' \ + '

Add Description

' \ + 'Add a new description with ' \ + 'the button "Add description".' \ + '

Edit Description

' \ + 'You can edit the description of one single annotation '\ + 'or all annotations of the currently selected kind with '\ + 'the button "Edit description".' \ + '

Remove Description

' \ + 'You can remove all annotations of the currently '\ + 'selected kind with the button "Remove description".' + self.main.message_box(text='Annotations-Help', + info_text=info_text, + icon=QMessageBox.Information) class BrowserView(GraphicsView): @@ -2091,39 +2441,50 @@ def run(self): # because of the frequent gui-update-calls. # Thus n_chunks = 10 should suffice. data = None - times = None - n_chunks = 10 - if not self.mne.is_epochs: - chunk_size = len(self.browser.mne.inst) // n_chunks - for n in range(n_chunks): - start = n * chunk_size - if n == n_chunks - 1: - # Get last chunk which may be larger due to rounding above - stop = None - else: - stop = start + chunk_size - # Load data + if self.mne.is_epochs: + times = np.arange(len(self.mne.inst) * len(self.mne.inst.times)) \ + / self.mne.info['sfreq'] + else: + times = None + n_chunks = min(10, len(self.mne.inst)) + chunk_size = len(self.mne.inst) // n_chunks + for n in range(n_chunks): + start = n * chunk_size + if n == n_chunks - 1: + # Get last chunk which may be larger due to rounding above + stop = None + else: + stop = start + chunk_size + # Load epochs + if self.mne.is_epochs: + item = slice(start, stop) + with self.mne.inst.info._unlock(): + data_chunk = np.concatenate( + self.mne.inst.get_data(item=item), axis=-1) + # Load raw + else: data_chunk, times_chunk = self.browser._load_data(start, stop) - if data is None: - data = data_chunk + if times is None: times = times_chunk else: - data = np.concatenate((data, data_chunk), axis=1) times = np.concatenate((times, times_chunk), axis=0) - self.loadProgress.emit(n + 1) - else: - self.browser._load_data() - self.loadProgress.emit(n_chunks) - picks = self.browser.mne.ch_order + if data is None: + data = data_chunk + else: + data = np.concatenate((data, data_chunk), axis=1) + + self.loadProgress.emit(n + 1) + + picks = self.mne.ch_order # Deactive remove dc because it will be removed for visible range stashed_remove_dc = self.mne.remove_dc self.mne.remove_dc = False data = self.browser._process_data(data, 0, len(data), picks, self) self.mne.remove_dc = stashed_remove_dc - self.browser.mne.global_data = data - self.browser.mne.global_times = times + self.mne.global_data = data + self.mne.global_times = times # Calculate Z-Scores self.processText.emit('Calculating Z-Scores...') @@ -2168,7 +2529,7 @@ class _PGMetaClass(type(BrowserBase), type(QMainWindow)): 'downsampling': 1, # Downsampling-Method (set SettingsDialog for details) 'ds_method': 'peak' - } +} class PyQtGraphBrowser(BrowserBase, QMainWindow, metaclass=_PGMetaClass): @@ -2190,16 +2551,44 @@ def __init__(self, **kwargs): # Initialize attributes which are only used by pyqtgraph, not by # matplotlib and add them to MNEBrowseParams. - self.load_thread = None + + # Blocks concurrent scolling to avoid segmentation faults + self.is_scrolling = False + # Exactly one MessageBox for messages to facilitate testing/debugging + self.msg_box = QMessageBox(self) + # MessageBox modality needs to be adapted for tests + # (otherwise test execution blocks) + self.test_mode = False + # A Settings-Dialog self.mne.fig_settings = None + # Stores decimated data self.mne.decim_data = None self.mne.decim_times = None + # Stores ypos for selection-mode self.mne.selection_ypos_dict = dict() + # Parameters for precomputing self.mne.enable_precompute = False self.mne.data_precomputed = False + self._rerun_load_thread = False + # Parameters for overviewbar self.mne.show_overview_bar = True self.mne.overview_mode = 'channels' self.mne.zscore_rgba = None + # Container for traces + self.mne.traces = list() + # Scale-Factor + self.mne.scale_factor = 1 + # Stores channel-types for butterfly-mode + self.mne.butterfly_type_order = [tp for tp in + _DATA_CH_TYPES_ORDER_DEFAULT + if tp in self.mne.ch_types] + if self.mne.is_epochs: + # Stores parameters for epochs + self.mne.epoch_dur = np.diff(self.mne.boundary_times[:2])[0] + epoch_idx = np.searchsorted(self.mne.midpoints, + (self.mne.t_start, + self.mne.t_start + self.mne.duration)) + self.mne.epoch_idx = np.arange(epoch_idx[0], epoch_idx[1]) # Load from QSettings if available for qparam in qsettings_params: @@ -2217,10 +2606,34 @@ def __init__(self, **kwargs): setattr(self.mne, qparam, qvalue) # Initialize channel-colors for faster indexing later - self.mne.ch_color_assoc = dict() + self.mne.ch_color_ref = dict() for idx, ch_name in enumerate(self.mne.ch_names): ch_type = self.mne.ch_types[idx] - self.mne.ch_color_assoc[ch_name] = self.mne.ch_color_dict[ch_type] + self.mne.ch_color_ref[ch_name] = self.mne.ch_color_dict[ch_type] + + # Initialize epoch colors for faster indexing later + if self.mne.is_epochs: + if self.mne.epoch_colors is None: + self.mne.epoch_color_ref = \ + np.repeat([to_rgba_array(c) for c + in self.mne.ch_color_ref.values()], + len(self.mne.inst), axis=1) + else: + self.mne.epoch_color_ref = np.empty((len(self.mne.ch_names), + len(self.mne.inst), 4)) + for epo_idx, epo in enumerate(self.mne.epoch_colors): + for ch_idx, color in enumerate(epo): + self.mne.epoch_color_ref[ch_idx, epo_idx] = \ + to_rgba_array(color) + + # Mark bad epochs + self.mne.epoch_color_ref[:, self.mne.bad_epochs] = \ + to_rgba_array(self.mne.epoch_color_bad) + + # Mark bad channels + bad_idxs = np.in1d(self.mne.ch_names, self.mne.info['bads']) + self.mne.epoch_color_ref[bad_idxs, :] = \ + to_rgba_array(self.mne.ch_color_bad) # Add Load-Progressbar for loading in a thread self.mne.load_prog_label = QLabel('Loading...') @@ -2232,23 +2645,12 @@ def __init__(self, **kwargs): self.statusBar().addWidget(self.mne.load_progressbar, stretch=1) self.mne.load_progressbar.hide() - self.mne.traces = list() - self.mne.scale_factor = 1 - self.mne.butterfly_type_order = [tp for tp in - _DATA_CH_TYPES_ORDER_DEFAULT - if tp in self.mne.ch_types] - - # Initialize annotations (ToDo: Adjust to MPL) - self.mne.annotation_mode = False - self.mne.annotations_visible = True - self.mne.new_annotation_labels = self._get_annotation_labels() - if len(self.mne.new_annotation_labels) > 0: - self.mne.current_description = self.mne.new_annotation_labels[0] - else: - self.mne.current_description = None - self._setup_annotation_colors() - self.mne.regions = list() - self.mne.selected_region = None + # A QThread for preloading + self.load_thread = LoadThread(self) + self.load_thread.loadProgress.connect(self.mne. + load_progressbar.setValue) + self.load_thread.processText.connect(self._show_process) + self.load_thread.loadingFinished.connect(self._precompute_finished) # Create centralWidget and layout widget = QWidget() @@ -2265,7 +2667,7 @@ def __init__(self, **kwargs): # Start precomputing if enabled self._init_precompute() - # Initialize data (needed in RawTraceItem.update_data). + # Initialize data (needed in DataTrace.update_data). self._update_data() # Initialize Trace-Plot @@ -2274,7 +2676,11 @@ def __init__(self, **kwargs): # Hide AutoRange-Button plt.hideButtons() # Configure XY-Range - self.mne.xmax = self.mne.inst.times[-1] + if self.mne.is_epochs: + self.mne.xmax = len(self.mne.inst.times) * len(self.mne.inst) \ + / self.mne.info['sfreq'] + else: + self.mne.xmax = self.mne.inst.times[-1] # Add one empty line as padding at top (y=0). # Negative Y-Axis to display channels from top. self.mne.ymax = len(self.mne.ch_order) + 1 @@ -2287,11 +2693,18 @@ def __init__(self, **kwargs): # Add traces for ch_idx in self.mne.picks: - self._add_trace(ch_idx) + DataTrace(self, ch_idx) - # Add events (add all once, since their representation is simple - # they shouldn't have a big impact on performance when showing them - # is handled by QGraphicsView). + # Initialize Epochs Grid + if self.mne.is_epochs: + grid_pen = mkPen(color='k', width=2, style=Qt.DashLine) + for x_grid in self.mne.boundary_times[1:-1]: + grid_line = InfiniteLine(pos=x_grid, + pen=grid_pen, + movable=False) + plt.addItem(grid_line) + + # Add events if self.mne.event_nums is not None: self.mne.events_visible = True for ev_time, ev_id in zip(self.mne.event_times, @@ -2311,18 +2724,28 @@ def __init__(self, **kwargs): # Check for OpenGL if self.mne.use_opengl is None: # default: opt-in self.mne.use_opengl = ( - get_config('MNE_BROWSE_USE_OPENGL', '').lower() == 'true') + get_config('MNE_BROWSE_USE_OPENGL', '').lower() == 'true') + + # Epochs currently only work with OpenGL enabled + # (https://github.com/mne-tools/mne-qt-browser/issues/53) + mac_epochs = self.mne.is_epochs and sys.platform == 'darwin' + if mac_epochs: + self.mne.use_opengl = True + if self.mne.use_opengl: try: import OpenGL except (ModuleNotFoundError, ImportError): warn('PyOpenGL was not found and OpenGL can\'t be used!\n' - 'Consider installing pyopengl with "pip install pyopengl"' + 'Consider installing pyopengl with pip or conda' 'or set "use_opengl" to False to avoid this warning.') + if mac_epochs: + warn('Plotting epochs on MacOS without OpenGL' + 'may be unstable!') self.mne.use_opengl = False else: logger.info( - f'Using pyopengl with version {OpenGL.__version__}') + f'Using pyopengl with version {OpenGL.__version__}') # Initialize BrowserView (inherits QGraphicsView) view = BrowserView(plt, background='w', useOpenGL=self.mne.use_opengl) @@ -2335,6 +2758,21 @@ def __init__(self, **kwargs): ax_vscroll = ChannelScrollBar(self.mne) layout.addWidget(ax_vscroll, 0, 1) + # Initialize VLine + self.mne.vline = None + self.mne.vline_visible = False + + # Initialize crosshair (as in pyqtgraph example) + self.mne.crosshair_enabled = False + self.mne.crosshair_h = None + self.mne.crosshair = None + view.sigSceneMouseMoved.connect(self._mouse_moved) + + # Initialize Annotation-Widgets + self.mne.annotation_mode = False + if not self.mne.is_epochs: + self._init_annot_mode() + # OverviewBar overview_bar = OverviewBar(self) layout.addWidget(overview_bar, 2, 0, 1, 2) @@ -2359,7 +2797,7 @@ def __init__(self, **kwargs): if self.mne.enable_precompute: self.overview_mode_chkbx.addItems(['zscore']) self.overview_mode_chkbx.currentTextChanged.connect( - self._overview_mode_changed) + self._overview_mode_changed) self.overview_mode_chkbx.setCurrentIndex(0) # Avoid taking keyboard-focus self.overview_mode_chkbx.setFocusPolicy(Qt.NoFocus) @@ -2373,32 +2811,6 @@ def __init__(self, **kwargs): widget.setLayout(layout) self.setCentralWidget(widget) - # Initialize Annotation-Dock - fig_annotation = AnnotationDock(self) - self.addDockWidget(Qt.TopDockWidgetArea, fig_annotation) - fig_annotation.setVisible(False) - vars(self.mne).update(fig_annotation=fig_annotation) - - # Add annotations as regions - for annot in self.mne.inst.annotations: - plot_onset = _sync_onset(self.mne.inst, annot['onset']) - duration = annot['duration'] - description = annot['description'] - self._add_region(plot_onset, duration, description) - - # Initialize annotations - self._change_annot_mode() - - # Initialize VLine - self.mne.vline = None - self.mne.vline_visible = False - - # Initialize crosshair (as in pyqtgraph example) - self.mne.crosshair_enabled = False - self.mne.crosshair_h = None - self.mne.crosshair = None - view.sigSceneMouseMoved.connect(self._mouse_moved) - # Initialize Toolbar toolbar = self.addToolBar('Tools') toolbar.setToolButtonStyle(Qt.ToolButtonTextBesideIcon) @@ -2433,10 +2845,11 @@ def __init__(self, **kwargs): aincr_nchan.triggered.connect(partial(self.scale_all, 5 / 4)) toolbar.addAction(aincr_nchan) - atoggle_annot = QAction(_get_std_icon('SP_DialogResetButton'), - 'Annotations', parent=self) - atoggle_annot.triggered.connect(self._toggle_annotation_fig) - toolbar.addAction(atoggle_annot) + if not self.mne.is_epochs: + atoggle_annot = QAction(_get_std_icon('SP_DialogResetButton'), + 'Annotations', parent=self) + atoggle_annot.triggered.connect(self._toggle_annotation_fig) + toolbar.addAction(atoggle_annot) atoggle_proj = QAction(_get_std_icon('SP_DialogOkButton'), 'SSP', parent=self) @@ -2460,15 +2873,19 @@ def __init__(self, **kwargs): # Add GUI-Elements to MNEBrowserParams-Instance vars(self.mne).update( - plt=plt, view=view, ax_hscroll=ax_hscroll, ax_vscroll=ax_vscroll, - overview_bar=overview_bar, fig_annotation=fig_annotation, - toolbar=toolbar + plt=plt, view=view, ax_hscroll=ax_hscroll, + ax_vscroll=ax_vscroll, + overview_bar=overview_bar, toolbar=toolbar ) # Set Start-Range (after all necessary elements are initialized) - plt.setXRange(self.mne.t_start, self.mne.t_start + self.mne.duration, + plt.setXRange(self.mne.t_start, + self.mne.t_start + self.mne.duration, padding=0) - plt.setYRange(0, self.mne.n_channels + 1, padding=0) + if self.mne.butterfly: + self._set_butterfly(True) + else: + plt.setYRange(0, self.mne.n_channels + 1, padding=0) # Set Size width = int(self.mne.figsize[0] * self.logicalDpiX()) @@ -2658,50 +3075,9 @@ def __init__(self, **kwargs): } } - def _get_scale_transform(self): - transform = QTransform() - transform.scale(1, self.mne.scale_factor) - - return transform - def _update_yaxis_labels(self): self.mne.channel_axis.repaint() - def _bad_ch_clicked(self, line): - """Slot for bad channel click.""" - _, pick, marked_bad = self._toggle_bad_channel(line.range_idx) - - # Update line color - line.isbad = not line.isbad - line.update_color() - - # Update Channel-Axis - self._update_yaxis_labels() - - # Update Overview-Bar - self.mne.overview_bar.update_bad_channels() - - # update sensor color (if in selection mode) - if self.mne.fig_selection is not None: - self.mne.fig_selection._update_bad_sensors(pick, marked_bad) - - def _add_trace(self, ch_idx): - trace = RawTraceItem(self.mne, ch_idx) - - # Apply scaling - transform = self._get_scale_transform() - trace.setTransform(transform) - - # Add Item early to have access to viewBox - self.mne.plt.addItem(trace) - self.mne.traces.append(trace) - - trace.sigClicked.connect(lambda tr, _: self._bad_ch_clicked(tr)) - - def _remove_trace(self, trace): - self.mne.plt.removeItem(trace) - self.mne.traces.remove(trace) - def _add_scalebars(self): """Add scalebars for all channel-types. (scene handles showing them in when in view @@ -2719,7 +3095,6 @@ def _add_scalebars(self): ct in self.mne.scalings and ct in getattr(self.mne, 'units', {}) and ct in getattr(self.mne, 'unit_scalings', {})]: - scale_bar = ScaleBar(self.mne, ch_type) self.mne.scalebars[ch_type] = scale_bar self.mne.plt.addItem(scale_bar) @@ -2771,23 +3146,32 @@ def _overview_mode_changed(self, new_mode): def scale_all(self, step): """Scale all traces by multiplying with step.""" self.mne.scale_factor *= step - transform = self._get_scale_transform() + + # Reapply clipping if necessary + if self.mne.clipping is not None: + self._update_data() # Scale Traces (by scaling the Item, not the data) for line in self.mne.traces: - line.setTransform(transform) - if self.mne.clipping is not None: - line.update_data() + line.update_scale() # Update Scalebars self._update_scalebar_values() def hscroll(self, step): """Scroll horizontally by step.""" + if self.is_scrolling: + return + + self.is_scrolling = True + if step == '+full': rel_step = self.mne.duration elif step == '-full': rel_step = - self.mne.duration + elif self.mne.is_epochs: + direction = 1 if step > 0 else -1 + rel_step = direction * self.mne.duration / self.mne.n_epochs else: rel_step = step * self.mne.duration / self.mne.scroll_sensitivity # Get current range and add step to it @@ -2804,6 +3188,11 @@ def hscroll(self, step): def vscroll(self, step): """Scroll vertically by step.""" + if self.is_scrolling: + return + + self.is_scrolling = True + if self.mne.fig_selection is not None: if step == '+full': step = 1 @@ -2812,6 +3201,8 @@ def vscroll(self, step): else: step = int(step) self.mne.fig_selection._scroll_selection(step) + elif self.mne.butterfly: + return else: # Get current range and add step to it if step == '+full': @@ -2831,15 +3222,19 @@ def vscroll(self, step): def change_duration(self, step): """Change duration by step.""" - rel_step = self.mne.duration * step xmin, xmax = self.mne.viewbox.viewRange()[0] if self.mne.is_epochs: # use the length of one epoch as duration change min_dur = len(self.mne.inst.times) / self.mne.info['sfreq'] + step_dir = (1 if step > 0 else -1) + rel_step = min_dur * step_dir + self.mne.n_epochs = np.clip(self.mne.n_epochs + step_dir, + 1, len(self.mne.inst)) else: # never show fewer than 3 samples min_dur = 3 * np.diff(self.mne.inst.times[:2])[0] + rel_step = self.mne.duration * step xmax += rel_step @@ -2854,6 +3249,7 @@ def change_duration(self, step): if xmin < 0: xmin = 0 + self.mne.ax_hscroll.update_duration() self.mne.plt.setXRange(xmin, xmax, padding=0) def change_nchan(self, step): @@ -2875,10 +3271,11 @@ def change_nchan(self, step): if ymax - ymin <= 2: ymax = ymin + 2 + self.mne.ax_vscroll.update_nchan() self.mne.plt.setYRange(ymin, ymax, padding=0) def _remove_vline(self): - if self.mne.vline: + if self.mne.vline is not None: if self.mne.is_epochs: for vline in self.mne.vline: self.mne.plt.removeItem(vline) @@ -2887,14 +3284,55 @@ def _remove_vline(self): self.mne.vline = None self.mne.vline_visible = False + self.mne.overview_bar.update_vline() + + def _get_vline_times(self, t): + rel_time = t % self.mne.epoch_dur + abs_time = self.mne.times[0] + ts = np.arange( + self.mne.n_epochs) * self.mne.epoch_dur + abs_time + rel_time - def _add_vline(self, pos): - # Remove vline if already shown - self._remove_vline() + return ts + + def _vline_slot(self, orig_vline): + if self.mne.is_epochs: + ts = self._get_vline_times(orig_vline.value()) + for vl, xt in zip(self.mne.vline, ts): + if vl != orig_vline: + vl.setPos(xt) + self.mne.overview_bar.update_vline() + + def _add_vline(self, t): + if self.mne.is_epochs: + ts = self._get_vline_times(t) + + # Add vline if None + if self.mne.vline is None: + self.mne.vline = list() + for xt in ts: + epo_idx = np.searchsorted(self.mne.boundary_times, xt) - 1 + bmin, bmax = self.mne.boundary_times[epo_idx:epo_idx + 2] + # Avoid off-by-one-error at bmax for VlineLabel + bmax -= 1 / self.mne.info['sfreq'] + vl = VLine(self.mne, xt, bounds=(bmin, bmax)) + # Should only be emitted when dragged + vl.sigPositionChangeFinished.connect(self._vline_slot) + self.mne.vline.append(vl) + self.mne.plt.addItem(vl) + else: + for vl, xt in zip(self.mne.vline, ts): + vl.setPos(xt) + else: + if self.mne.vline is None: + self.mne.vline = VLine(self.mne, t, bounds=(0, self.mne.xmax)) + self.mne.vline.sigPositionChangeFinished.connect( + self._vline_slot) + self.mne.plt.addItem(self.mne.vline) + else: + self.mne.vline.setPos(t) - self.mne.vline = VLine(pos, bounds=(0, self.mne.xmax)) - self.mne.plt.addItem(self.mne.vline) self.mne.vline_visible = True + self.mne.overview_bar.update_vline() def _mouse_moved(self, pos): """Show Crosshair if enabled at mouse move.""" @@ -2939,12 +3377,39 @@ def _toggle_crosshair(self): def _xrange_changed(self, _, xrange): # Update data + if self.mne.is_epochs: + if self.mne.vline is not None: + rel_vl_t = self.mne.vline[0].value() \ + - self.mne.boundary_times[self.mne.epoch_idx][0] + + # Depends on only allowing xrange showing full epochs + boundary_idxs = np.searchsorted(self.mne.midpoints, xrange) + self.mne.epoch_idx = np.arange(*boundary_idxs) + + # Update colors + for trace in self.mne.traces: + trace.update_color() + + # Update vlines + if self.mne.vline is not None: + for bmin, bmax, vl in zip(self.mne.boundary_times[ + self.mne.epoch_idx], + self.mne.boundary_times[ + self.mne.epoch_idx + 1], + self.mne.vline): + # Avoid off-by-one-error at bmax for VlineLabel + bmax -= 1 / self.mne.info['sfreq'] + vl.setBounds((bmin, bmax)) + vl.setValue(bmin + rel_vl_t) + self.mne.t_start = xrange[0] self.mne.duration = xrange[1] - xrange[0] + self._redraw(update_data=True) # Update annotations - self._update_annotations_xrange(xrange) + if not self.mne.is_epochs: + self._update_annotations_xrange(xrange) # Update Events self._update_events_xrange(xrange) @@ -2958,6 +3423,9 @@ def _xrange_changed(self, _, xrange): # Update Scalebars self._update_scalebar_x_positions() + # Relieve Scrolling-Block + self.is_scrolling = False + def _update_events_xrange(self, xrange): """Add or remove event-lines depending on view-range. @@ -3035,7 +3503,7 @@ def _yrange_changed(self, _, yrange): # Only remove from traces not in picks. remove_traces = off_traces[:abs(trace_diff)] for trace in remove_traces: - self._remove_trace(trace) + trace.remove() off_traces.remove(trace) # Add new traces if necessary. @@ -3043,7 +3511,7 @@ def _yrange_changed(self, _, yrange): # Make copy to avoid skipping iteration. idxs_copy = add_idxs.copy() for aidx in idxs_copy[:trace_diff]: - self._add_trace(aidx) + DataTrace(self, aidx) add_idxs.remove(aidx) # Update data of traces outside of yrange (reuse remaining trace-items) @@ -3052,6 +3520,9 @@ def _yrange_changed(self, _, yrange): trace.update_color() trace.update_data() + # Relieve Scrolling-Block + self.is_scrolling = False + # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # DATA HANDLING # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # @@ -3136,6 +3607,10 @@ def _precompute_finished(self): # Show loaded overview image self.mne.overview_bar.set_background() + if self._rerun_load_thread: + self._rerun_load_thread = False + self._init_precompute() + def _init_precompute(self): # Remove previously loaded data self.mne.data_precomputed = False @@ -3153,13 +3628,14 @@ def _init_precompute(self): # Start precompute thread self.mne.load_progressbar.show() self.mne.load_prog_label.show() - self.load_thread = LoadThread(self) - self.load_thread.loadProgress.connect(self.mne. - load_progressbar.setValue) - self.load_thread.processText.connect(self._show_process) - self.load_thread.loadingFinished.connect(self._precompute_finished) self.load_thread.start() + def _rerun_precompute(self): + if self.load_thread.isRunning(): + self._rerun_load_thread = True + else: + self._init_precompute() + def _check_space_for_precompute(self): try: import psutil @@ -3169,10 +3645,14 @@ def _check_space_for_precompute(self): 'Setting precompute to False.') return False else: - if self.mne.inst.filenames[0]: + if self.mne.is_epochs: + files = [self.mne.inst.filename] + else: + files = self.mne.inst.filenames + if files[0] is not None: # Get disk-space of raw-file(s) disk_space = 0 - for fn in self.mne.inst.filenames: + for fn in files: disk_space += getsize(fn) # Determine expected RAM space based on orig_format @@ -3246,8 +3726,8 @@ def _update_data(self): # decim can vary by channel type, # so compute different `times` vectors. self.mne.decim_times = {decim_value: self.mne.times[::decim_value] - + self.mne.first_time for - decim_value in set(self.mne.decim_data)} + + self.mne.first_time for decim_value + in set(self.mne.decim_data)} # Apply clipping if self.mne.clipping == 'clamp': @@ -3305,7 +3785,7 @@ def _add_region(self, plot_onset, duration, description, region=None): region = AnnotRegion(self.mne, description=description, values=(plot_onset, plot_onset + duration)) if (any([self.mne.t_start < v < self.mne.t_start + self.mne.duration - for v in [plot_onset, plot_onset + duration]]) and + for v in [plot_onset, plot_onset + duration]]) and region not in self.mne.plt.items): self.mne.plt.addItem(region) self.mne.plt.addItem(region.label_item) @@ -3313,7 +3793,7 @@ def _add_region(self, plot_onset, duration, description, region=None): region.gotSelected.connect(self._region_selected) region.removeRequested.connect(self._remove_region) self.mne.viewbox.sigYRangeChanged.connect( - region.update_label_pos) + region.update_label_pos) self.mne.regions.append(region) region.update_label_pos() @@ -3376,6 +3856,35 @@ def _draw_annotations(self): # which is faster than handling adding/removing in Python. pass + def _init_annot_mode(self): + self.mne.annotations_visible = True + self.mne.new_annotation_labels = self._get_annotation_labels() + if len(self.mne.new_annotation_labels) > 0: + self.mne.current_description = self.mne.new_annotation_labels[0] + else: + self.mne.current_description = None + self._setup_annotation_colors() + self.mne.regions = list() + self.mne.selected_region = None + + # Initialize Annotation-Dock + existing_dock = getattr(self.mne, 'fig_annotation', None) + if existing_dock is None: + fig_annotation = AnnotationDock(self) + self.addDockWidget(Qt.TopDockWidgetArea, fig_annotation) + fig_annotation.setVisible(False) + vars(self.mne).update(fig_annotation=fig_annotation) + + # Add annotations as regions + for annot in self.mne.inst.annotations: + plot_onset = _sync_onset(self.mne.inst, annot['onset']) + duration = annot['duration'] + description = annot['description'] + self._add_region(plot_onset, duration, description) + + # Initialize showing annotation widgets + self._change_annot_mode() + def _change_annot_mode(self): if not self.mne.annotation_mode: # Reset Widgets in Annotation-Figure @@ -3397,13 +3906,14 @@ def _change_annot_mode(self): self.mne.selected_region.select(self.mne.annotation_mode) def _toggle_annotation_fig(self): - self.mne.annotation_mode = not self.mne.annotation_mode - self._change_annot_mode() + if not self.mne.is_epochs: + self.mne.annotation_mode = not self.mne.annotation_mode + self._change_annot_mode() def _update_regions_visible(self): for region in self.mne.regions: region.update_visible( - self.mne.visible_annotations[region.description]) + self.mne.visible_annotations[region.description]) self.mne.overview_bar.update_annotations() def _set_annotations_visible(self, visible): @@ -3437,7 +3947,7 @@ def _apply_update_projectors(self, toggle_all=False): self.mne.projs_on = new_state self._update_projector() # If data was precomputed it needs to be precomputed again. - self._init_precompute() + self._rerun_precompute() self._redraw() def _toggle_proj_fig(self): @@ -3455,7 +3965,7 @@ def _toggle_all_projs(self): def _toggle_whitening(self): super()._toggle_whitening() # If data was precomputed it needs to be precomputed again. - self._init_precompute() + self._rerun_precompute() self._redraw() def _toggle_settings_fig(self): @@ -3484,13 +3994,16 @@ def _set_butterfly(self, butterfly): for pick in picks: self.mne.selection_ypos_dict[pick] = idx + 1 ymax = len(selections_dict) + 1 + self.mne.ymax = ymax self.mne.plt.setLimits(yMax=ymax) self.mne.plt.setYRange(0, ymax, padding=0) elif butterfly: ymax = len(self.mne.butterfly_type_order) + 1 + self.mne.ymax = ymax self.mne.plt.setLimits(yMax=ymax) self.mne.plt.setYRange(0, ymax, padding=0) else: + self.mne.ymax = len(self.mne.ch_order) + 1 self.mne.plt.setLimits(yMax=self.mne.ymax) self.mne.plt.setYRange(self.mne.ch_start, self.mne.ch_start + self.mne.n_channels + 1, @@ -3509,8 +4022,8 @@ def _set_butterfly(self, butterfly): # update ypos and color for butterfly-mode for trace in self.mne.traces: - trace.update_ypos() trace.update_color() + trace.update_ypos() self._draw_traces() @@ -3615,12 +4128,38 @@ def _create_ch_context_fig(self, idx): if fig is not None: self._get_dlg_from_mpl(fig) + def _toggle_epoch_histogramm(self): + if self.mne.is_epochs: + fig = self._create_epoch_histogram() + if fig is not None: + self._get_dlg_from_mpl(fig) + def _update_trace_offsets(self): pass def _create_selection_fig(self): SelectionDialog(self) + def message_box(self, text, info_text=None, buttons=None, + default_button=None, icon=None, modal=True): + self.msg_box.setText(f'{text}') + if info_text is not None: + self.msg_box.setInformativeText(info_text) + if buttons is not None: + self.msg_box.setStandardButtons(buttons) + if default_button is not None: + self.msg_box.setDefaultButton(default_button) + if icon is not None: + self.msg_box.setIcon(icon) + + # Allow interacting with message_box in test-mode. + # Set modal=False only if no return value is expected. + self.msg_box.setModal(False if self.test_mode else modal) + if self.test_mode or not modal: + self.msg_box.show() + else: + return self.msg_box.exec() + def keyPressEvent(self, event): """Customize key press events.""" # On MacOs additionally KeypadModifier is set when arrow-keys @@ -3677,6 +4216,9 @@ def _fake_keypress(self, key, fig=None): if key.isupper(): key = key.lower() modifier = Qt.ShiftModifier + elif key.startswith('shift+'): + key = key[6:] + modifier = Qt.ShiftModifier else: modifier = Qt.NoModifier @@ -3732,7 +4274,7 @@ def _fake_click(self, point, add_points=None, fig=None, ax=None, point = self.mne.viewbox.mapViewToScene(Point(*point)) for idx, apoint in enumerate(add_points): add_points[idx] = self.mne.viewbox.mapViewToScene( - Point(*apoint)) + Point(*apoint)) elif xform == 'none' or xform is None: if isinstance(point, (tuple, list)): @@ -3773,6 +4315,10 @@ def _fake_scroll(self, x, y, step, fig=None): self.vscroll(step) def _click_ch_name(self, ch_index, button): + self.mne.channel_axis.repaint() + # Wait because channel-axis may need time + # (came up with test_epochs::test_plot_epochs_clicks) + QTest.qWait(100) if not self.mne.butterfly: ch_name = self.mne.ch_names[self.mne.picks[ch_index]] xrange, yrange = self.mne.channel_axis.ch_texts[ch_name] diff --git a/mne_qt_browser/_version.py b/mne_qt_browser/_version.py index fead2c14..3e6615e7 100644 --- a/mne_qt_browser/_version.py +++ b/mne_qt_browser/_version.py @@ -1,2 +1,2 @@ """The version number.""" -__version__ = '0.1.8.dev0' +__version__ = '0.2.0' diff --git a/mne_qt_browser/conftest.py b/mne_qt_browser/conftest.py index 0b546ba0..85ca5127 100644 --- a/mne_qt_browser/conftest.py +++ b/mne_qt_browser/conftest.py @@ -1,22 +1,16 @@ # -*- coding: utf-8 -*- -# Author: Eric Larson +# Authors: Eric Larson +# Martin Schulz # # License: BSD-3-Clause import pytest -from mne.viz import use_browser_backend from mne.conftest import (raw_orig, pg_backend, garbage_collect) # noqa: F401 -_store = dict() - - -@pytest.fixture -def browser_backend(garbage_collect): # noqa: F811 - """Parametrizes the name of the browser backend.""" - with use_browser_backend('pyqtgraph') as backend: - yield backend +_store = {'Raw': {}, + 'Epochs': {}} def pytest_configure(config): @@ -39,8 +33,10 @@ def pytest_sessionfinish(session, exitstatus): writer = TerminalWriter() writer.line() # newline writer.sep('=', 'benchmark results') - for name, vals in _store.items(): - writer.line( - f'{name}:\n' - f' Horizontal: {vals["h"]:6.2f}\n' - f' Vertical: {vals["v"]:6.2f}') + for type_name, results in _store.items(): + writer.sep('-', type_name) + for name, vals in results.items(): + writer.line( + f'{name}:\n' + f' Horizontal: {vals["h"]:6.2f}\n' + f' Vertical: {vals["v"]:6.2f}') diff --git a/mne_qt_browser/tests/test_pg_specific.py b/mne_qt_browser/tests/test_pg_specific.py new file mode 100644 index 00000000..618598c8 --- /dev/null +++ b/mne_qt_browser/tests/test_pg_specific.py @@ -0,0 +1,83 @@ +# -*- coding: utf-8 -*- +# Author: Martin Schulz +# +# License: BSD-3-Clause + +import numpy as np + + +def test_annotations_interactions(raw_orig, pg_backend): + """Test interactions specific to pyqtgraph-backend.""" + # Add test-annotations + onsets = np.arange(2, 8, 2) + raw_orig.first_time + durations = np.repeat(1, len(onsets)) + descriptions = ['A', 'B', 'C'] + for onset, duration, description in zip(onsets, durations, descriptions): + raw_orig.annotations.append(onset, duration, description) + n_anns = len(raw_orig.annotations) + fig = raw_orig.plot() + fig.test_mode = True + annot_dock = fig.mne.fig_annotation + + # Activate annotation_mode + fig._fake_keypress('a') + + # Set current description to index 1 + annot_dock.description_cmbx.activated.emit(1) + assert fig.mne.current_description == 'B' + + # Draw additional annotation + fig._fake_click((8., 1.), add_points=[(9., 1.)], xform='data', button=1, + kind='drag') + assert len(raw_orig.annotations.onset) == n_anns + 1 + assert len(raw_orig.annotations.duration) == n_anns + 1 + assert len(raw_orig.annotations.description) == n_anns + 1 + assert raw_orig.annotations.description[-1] == 'B' + + # Test remove all regions description + annot_dock._remove_description('B') + assert len(raw_orig.annotations.onset) == n_anns - 1 + assert len(raw_orig.annotations.duration) == n_anns - 1 + assert len(raw_orig.annotations.description) == n_anns - 1 + assert fig.mne.current_description == 'A' + assert fig.mne.selected_region is None + + # Redraw annotation (now with 'A') + fig._fake_click((4., 1.), add_points=[(5., 1.)], xform='data', button=1, + kind='drag') + assert len(raw_orig.annotations.onset) == n_anns + assert len(raw_orig.annotations.duration) == n_anns + assert len(raw_orig.annotations.description) == n_anns + + # Test editing descriptions (all) + annot_dock._edit_description_all('D') + assert len(np.where(raw_orig.annotations.description == 'D')[0]) == 2 + + # Test editing descriptions (selected) + # Select second region + fig._fake_click((4.5, 1.), xform='data') + assert fig.mne.selected_region.description == 'D' + annot_dock._edit_description_selected('E') + assert raw_orig.annotations.description[1] == 'E' + + # Test Spinbox behaviour + # Update of Spinboxes + fig._fake_click((2.5, 1.), xform='data') + assert annot_dock.start_bx.value() == 2. + assert annot_dock.stop_bx.value() == 3. + + # Setting values with Spinboxex + annot_dock.start_bx.setValue(1.5) + annot_dock.start_bx.editingFinished.emit() + annot_dock.stop_bx.setValue(3.5) + annot_dock.stop_bx.editingFinished.emit() + assert raw_orig.annotations.onset[0] == 1.5 + raw_orig.first_time + assert raw_orig.annotations.duration[0] == 2. + + # Test SpinBox Warning + annot_dock.start_bx.setValue(6) + annot_dock.start_bx.editingFinished.emit() + assert fig.msg_box.isVisible() + assert fig.msg_box.informativeText() == 'Start can\'t be bigger or ' \ + 'equal to Stop!' + fig.msg_box.close() diff --git a/mne_qt_browser/tests/test_speed.py b/mne_qt_browser/tests/test_speed.py index 29374405..3f80d3b4 100644 --- a/mne_qt_browser/tests/test_speed.py +++ b/mne_qt_browser/tests/test_speed.py @@ -1,9 +1,20 @@ +# -*- coding: utf-8 -*- +# Authors: Eric Larson +# Martin Schulz +# +# License: BSD-3-Clause + +import sys from copy import copy from functools import partial -import sys +from time import perf_counter import numpy as np import pytest +from PyQt5.QtCore import QTimer +from PyQt5.QtWidgets import QApplication + +import mne bm_limit = 50 bm_count = copy(bm_limit) @@ -11,7 +22,8 @@ vscroll_dir = True h_last_time = None v_last_time = None - +hscroll_diffs = list() +vscroll_diffs = list() try: import OpenGL # noqa @@ -22,86 +34,176 @@ has_gl = True reason = '' gl_mark = pytest.mark.skipif( - not has_gl, reason=f'Requires PyOpengl (got {reason})') + not has_gl, reason=f'Requires PyOpengl (got {reason})') -@pytest.mark.benchmark -@pytest.mark.parametrize('benchmark_param', [ - pytest.param({'use_opengl': False}, id='use_opengl=False'), - pytest.param({'use_opengl': True}, id='use_opengl=True', marks=gl_mark), - pytest.param({'precompute': False}, id='precompute=False'), - pytest.param({'precompute': True}, id='precompute=True'), - pytest.param({}, id='defaults'), -]) -def test_scroll_speed(raw_orig, benchmark_param, store, pg_backend, request): - """Test the speed of a parameter.""" - # Remove spaces and get params with values - from time import perf_counter - - from PyQt5.QtCore import QTimer - from PyQt5.QtWidgets import QApplication +def _reinit_bm_values(): + global bm_count + global hscroll_dir + global vscroll_dir + global h_last_time + global v_last_time + global hscroll_diffs + global vscroll_diffs + bm_limit = 50 + bm_count = copy(bm_limit) + hscroll_dir = True + vscroll_dir = True + h_last_time = None + v_last_time = None hscroll_diffs = list() vscroll_diffs = list() - def _initiate_hscroll(pg_fig): - global bm_count - global hscroll_dir - global vscroll_dir - global h_last_time - global v_last_time - - if bm_count > 0: - bm_count -= 1 - # Scroll in horizontal direction and turn at ends. - if pg_fig.mne.t_start + pg_fig.mne.duration \ - >= pg_fig.mne.inst.times[-1]: - hscroll_dir = False - elif pg_fig.mne.t_start <= 0: - hscroll_dir = True - key = 'right' if hscroll_dir else 'left' - pg_fig._fake_keypress(key) - # Get time-difference - now = perf_counter() - if h_last_time is not None: - hscroll_diffs.append(now - h_last_time) - h_last_time = now - elif bm_count > -bm_limit: - bm_count -= 1 - # Scroll in vertical direction and turn at ends. - if pg_fig.mne.ch_start + pg_fig.mne.n_channels \ - >= len(pg_fig.mne.inst.ch_names): - vscroll_dir = False - elif pg_fig.mne.ch_start <= 0: - vscroll_dir = True - key = 'down' if vscroll_dir else 'up' - pg_fig._fake_keypress(key) - # get time-difference - now = perf_counter() - if v_last_time is not None: - vscroll_diffs.append(now - v_last_time) - v_last_time = now + +def _initiate_hscroll(pg_fig, store, request, timer): + global bm_count + global hscroll_dir + global vscroll_dir + global h_last_time + global v_last_time + if bm_count > 0: + bm_count -= 1 + + if pg_fig.mne.is_epochs: + t_limit = pg_fig.mne.boundary_times[-1] else: - timer.stop() - bm_count = copy(bm_limit) + t_limit = pg_fig.mne.inst.times[-1] + + # Scroll in horizontal direction and turn at ends. + if pg_fig.mne.t_start + pg_fig.mne.duration >= t_limit: + hscroll_dir = False + elif pg_fig.mne.t_start <= 0: hscroll_dir = True + key = 'right' if hscroll_dir else 'left' + pg_fig._fake_keypress(key) + # Get time-difference + now = perf_counter() + if h_last_time is not None: + hscroll_diffs.append(now - h_last_time) + h_last_time = now + elif bm_count > -bm_limit: + bm_count -= 1 + # Scroll in vertical direction and turn at ends. + if pg_fig.mne.ch_start + pg_fig.mne.n_channels \ + >= len(pg_fig.mne.ch_order): + vscroll_dir = False + elif pg_fig.mne.ch_start <= 0: vscroll_dir = True - h_last_time = None - v_last_time = None + key = 'down' if vscroll_dir else 'up' + pg_fig._fake_keypress(key) + # get time-difference + now = perf_counter() + if v_last_time is not None: + vscroll_diffs.append(now - v_last_time) + v_last_time = now + else: + timer.stop() + + h_mean_fps = 1 / np.median(hscroll_diffs) + v_mean_fps = 1 / np.median(vscroll_diffs) + type_key = 'Epochs' if pg_fig.mne.is_epochs else 'Raw' + store[type_key][request.node.callspec.id] = dict(h=h_mean_fps, + v=v_mean_fps) + pg_fig.close() + + +@pytest.mark.benchmark +@pytest.mark.parametrize('benchmark_param', [ + pytest.param({'use_opengl': False, 'precompute': False}, + id='use_opengl=False'), + pytest.param({'use_opengl': True, 'precompute': False}, + id='use_opengl=True', marks=gl_mark), + pytest.param({'precompute': False, 'use_opengl': False}, + id='precompute=False'), + pytest.param({'precompute': True, 'use_opengl': False}, + id='precompute=True'), + pytest.param({}, id='defaults'), +]) +def test_scroll_speed_raw(raw_orig, benchmark_param, store, + pg_backend, request): + """Test the speed of a parameter.""" + # Remove spaces and get params with values - h_mean_fps = 1 / np.median(hscroll_diffs) - v_mean_fps = 1 / np.median(vscroll_diffs) - store[request.node.callspec.id] = dict( - h=h_mean_fps, v=v_mean_fps) - pg_fig.close() + _reinit_bm_values() app = QApplication.instance() if app is None: app = QApplication(sys.argv) fig = raw_orig.plot(duration=5, n_channels=40, show=False, block=False, **benchmark_param) + + # # Wait max. 10 s for precomputed data to load + if fig.load_thread.isRunning(): + fig.load_thread.wait(10000) + + timer = QTimer() + timer.timeout.connect(partial(_initiate_hscroll, fig, store, + request, timer)) + timer.start(0) + + fig.show() + with pytest.raises(SystemExit): + sys.exit(app.exec()) + + +@pytest.mark.benchmark +@pytest.mark.parametrize('benchmark_param', [ + pytest.param({'use_opengl': False, 'precompute': False}, + id='use_opengl=False'), + pytest.param({'use_opengl': True, 'precompute': False}, + id='use_opengl=True', marks=gl_mark), + pytest.param({'precompute': False, 'use_opengl': False}, + id='precompute=False'), + pytest.param({'precompute': True, 'use_opengl': False}, + id='precompute=True'), + pytest.param({}, id='defaults'), +]) +def test_scroll_speed_epochs(raw_orig, benchmark_param, store, + pg_backend, request): + from PyQt5.QtCore import QTimer + from PyQt5.QtWidgets import QApplication + + _reinit_bm_values() + + app = QApplication.instance() + if app is None: + app = QApplication(sys.argv) + + events = np.full((50, 3), [0, 0, 1]) + events[:, 0] = np.arange(0, len(raw_orig), len(raw_orig) / 50) \ + + raw_orig.first_samp + epochs = mne.Epochs(raw_orig, events, preload=True) + # Prevent problems with info's locked-stated + epochs.info._unlocked = True + # Make colored segments (simulating bad epochs, + # bad segments from autoreject) + epoch_col1 = np.asarray(['b'] * len(epochs.ch_names)) + epoch_col1[::2] = 'r' + epoch_col2 = np.asarray(['r'] * len(epochs.ch_names)) + epoch_col2[::2] = 'b' + epoch_col3 = np.asarray(['g'] * len(epochs.ch_names)) + epoch_col3[::2] = 'b' + epoch_colors = np.asarray([['b'] * len(epochs.ch_names) for _ in + range(len(epochs))]) + epoch_colors[::3] = epoch_col1 + epoch_colors[1::3] = epoch_col2 + epoch_colors[2::3] = epoch_col3 + epoch_colors = epoch_colors.tolist() + + if sys.platform == 'darwin': + benchmark_param['use_opengl'] = True + + fig = epochs.plot(show=False, block=False, epoch_colors=epoch_colors, + **benchmark_param) + + # # Wait max. 10 s for precomputed data to load + if fig.load_thread.isRunning(): + fig.load_thread.wait(10000) + timer = QTimer() - timer.timeout.connect(partial(_initiate_hscroll, fig)) + timer.timeout.connect(partial(_initiate_hscroll, fig, store, + request, timer)) timer.start(0) fig.show()