diff --git a/qt/python/mantidqt/mantidqt/widgets/workspacedisplay/table/group_model.py b/qt/python/mantidqt/mantidqt/widgets/workspacedisplay/table/group_model.py index da54659fc392..f06adde43d83 100644 --- a/qt/python/mantidqt/mantidqt/widgets/workspacedisplay/table/group_model.py +++ b/qt/python/mantidqt/mantidqt/widgets/workspacedisplay/table/group_model.py @@ -40,38 +40,34 @@ def __init__(self, ws: WorkspaceGroup): self.ws: WorkspaceGroup = ws self.ws_num_rows = sum(peakWs.rowCount() for peakWs in ws) - self.ws_num_cols = self.ws[0].columnCount() + 1 + self.ws_num_cols = self.ws[0].columnCount() + 2 self.marked_columns = MarkedColumns() self._original_column_headers = self.get_column_headers() self.block_model_replace = False - # loads the types of the columns - for col in range(1, len(self._original_column_headers)): - plot_type = self.ws[0].getPlotType(col - 1) + self._row_mapping = self._make_row_mapping() + self._load_col_types() + + def _make_row_mapping(self): + row_index = 0 + row_mapping = [] + for group_index, peaksWs in enumerate(self.ws): + group_start = row_index + group_end = row_index + len(peaksWs) + row_mapping.extend([(group_index, index - group_start) for index in range(group_start, group_end)]) + row_index += len(peaksWs) + return row_mapping + + def _load_col_types(self): + for col in range(2, len(self._original_column_headers)): + plot_type = self.ws[0].getPlotType(col - 2) if plot_type == TableWorkspaceColumnTypeMapping.X: self.marked_columns.add_x(col) elif plot_type == TableWorkspaceColumnTypeMapping.Y: self.marked_columns.add_y(col) elif plot_type == TableWorkspaceColumnTypeMapping.YERR: - err_for_column = self.ws[0].getLinkedYCol(col - 1) + err_for_column = self.ws[0].getLinkedYCol(col - 2) if err_for_column >= 0: - self.marked_columns.add_y_err(ErrorColumn(col, err_for_column)) - - def _get_group_and_workspace_indcies(self, row_indicies): - cumulative_size = 0 - ws_range_limits = [] - for peaksWs in self.ws: - ws_range_limits.append((cumulative_size, cumulative_size + len(peaksWs))) - cumulative_size += len(peaksWs) - - row_to_ws_index = defaultdict(list) - - for row_index in list(map(int, row_indicies.split(","))): - for group_index, ws_range_limit in enumerate(ws_range_limits): - if row_index >= ws_range_limit[0] and row_index < ws_range_limit[1]: - row_to_ws_index[group_index].append(row_index - ws_range_limit[0]) - break - - return row_to_ws_index + self.marked_columns.add_y_err(ErrorColumn(col, err_for_column + 2)) def original_column_headers(self): return self._original_column_headers[:] @@ -83,7 +79,7 @@ def get_name(self): return self.ws.name() def get_column_headers(self): - return ["Group Index"] + self.ws[0].getColumnNames() + return ["Group Index", "WS Index"] + self.ws[0].getColumnNames() def get_column(self, index): column_data = [] @@ -91,8 +87,10 @@ def get_column(self, index): for i, ws_item in enumerate(self.ws): if index == 0: column_data.extend([i] * ws_item.rowCount()) + elif index == 1: + column_data.extend([i for i in range(ws_item.rowCount())]) else: - column_data.extend(ws_item.column(index - 1)) + column_data.extend(ws_item.column(index - 2)) return column_data @@ -108,64 +106,49 @@ def get_column_header(self, index): def is_editable_column(self, icol): return self.get_column_headers()[icol] in self.EDITABLE_COLUMN_NAMES + def set_column_type(self, col, type, linked_col_index=-1): + self.ws[0].setPlotType(col - 2, type, linked_col_index - 2 if linked_col_index != -1 else linked_col_index) + def workspace_equals(self, workspace_name): return self.ws.name() == workspace_name - def set_column_type(self, col, type, linked_col_index=-1): - for peaksWs in self.ws: - peaksWs.setPlotType(col, type, linked_col_index) - def get_cell(self, row, column): - row_to_ws_index = self._get_group_and_workspace_indcies(f"{row}") + group_index, ws_index = self._row_mapping[row] - group_index, ws_index = next(iter(row_to_ws_index.items())) + if column == 0: + return group_index - return self.ws[group_index][ws_index] + if column == 1: + return ws_index + + column = column - 2 + + return self.ws[group_index].cell(ws_index, column) def set_cell_data(self, row, col, data, is_v3d): - cumulative_size = 0 - col = col - 1 - for peaksWs in self.ws: - if cumulative_size + len(peaksWs) > row: - local_index = row - cumulative_size - - p = peaksWs[local_index] - if self.ws.getColumnNames()[col] == "h": - p.setH(data) - elif self.ws.getColumnNames()[col] == "k": - p.setK(data) - elif self.ws.getColumnNames()[col] == "l": - p.setL(data) - - cumulative_size += len(peaksWs) + col = col - 2 + + group_index, ws_index = row + p = self.ws[group_index][ws_index] + if self.ws.getColumnNames()[col] == "h": + p.setH(data) + elif self.ws.getColumnNames()[col] == "k": + p.setK(data) + elif self.ws.getColumnNames()[col] == "l": + p.setL(data) def delete_rows(self, selected_rows): from mantid.simpleapi import DeleteTableRows - row_to_ws_index = self._get_group_and_workspace_indcies(selected_rows) + row_to_ws_index = defaultdict(list) + for group_index, ws_index in selected_rows: + row_to_ws_index[group_index].append(ws_index) for group_index in row_to_ws_index: DeleteTableRows(self.ws[group_index], ",".join(map(str, row_to_ws_index[group_index]))) def get_statistics(self, selected_columns): - from mantid.simpleapi import StatisticsOfTableWorkspace stats = StatisticsOfTableWorkspace(self.ws, selected_columns) return stats - - def sort(self, column_index, sort_ascending): - from mantid.simpleapi import SortPeaksWorkspace - - if column_index == 0: - return - - column_name = self.get_column_headers()[column_index] - - for peakWs in self.ws: - SortPeaksWorkspace( - InputWorkspace=peakWs, - OutputWorkspace=peakWs, - ColumnNameToSortBy=column_name, - SortAscending=sort_ascending, - ) diff --git a/qt/python/mantidqt/mantidqt/widgets/workspacedisplay/table/presenter.py b/qt/python/mantidqt/mantidqt/widgets/workspacedisplay/table/presenter.py index 27ea7b692b55..240fde7eb136 100644 --- a/qt/python/mantidqt/mantidqt/widgets/workspacedisplay/table/presenter.py +++ b/qt/python/mantidqt/mantidqt/widgets/workspacedisplay/table/presenter.py @@ -22,7 +22,8 @@ from mantidqt.widgets.workspacedisplay.table.presenter_standard import TableWorkspaceDataPresenterStandard from mantidqt.widgets.workspacedisplay.table.table_model import TableModel from mantidqt.widgets.workspacedisplay.table.view import TableWorkspaceDisplayView -from mantidqt.widgets.workspacedisplay.table.tableworkspace_item import QStandardItem, create_table_item, RevertibleItem # noqa: F401 +from mantidqt.widgets.workspacedisplay.table.group_table_model import GroupTableModel +from mantidqt.widgets.workspacedisplay.table.presenter_group import TableWorkspaceDataPresenterGroup class TableWorkspaceDisplay(ObservingPresenter, DataCopier): @@ -133,8 +134,15 @@ def _create_table_batch(self, ws, parent, window_flags, view, model): def _create_table_group(self, ws, parent, window_flags, view, model): model = model if model is not None else GroupTableWorkspaceDisplayModel(ws) - view = view if view else TableWorkspaceDisplayView(presenter=self, parent=parent, window_flags=window_flags) - self.presenter = TableWorkspaceDataPresenterStandard(model, view) + table_model = GroupTableModel(model, view) + view = ( + view + if view + else TableWorkspaceDisplayView( + presenter=self, parent=parent, window_flags=window_flags, table_model=table_model, wrap_sorting=True + ) + ) + self.presenter = TableWorkspaceDataPresenterGroup(model, view) return view, model @classmethod @@ -184,10 +192,13 @@ def action_delete_row(self): return selected_rows = selection_model.selectedRows() - selected_rows_list = [index.row() for index in selected_rows] - selected_rows_str = ",".join([str(row) for row in selected_rows_list]) - - self.presenter.model.delete_rows(selected_rows_str) + if not self.group: + selected_rows_list = [index.row() for index in selected_rows] + selected_rows_str = ",".join([str(row) for row in selected_rows_list]) + self.presenter.model.delete_rows(selected_rows_str) + else: + selected_rows_list = [index.row() for index in selected_rows] + self.presenter.delete_rows(selected_rows_list) def _get_selected_columns(self, max_selected=None, message_if_over_max=None): selection_model = self.presenter.view.selectionModel() @@ -292,7 +303,10 @@ def action_sort(self, sort_ascending): except ValueError: return - self.presenter.model.sort(selected_column, sort_ascending) + if not self.group: + self.presenter.model.sort(selected_column, sort_ascending) + else: + self.presenter.sort(selected_column, sort_ascending) def action_plot(self, plot_type): try: diff --git a/qt/python/mantidqt/mantidqt/widgets/workspacedisplay/table/view.py b/qt/python/mantidqt/mantidqt/widgets/workspacedisplay/table/view.py index c2d2545687dc..84b00465ee53 100644 --- a/qt/python/mantidqt/mantidqt/widgets/workspacedisplay/table/view.py +++ b/qt/python/mantidqt/mantidqt/widgets/workspacedisplay/table/view.py @@ -9,7 +9,7 @@ from functools import partial from qtpy import QtGui -from qtpy.QtCore import QVariant, Qt, Signal, Slot +from qtpy.QtCore import QVariant, Qt, Signal, Slot, QSortFilterProxyModel from qtpy.QtGui import QKeySequence, QStandardItemModel from qtpy.QtWidgets import QAction, QHeaderView, QItemEditorFactory, QMenu, QMessageBox, QStyledItemDelegate, QTableView @@ -34,10 +34,16 @@ def createEditor(self, user_type, parent): class TableWorkspaceDisplayView(QTableView): repaint_signal = Signal() - def __init__(self, presenter=None, parent=None, window_flags=Qt.Window, table_model=None): + def __init__(self, presenter=None, parent=None, window_flags=Qt.Window, table_model=None, wrap_sorting=False): super().__init__(parent) self.table_model = table_model if table_model else QStandardItemModel(parent) - self.setModel(self.table_model) + + if wrap_sorting: + sorting_model = QSortFilterProxyModel() + sorting_model.setSourceModel(self.table_model) + self.setModel(sorting_model) + else: + self.setModel(self.table_model) self.presenter = presenter self.COPY_ICON = mantidqt.icons.get_icon("mdi.content-copy") @@ -56,6 +62,9 @@ def __init__(self, presenter=None, parent=None, window_flags=Qt.Window, table_mo self.setWindowFlags(window_flags) + def model(self): + return self.table_model + def columnCount(self): return self.table_model.columnCount() @@ -245,3 +254,6 @@ def ask_confirmation(self, message, title="Mantid Workbench"): def show_warning(self, message, title="Mantid Workbench"): QMessageBox.warning(self, title, message) + + def sortBySelectedColumn(self, selected_column, sort_ascending): + self.sortByColumn(selected_column, Qt.AscendingOrder if sort_ascending else Qt.DescendingOrder)