In [None]:
import sys
import os

dir_notebook = os.path.dirname(os.path.abspath("__file__"))
# 親ディレクトリのパスを取得
dir_parent = os.path.dirname(dir_notebook)
if not dir_parent in sys.path:
    sys.path.append(dir_parent)

from PyQt5.QtWidgets import QMainWindow, QWidget, QGridLayout, QVBoxLayout, QHBoxLayout, QApplication
from optic.config import *
from optic.controls import *
from optic.dialog import *
from optic.gui import *
from optic.io import *
from optic.manager import *
from optic.gui.bind_func import *

class Suite2pROICurationGUI(QMainWindow):
    def __init__(self):
        APP_NAME = "SUITE2P_ROI_CURATION"
        QMainWindow.__init__(self)
        self.widget_manager, self.config_manager, self.data_manager, self.control_manager, self.layout_manager = initManagers(
            WidgetManager(), ConfigManager(), DataManager(), ControlManager(), LayoutManager()
        )
        self.config_manager.setCurrentApp(APP_NAME)
        self.app_keys = self.config_manager.gui_defaults["APP_KEYS"]
        self.app_key_pri = self.app_keys[0]

        self.setupUI_done = False
        setupMainWindow(self, self.config_manager.gui_defaults)

        self.initUI()

    """
    setup UI Function
    """
    def initUI(self):
        self.central_widget = QWidget(self)
        self.setCentralWidget(self.central_widget)
        self.layout_main = QGridLayout(self.central_widget)

        # FileLoadUI用のレイアウト
        self.layout_file_load = QVBoxLayout()
        self.setupFileLoadUI()
        self.layout_main.addLayout(self.layout_file_load, 1, 0, 1, 1)

        # メインUI用のレイアウト
        self.layout_main_ui = QGridLayout()
        self.layout_main.addLayout(self.layout_main_ui, 0, 0, 1, 1)

    def setupFileLoadUI(self):
        file_load_widget = QWidget()
        layout = QVBoxLayout(file_load_widget)
        # ファイル読み込み用のUIを追加
        layout.addLayout(self.makeLayoutSectionBottom())
        # bindFunc
        self.bindFuncFileLoadUI()

        self.layout_file_load.addWidget(file_load_widget)

    def loadFilePathsandInitialize(self):
        self.control_manager, self.data_manager = initManagers(self.control_manager, self.data_manager)
        success = self.loadData()
        if success:
            QMessageBox.information(self, "File load", "File loaded successfully!")
            self.setupMainUI()
        else:
            QMessageBox.warning(self, "File Load Error", "Failed to load the file.")
            return

    def setupMainUI(self):
        if self.setupUI_done:
            # メインUIのクリア
            clearLayout(self.layout_main_ui)
        
        # 新しいメインUIの設定
        self.setupMainUILayouts()
        self.setupControls()
        self.bindFuncAllWidget()

        self.setupUI_done = True

    def loadData(self):
        success = self.data_manager.loadFallMat(
            app_key=self.app_key_pri, 
            path_fall=self.widget_manager.dict_lineedit[f"path_fall_{self.app_key_pri}"].text()
        )
        if self.widget_manager.dict_lineedit[f"path_reftif_{self.app_key_pri}"].text() != "":
            success = self.data_manager.loadTifImage(
                app_key=self.app_key_pri,
                path_image=self.widget_manager.dict_lineedit[f"path_reftif_{self.app_key_pri}"].text(), 
            )
        return success

    def setupMainUILayouts(self):
        self.layout_main_ui.addLayout(self.makeLayoutSectionLeftUpper(), 0, 0)
        self.layout_main_ui.addLayout(self.makeLayoutSectionMiddleUpper(), 0, 1)
        self.layout_main_ui.addLayout(self.makeLayoutSectionRightUpper(), 0, 2)

    def setupControls(self):
        self.control_manager.table_controls[self.app_key_pri] = TableControl(
            app_key=self.app_key_pri,
            q_table=self.widget_manager.dict_table[self.app_key_pri],
            data_manager=self.data_manager,
            widget_manager=self.widget_manager,
            config_manager=self.config_manager,
            control_manager=self.control_manager,
        )
        
        self.control_manager.table_controls[self.app_key_pri].setupWidgetROITable(self.app_key_pri)
        self.control_manager.view_controls[self.app_key_pri] = ViewControl(
            app_key=self.app_key_pri,
            q_view=self.widget_manager.dict_view[self.app_key_pri], 
            q_scene=self.widget_manager.dict_scene[self.app_key_pri], 
            data_manager=self.data_manager, 
            widget_manager=self.widget_manager,
            config_manager=self.config_manager,
            control_manager=self.control_manager,
        )
        self.control_manager.view_controls[self.app_key_pri].setViewSize()
        self.control_manager.canvas_controls[self.app_key_pri] = CanvasControl(
            app_key=self.app_key_pri,
            figure=self.widget_manager.dict_figure[self.app_key_pri], 
            canvas=self.widget_manager.dict_canvas[self.app_key_pri], 
            data_manager=self.data_manager, 
            widget_manager=self.widget_manager,
            config_manager=self.config_manager,
            control_manager=self.control_manager,
            ax_layout="triple"
        )
        self.control_manager.initializeSkipROITypes(self.app_key_pri, self.control_manager.table_controls[self.app_key_pri].table_columns)

    """
    makeLayout Function; Component
    小要素のLayout
    return -> Layout
    """

    "Bottom"
    # ファイル読み込み用UI Layout
    def makeLayoutComponentFileLoadUI(self):
        layout = QVBoxLayout()

        # LineEdit
        list_label = ["Fall mat file path", "Reference Tiff image file path (optional)"]
        list_key = [f"path_fall_{self.app_key_pri}", f"path_reftif_{self.app_key_pri}"]
        for label, key in zip(list_label, list_key):
            layout.addLayout(makeLayoutLoadFileWidget(
                self.widget_manager, 
                label=label, 
                key_label=key, 
                key_lineedit=key, 
                key_button=key
            ))
        # Button
        layout.addLayout(makeLayoutLoadFileExitHelp(self.widget_manager))
        return layout

    "Left Upper"
    def makeLayoutComponentPlotProperty(self):
        layout = QHBoxLayout()
        layout.addLayout(makeLayoutLightPlotMode(self.widget_manager, self.config_manager))
        layout.addLayout(makeLayoutMinimumPlotRange(self.widget_manager, self.config_manager, self.app_key_pri))
        return layout
    
    # EventFile load, plot property
    def makeLayoutComponentEventFilePlotProperty(self):
        layout = makeLayoutEventFilePlotProperty(
            self.widget_manager, 
            f"{self.app_key_pri}_load_eventfile",
            f"{self.app_key_pri}_clear_eventfile",
            f"{self.app_key_pri}_plot_eventfile",
            f"{self.app_key_pri}_plot_eventfile_ffneu",
            f"{self.app_key_pri}_plot_eventfile_dff0",
            f"{self.app_key_pri}_eventfile_prop_range",
            f"{self.app_key_pri}_eventfile_prop_ffneu",
            f"{self.app_key_pri}_eventfile_prop_dff0",
            f"{self.app_key_pri}_eventfile_loaded",
            f"{self.app_key_pri}_eventfile_prop_range",
            f"{self.app_key_pri}_eventfile_prop_ffneu",
            f"{self.app_key_pri}_eventfile_prop_dff0",
            f"{self.app_key_pri}_eventfile_loaded",
            self.app_key_pri)
        return layout
    
    "Middle Upper"
    # ROI view
    def makeLayoutComponentROIView(self):
        layout = makeLayoutViewWithZTSlider(self.widget_manager, self.app_key_pri)
        return layout

    # ROI property label
    def makeLayoutComponentROIPropertyDisplay_Threshold(self):
        layout = QVBoxLayout()
        layout.addLayout(makeLayoutROIProperty(self.widget_manager, key_label=f"{self.app_key_pri}_roi_prop"))
        return layout

    # ROI display, background image button group, checkbox
    def makeLayoutComponentROIDisplay_BGImageDisplay_ROISkip(self):
        layout = QHBoxLayout()
        layout.addWidget(makeLayoutWidgetDislplayCelltype(
            self.widget_manager, 
            key_label=f'{self.app_key_pri}_display_celltype',
            key_checkbox=f'{self.app_key_pri}_display_celltype', 
            key_scrollarea=f'{self.app_key_pri}_display_celltype', 
            table_columns=self.config_manager.table_columns[self.app_key_pri],
            gui_defaults=self.config_manager.gui_defaults,
        ))
        layout.addWidget(makeLayoutWidgetDislplayCheckbox(
            self.widget_manager, 
            key_label=f'{self.app_key_pri}_display_checkbox',
            key_checkbox=f'{self.app_key_pri}_display_checkbox', 
            key_scrollarea=f'{self.app_key_pri}_display_checkbox', 
            table_columns=self.config_manager.table_columns[self.app_key_pri],
            gui_defaults=self.config_manager.gui_defaults,
        ))
        layout.addWidget(makeLayoutWidgetBGImageTypeDisplay(
            self, 
            self.widget_manager, 
            key_label=f'{self.app_key_pri}_im_bg_type',
            key_buttongroup=f'{self.app_key_pri}_im_bg_type',
            key_scrollarea=f'{self.app_key_pri}_im_bg_type',
            gui_defaults=self.config_manager.gui_defaults,
        ))
        layout.addWidget(makeLayoutWidgetROIChooseSkip(
            self.widget_manager, 
            key_label=f'{self.app_key_pri}_skip_celltype',
            key_checkbox=f'{self.app_key_pri}_skip_celltype', 
            key_scrollarea=f'{self.app_key_pri}_skip_celltype', 
            table_columns=self.config_manager.table_columns[self.app_key_pri],
            gui_defaults=self.config_manager.gui_defaults,
        ))
        return layout

    # channel contrast, ROI opacity slider
    def makeLayoutComponentContrastOpacitySlider(self):
        layout = QVBoxLayout()
        channels = self.config_manager.gui_defaults["CHANNELS"]
        layout_channel = QHBoxLayout()
        for channel in channels:
            layout_channel.addLayout(makeLayoutContrastSlider(
                self.widget_manager, 
                key_label=f"{self.app_key_pri}_{channel}", 
                key_checkbox=f"{self.app_key_pri}_{channel}", 
                key_slider=f"{self.app_key_pri}_{channel}", 
                label_checkbox=f"Show {channel} channel", 
                label_label=f"{channel} Value", 
                checked=True
            ))

        layout.addLayout(layout_channel)
        layout.addLayout(makeLayoutOpacitySlider(
            self.widget_manager, 
            key_label=self.app_key_pri, 
            key_slider=self.app_key_pri, 
            label=self.app_key_pri
        ))
        layout.addLayout(makeLayoutDisplayROIContours(
            self.widget_manager,
            key_checkbox_contour_all=f"{self.app_key_pri}_display_contour_all",
            key_checkbox_contour_selected=f"{self.app_key_pri}_display_contour_selected",
            key_checkbox_contour_next=f"{self.app_key_pri}_display_contour_next",
        ))
        return layout

    "Right Upper"
    # Table, ROI count label, Table Columns Config, Set ROI Celltype, ROICheck IO
    def makeLayoutComponentTable_ROICountLabel_ROISetSameCelltype_ROICheckIO(self):
        layout = QVBoxLayout()
        layout.addLayout(makeLayoutTableROICountLabel(
            self.widget_manager, 
            key_label=self.app_key_pri, 
            key_table=self.app_key_pri, 
            table_columns=self.config_manager.table_columns[self.app_key_pri]
        ))
        layout.addWidget(self.widget_manager.makeWidgetButton(key=f"{self.app_key_pri}_config_table", label="Table Columns Config"))
        layout.addWidget(self.widget_manager.makeWidgetButton(key=f"{self.app_key_pri}_roi_celltype_set", label="Set ROI Celltype"))
        layout.addLayout(makeLayoutROICheckIO(
            self.widget_manager, 
            key_button_save=f"roicuration_save_{self.app_key_pri}",
            key_button_load=f"roicuration_load_{self.app_key_pri}",
        ))
        return layout

    # ROI Filter, threshold
    def makeLayoutComponentROIFilter(self):
        layout = QHBoxLayout()
        layout.addLayout(makeLayoutROIFilterThreshold(
            self.widget_manager, 
            key_label=f"{self.app_key_pri}_roi_filter", 
            key_lineedit=f"{self.app_key_pri}_roi_filter",
            dict_roi_threshold=self.config_manager.gui_defaults["ROI_THRESHOLDS"]
        ))
        layout.addLayout(makeLayoutROIFilterButton(
            self.widget_manager, 
            key_label=f"{self.app_key_pri}_roi_filter", 
            key_button=f"{self.app_key_pri}_roi_filter"
        ))
        return layout
    

    """
    makeLayout Function; Section
    領域レベルの大Layout
    """
    # 左上
    def makeLayoutSectionLeftUpper(self):
        layout = QVBoxLayout()
        layout.addLayout(makeLayoutCanvasTracePlot(
            self.widget_manager, 
            key_figure=self.app_key_pri, 
            key_canvas=self.app_key_pri, 
            key_button=f"export_canvas_{self.app_key_pri}"
        ), stretch=1)
        layout.addLayout(self.makeLayoutComponentPlotProperty())
        layout.addLayout(self.makeLayoutComponentEventFilePlotProperty())
        return layout

    # 中上
    def makeLayoutSectionMiddleUpper(self):
        layout = QVBoxLayout()
        layout.addLayout(self.makeLayoutComponentROIView())
        layout.addLayout(self.makeLayoutComponentROIPropertyDisplay_Threshold())
        layout.addLayout(self.makeLayoutComponentROIDisplay_BGImageDisplay_ROISkip())
        layout.addLayout(self.makeLayoutComponentContrastOpacitySlider())
        return layout
    
    # 右上
    def makeLayoutSectionRightUpper(self):
        layout = QVBoxLayout()
        layout.addLayout(self.makeLayoutComponentTable_ROICountLabel_ROISetSameCelltype_ROICheckIO())
        layout.addLayout(self.makeLayoutComponentROIFilter())
        return layout

    # 下
    def makeLayoutSectionBottom(self):
        layout = self.makeLayoutComponentFileLoadUI()
        return layout
    
    """
    make SubWindow, Dialog Function
    """
    def showSubWindowTableColumnConfig(self, app_key):
        config_window = TableColumnConfigDialog(
            self, 
            self.control_manager.table_controls[app_key].table_columns, 
            self.config_manager.gui_defaults
        )
        if config_window.exec_():
            self.loadFilePathsandInitialize()

    def showSubWindowSetROICellTypeSet(self, app_key):
        celltype_window = ROICellTypeSetDialog(
            self, 
            self.app_key_pri,
            self.config_manager,
            self.control_manager.table_controls[app_key],
            self.config_manager.gui_defaults
        )
        celltype_window.show()
            

    """
    bindFunc Function
    配置したwidgetに関数を紐づけ
    """
    def bindFuncFileLoadUI(self):        
        list_key = [f"path_fall_{self.app_key_pri}", f"path_reftif_{self.app_key_pri}"]
        list_filetype = [Extension.MAT, Extension.TIFF]
        for key, filetype in zip(list_key, list_filetype):
            bindFuncLoadFileWidget(
                q_widget=self, 
                q_button=self.widget_manager.dict_button[key], 
                q_lineedit=self.widget_manager.dict_lineedit[key], 
                filetype=filetype
            )

        self.widget_manager.dict_button["load_file"].clicked.connect(lambda: self.loadFilePathsandInitialize())
        bindFuncExit(q_window=self, q_button=self.widget_manager.dict_button["exit"])
        bindFuncHelp(q_button=self.widget_manager.dict_button["help"], url=AccessURL.HELP[self.config_manager.current_app])

    def bindFuncAllWidget(self):
        # ROICheck save load
        bindFuncROICheckIO(
            q_window=self, 
            q_lineedit=self.widget_manager.dict_lineedit[f"path_fall_{self.app_key_pri}"], 
            q_button_save=self.widget_manager.dict_button[f"roicuration_save_{self.app_key_pri}"], 
            q_button_load=self.widget_manager.dict_button[f"roicuration_load_{self.app_key_pri}"], 
            q_table=self.widget_manager.dict_table[f"{self.app_key_pri}"], 
            widget_manager=self.widget_manager,
            config_manager=self.config_manager,
            control_manager=self.control_manager,
            app_key=self.app_key_pri,
            local_var=False
        )
        # Table Column Config
        self.widget_manager.dict_button[f"{self.app_key_pri}_config_table"].clicked.connect(
            lambda: self.showSubWindowTableColumnConfig(self.app_key_pri)
        )
        # Set ROI Celltype
        self.widget_manager.dict_button[f"{self.app_key_pri}_roi_celltype_set"].clicked.connect(
            lambda: self.showSubWindowSetROICellTypeSet(self.app_key_pri)
        )
        # Radiobutton BGImageType buttonChanged
        bindFuncRadiobuttonBGImageTypeChanged(
            q_buttongroup=self.widget_manager.dict_buttongroup[f"{self.app_key_pri}_im_bg_type"], 
            view_control=self.control_manager.view_controls[self.app_key_pri],
        )
        # Radiobutton ROIDisplayType checkboxChanged
        bindFuncCheckBoxDisplayCelltypeChanged(
            dict_q_checkbox_celltype={key.split("celltype_roi_display_")[-1]: self.widget_manager.dict_checkbox[key] for key in self.widget_manager.dict_checkbox.keys() if ("celltype_roi_display_" in key) and (key.split("celltype_roi_display_")[-1] in set(self.config_manager.table_columns[self.app_key_pri].getColumns().keys()))},
            dict_q_checkbox_checkbox={key.split("checkbox_roi_display_")[-1]: self.widget_manager.dict_checkbox[key] for key in self.widget_manager.dict_checkbox.keys() if ("checkbox_roi_display_" in key) and (key.split("checkbox_roi_display_")[-1] in set(self.config_manager.table_columns[self.app_key_pri].getColumns().keys()))},
            view_control=self.control_manager.view_controls[self.app_key_pri],
            table_control=self.control_manager.table_controls[self.app_key_pri],
        )
        # Checkbox ROIDisplayType checkboxChanged
        bindFuncCheckBoxDisplayCheckBoxChanged(
            dict_q_checkbox_celltype={key.split("celltype_roi_display_")[-1]: self.widget_manager.dict_checkbox[key] for key in self.widget_manager.dict_checkbox.keys() if ("celltype_roi_display_" in key) and (key.split("celltype_roi_display_")[-1] in set(self.config_manager.table_columns[self.app_key_pri].getColumns().keys()))},
            dict_q_checkbox_checkbox={key.split("checkbox_roi_display_")[-1]: self.widget_manager.dict_checkbox[key] for key in self.widget_manager.dict_checkbox.keys() if ("checkbox_roi_display_" in key) and (key.split("checkbox_roi_display_")[-1] in set(self.config_manager.table_columns[self.app_key_pri].getColumns().keys()))},
            view_control=self.control_manager.view_controls[self.app_key_pri],
            table_control=self.control_manager.table_controls[self.app_key_pri],
        )
        # Checkbox ROISkip stateChanged
        bindFuncCheckBoxROIChooseSkip(
            dict_q_checkbox={key.split("celltype_skip_choose_")[-1]: self.widget_manager.dict_checkbox[key] for key in self.widget_manager.dict_checkbox.keys() if ("celltype_skip_choose_" in key) and (key.split("celltype_skip_choose_")[-1] in set(self.config_manager.table_columns[self.app_key_pri].getColumns().keys()))},
            control_manager=self.control_manager,
            app_key=self.app_key_pri,
        )
        # Filter ROIs
        bindFuncButtonFilterROI(
            q_button=self.widget_manager.dict_button[f"{self.app_key_pri}_roi_filter"],
            dict_q_lineedit={key: self.widget_manager.dict_lineedit[f"{self.app_key_pri}_roi_filter_{key}"] for key in self.config_manager.gui_defaults["ROI_THRESHOLDS"].keys()},
            table_control=self.control_manager.table_controls[self.app_key_pri],
            view_control=self.control_manager.view_controls[self.app_key_pri],
        )
        # ROICheck Table onSelectionChanged
        bindFuncTableSelectionChanged(
            q_table=self.widget_manager.dict_table[self.app_key_pri],
            table_control=self.control_manager.table_controls[self.app_key_pri],
            view_control=self.control_manager.view_controls[self.app_key_pri],
            canvas_control=self.control_manager.canvas_controls[self.app_key_pri],
        )
        # ROICheck Table TableColumn CellType Changed
        bindFuncRadiobuttonOfTableChanged(
            table_control=self.control_manager.table_controls[self.app_key_pri],
            view_control=self.control_manager.view_controls[self.app_key_pri],
        )
        # ROICheck Table TableColumn Checkbox Changed
        bindFuncCheckboxOfTableChanged(
            table_control=self.control_manager.table_controls[self.app_key_pri],
            view_control=self.control_manager.view_controls[self.app_key_pri],
        )
        # Slider Opacity valueChanged
        bindFuncOpacitySlider(
            q_slider=self.widget_manager.dict_slider[f"{self.app_key_pri}_opacity_roi_all"],
            view_control=self.control_manager.view_controls[self.app_key_pri],
        )
        bindFuncHighlightOpacitySlider(
            q_slider=self.widget_manager.dict_slider[f"{self.app_key_pri}_opacity_roi_selected"],
            view_control=self.control_manager.view_controls[self.app_key_pri],
        )
        # Slider Contrast valueChanged, Checkbox show channel stateChanged
        for channel in self.config_manager.gui_defaults["CHANNELS"]:
            bindFuncBackgroundContrastSlider(
                q_slider_min=self.widget_manager.dict_slider[f"{self.app_key_pri}_{channel}_contrast_min"],
                q_slider_max=self.widget_manager.dict_slider[f"{self.app_key_pri}_{channel}_contrast_max"],
                view_control=self.control_manager.view_controls[self.app_key_pri],
                channel=channel
            )
            bindFuncBackgroundVisibilityCheckbox(
                q_checkbox=self.widget_manager.dict_checkbox[f"{self.app_key_pri}_{channel}_show"], 
                view_control=self.control_manager.view_controls[self.app_key_pri],
                channel=channel,
            )
        # display all, selected, next ROI Contour
        bindFuncCheckBoxDisplayROIContours(
            q_checkbox_contour_all=self.widget_manager.dict_checkbox[f"{self.app_key_pri}_display_contour_all"],
            q_checkbox_contour_selected=self.widget_manager.dict_checkbox[f"{self.app_key_pri}_display_contour_selected"],
            q_checkbox_contour_next=self.widget_manager.dict_checkbox[f"{self.app_key_pri}_display_contour_next"],
            view_control=self.control_manager.view_controls[self.app_key_pri],
        )
        # View Events
        bindFuncViewEvents(
            q_view=self.widget_manager.dict_view[self.app_key_pri],
            view_control=self.control_manager.view_controls[self.app_key_pri],
        )
        # Canvas MouseEvent
        # Top axis events
        canvas_control = self.control_manager.canvas_controls[self.app_key_pri]
        bindFuncCanvasMouseEvent(
            canvas_control.canvas,
            canvas_control,
            canvas_control.axes[AxisKeys.TOP],
            list_event=['scroll_event', 'button_press_event', 'button_release_event', 'motion_notify_event'],
            list_func=[canvas_control.onScroll, canvas_control.onPress, canvas_control.onRelease, canvas_control.onMotion]
        )
        # Middle axis events
        bindFuncCanvasMouseEvent(
            canvas_control.canvas,
            canvas_control,
            canvas_control.axes[AxisKeys.MID],
            list_event=['button_press_event'],
            list_func=[canvas_control.onClick]
        )
        # export figure
        bindFuncButtonExportFigure(
            self.widget_manager.dict_button[f"export_canvas_{self.app_key_pri}"],
            self,
            self.widget_manager.dict_figure[self.app_key_pri],
            path_dst = self.widget_manager.dict_lineedit[f"path_fall_{self.app_key_pri}"].text().replace(".mat", "_traceplot.png")
        )
        # Canvas load EventFile
        bindFuncButtonEventfileIO(
            q_button_load=self.widget_manager.dict_button[f"{self.app_key_pri}_load_eventfile"],
            q_button_clear=self.widget_manager.dict_button[f"{self.app_key_pri}_clear_eventfile"],
            q_window=self,
            q_combobox_eventfile=self.widget_manager.dict_combobox[f"{self.app_key_pri}_eventfile_loaded"],
            data_manager=self.data_manager,
            control_manager=self.control_manager,
            canvas_control=self.control_manager.canvas_controls[self.app_key_pri],
            app_key=self.app_key_pri,
        )
        # Canvas plot EventFile property
        bindFuncCheckboxEventfilePlotProperty(
            q_checkbox_ffneu=self.widget_manager.dict_checkbox[f"{self.app_key_pri}_plot_eventfile_ffneu"],
            q_checkbox_dff0=self.widget_manager.dict_checkbox[f"{self.app_key_pri}_plot_eventfile_dff0"],
            canvas_control=self.control_manager.canvas_controls[self.app_key_pri],
        )

if __name__ == "__main__":
    app = QApplication(sys.argv) if QApplication.instance() is None else QApplication.instance()
    applyAppStyle(app)
    gui = Suite2pROICurationGUI()
    gui.show()
    sys.exit(app.exec_())

SystemExit: 0

To exit: use 'exit', 'quit', or Ctrl-D.


In [8]:
gui.control_manager.table_controls["pri"].dict_checkboxes[0]["Check"].connect

AttributeError: 'QTableWidgetItem' object has no attribute 'connect'

In [3]:
gui.control_manager.table_controls["pri"].q_table.item(0, 5).isChecked()

AttributeError: 'QTableWidgetItem' object has no attribute 'isChecked'

In [33]:
gui.control_manager.table_controls["pri"].getSharedAttr_CheckboxVisibility()

{'Check': True, 'Tracking': False}

In [2]:
gui.control_manager.getAllSharedAttrs("pri")

{'dict_roi_display': {0: {'celltype': True, 'checkbox': False},
  1: {'celltype': True, 'checkbox': False},
  2: {'celltype': True, 'checkbox': False},
  3: {'celltype': True, 'checkbox': False},
  4: {'celltype': True, 'checkbox': False},
  5: {'celltype': True, 'checkbox': False},
  6: {'celltype': True, 'checkbox': False},
  7: {'celltype': True, 'checkbox': False},
  8: {'celltype': True, 'checkbox': False},
  9: {'celltype': True, 'checkbox': False},
  10: {'celltype': True, 'checkbox': False},
  11: {'celltype': True, 'checkbox': False},
  12: {'celltype': True, 'checkbox': False},
  13: {'celltype': True, 'checkbox': False},
  14: {'celltype': True, 'checkbox': False},
  15: {'celltype': True, 'checkbox': False},
  16: {'celltype': True, 'checkbox': False},
  17: {'celltype': True, 'checkbox': False},
  18: {'celltype': True, 'checkbox': False},
  19: {'celltype': True, 'checkbox': False},
  20: {'celltype': True, 'checkbox': False},
  21: {'celltype': True, 'checkbox': False},


In [17]:
gui.config_manager.table_columns["pri"].getColumns().keys()

dict_keys(['Cell_ID', 'Neuron', 'new_cell_type', 'Astrocyte', 'Not_Cell', 'Check', 'Tracking', 'Memo'])

In [15]:
gui.widget_manager.dict_checkbox

{'light_plot_mode': <PyQt5.QtWidgets.QCheckBox at 0x183ab7a5430>,
 'pri_plot_eventfile': <PyQt5.QtWidgets.QCheckBox at 0x183ab1a9310>,
 'pri_plot_eventfile_ffneu': <PyQt5.QtWidgets.QCheckBox at 0x183ab1ab040>,
 'pri_plot_eventfile_dff0': <PyQt5.QtWidgets.QCheckBox at 0x183ab1a9280>,
 'pri_display_celltype_roi_display_Neuron': <PyQt5.QtWidgets.QCheckBox at 0x183ab1aaaf0>,
 'pri_display_celltype_roi_display_Astrocyte': <PyQt5.QtWidgets.QCheckBox at 0x183ab1aaee0>,
 'pri_display_celltype_roi_display_Not_Cell': <PyQt5.QtWidgets.QCheckBox at 0x183ab1aae50>,
 'pri_display_checkbox_roi_display_Check': <PyQt5.QtWidgets.QCheckBox at 0x183ab1a8ca0>,
 'pri_display_checkbox_roi_display_Tracking': <PyQt5.QtWidgets.QCheckBox at 0x183ab1ad550>,
 'pri_skip_celltype_skip_choose_Neuron': <PyQt5.QtWidgets.QCheckBox at 0x183ab1aa040>,
 'pri_skip_celltype_skip_choose_Astrocyte': <PyQt5.QtWidgets.QCheckBox at 0x183ab17fd30>,
 'pri_skip_celltype_skip_choose_Not_Cell': <PyQt5.QtWidgets.QCheckBox at 0x183ab17f

In [6]:
gui.control_manager.getAllSharedAttrs("pri")

{'roi_display': {0: False,
  1: False,
  2: False,
  3: False,
  4: False,
  5: False,
  6: False,
  7: False,
  8: False,
  9: False,
  10: False,
  11: False,
  12: False,
  13: False,
  14: False,
  15: False,
  16: False,
  17: False,
  18: False,
  19: False,
  20: False,
  21: False,
  22: False,
  23: False,
  24: False,
  25: False,
  26: False,
  27: False,
  28: False,
  29: False,
  30: False,
  31: False,
  32: False,
  33: False,
  34: False,
  35: False,
  36: False,
  37: False,
  38: False,
  39: False,
  40: False,
  41: False,
  42: False,
  43: False,
  44: False,
  45: False,
  46: False,
  47: False,
  48: False,
  49: False,
  50: False,
  51: False,
  52: False,
  53: False,
  54: False,
  55: False,
  56: False,
  57: False,
  58: False,
  59: False,
  60: False,
  61: False,
  62: False,
  63: False,
  64: False,
  65: False,
  66: False,
  67: False,
  68: False,
  69: False,
  70: False,
  71: False,
  72: False,
  73: False,
  74: False,
  75: False,
  76: F

In [None]:
table_columns_ = gui.control_manager.table_controls["pri"].table_columns.getColumns()
dict_checkbox_state = {}
for column_name in table_columns_.keys():
    if table_columns_[column_name]["type"] == "checkbox":
        dict_checkbox_state[column_name] = table_columns_[column_name]["state"]

{'Cell_ID': {'order': 0,
  'type': 'id',
  'width': 80,
  'removable': False,
  'name_fixed': True,
  'editable': False},
 'Neuron': {'order': 1,
  'type': 'celltype',
  'width': 80,
  'removable': True,
  'default': True},
 'Astrocyte': {'order': 2,
  'type': 'celltype',
  'width': 80,
  'removable': True,
  'default': False},
 'Not_Cell': {'order': 3,
  'type': 'celltype',
  'width': 80,
  'removable': False,
  'name_fixed': True,
  'default': False},
 'Check': {'order': 4,
  'type': 'checkbox',
  'width': 80,
  'removable': True,
  'default': False},
 'Tracking': {'order': 5,
  'type': 'checkbox',
  'width': 80,
  'removable': True,
  'default': False},
 'Memo': {'order': 6, 'type': 'string', 'width': 200, 'removable': True}}

In [50]:
# from ROIcuration.mat, load Fall traces as input
def loadInputDictTraces(path_roicuration: str):
    mat_roicheck = loadmat(path_roicuration, simplify_cells=True)
    path_fall = mat_roicheck["path_Fall"]

    # prepare input Fall
    dict_Fall = loadFallMat(path_fall)
    dict_traces = {
        "F"     : dict_Fall["F"],
        "Fneu"  : dict_Fall["Fneu"],
        "spks"  : dict_Fall["spks"],
    }
    return dict_traces

# from ROIcuration.mat, load celltypes as output and dict for matching celltypes and indexes
def loadOutputCelltypesAndDictIndexesCelltypes(path_roicuration: str):
    mat_roicheck = loadmat(path_roicuration, simplify_cells=True)
    number_of_rois = mat_roicheck["NumberOfROI"]
    # prepare output celltype
    date = list(mat_roicheck["manualROIcheck"].keys())[-1]
    dict_roicheck = mat_roicheck["manualROIcheck"][date]
    table_columns = dict_roicheck["TableColumns"]
    list_celltype_table_columns = [col_name for col_name, col_info in table_columns.items() if col_info["type"] == "celltype"]

    list_celltype_current = gui.control_manager.table_controls["pri"].table_columns._celltype

    # check celltype, all current celltypes are included in ROIcuration.mat
    if not set(list_celltype_current) == set(list_celltype_table_columns):
        print("Current celltype:", list_celltype_current)
        print("ROIcuration.mat celltype:", list_celltype_table_columns)
        raise KeyError
    
    # one-hot vector for output, celltype
    celltype_onehot = np.zeros((number_of_rois, len(list_celltype_current)))
    # Dict[column: int, celltype: str]
    dict_idx_celltype = {idx:celltype for idx, celltype in enumerate(list_celltype_current)}

    for idx_col, celltype in dict_idx_celltype.items():
        for idx_row in dict_roicheck[celltype]:
            celltype_onehot[idx_row, idx_col] = 1

    return celltype_onehot, dict_idx_celltype

In [87]:
list_path_roicuration = [
    "D:/optic/data/ROIcuration_KA86-f1-z160_240227_1-preHolo.mat",
    "D:/optic/data/ROIcuration_KA86-f1-z160_240227_2-preWsk4bit.mat",
    "D:/optic/data/ROIcuration_KA91-f2-z130_240425_2-preWsk4bit--HfixWskVib.mat",
]

inputs_all = {
    "F"     : [], # list[np.array[float, float]],  # [sample, num_ROIs, sequence_length]
    "Fneu"  : [], # list[np.array[float, float]],  # [sample, num_ROIs, sequence_length]
    "spks"  : [], # list[np.array[float, float]],  # [sample, num_ROIs, sequence_length]
}

outputs_all = {
    "celltype"          : [], # list[np.array[int, int]],   # [sample, num_ROIs, num_celltypes]
    'dict_idx_celltype' : {}, # Dict[int, str]     # {0: "Neuron", 1: "Astrocite", 2: "Not_Cell"}
}

for i, path_roicuration in enumerate(list_path_roicuration):
    print(path_roicuration)
    dict_traces = loadInputDictTraces(path_roicuration)
    celltype_onehot, dict_idx_celltype = loadOutputCelltypesAndDictIndexesCelltypes(path_roicuration)
    print(dict_traces["F"].shape) # (number_of_rois, number_of_frames)
    print(celltype_onehot.shape) # (number_of_rois, number_of_celltypes)
    print(dict_idx_celltype) # {0: celltype1, 1: celltype2, ...}

    if i == 0:
        for trace in dict_traces.keys():
            inputs_all[trace].append(dict_traces[trace])
        outputs_all["celltype"].append(celltype_onehot)
        outputs_all["dict_idx_celltype"] = dict_idx_celltype
    else:
        if not dict_idx_celltype == outputs_all["dict_idx_celltype"]:
            print("dict_idx_celltype is not same")
            raise KeyError
        for trace in dict_traces.keys():
            inputs_all[trace].append(dict_traces[trace])
        outputs_all["celltype"].append(celltype_onehot)

D:/optic/data/ROIcuration_KA86-f1-z160_240227_1-preHolo.mat
(226, 24608)
(226, 3)
{0: 'Neuron', 1: 'Astrocyte', 2: 'Not_Cell'}
D:/optic/data/ROIcuration_KA86-f1-z160_240227_2-preWsk4bit.mat
(365, 20780)
(365, 3)
{0: 'Neuron', 1: 'Astrocyte', 2: 'Not_Cell'}
D:/optic/data/ROIcuration_KA91-f2-z130_240425_2-preWsk4bit--HfixWskVib.mat
(760, 20438)
(760, 3)
{0: 'Neuron', 1: 'Astrocyte', 2: 'Not_Cell'}


In [91]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

class CelltypeDataset(Dataset):
    """カルシウムイメージングデータセット"""
    
    def __init__(self, F, Fneu, spks, celltype):
        self.F = F
        self.Fneu = Fneu
        self.spks = spks
        self.celltype = celltype
        
    def __len__(self):
        return len(self.F)
    
    def __getitem__(self, idx):
        return {
            'F': self.F[idx],
            'Fneu': self.Fneu[idx],
            'spks': self.spks[idx],
            'celltype': self.celltype[idx]
        }


def collate_fn(batch):
    """バッチ内のシーケンスを動的にパディング"""
    # バッチ内の最長シーケンスの長さを取得
    max_len = max([len(item['F']) for item in batch])
    
    F_padded = []
    Fneu_padded = []
    spks_padded = []
    celltype_batch = []
    
    for item in batch:
        # 現在のシーケンス長
        seq_len = len(item['F'])
        
        # パディング
        F_pad = torch.nn.functional.pad(torch.tensor(item['F'], dtype=torch.float32), (0, max_len - seq_len))
        Fneu_pad = torch.nn.functional.pad(torch.tensor(item['Fneu'], dtype=torch.float32), (0, max_len - seq_len))
        spks_pad = torch.nn.functional.pad(torch.tensor(item['spks'], dtype=torch.float32), (0, max_len - seq_len))
        
        F_padded.append(F_pad)
        Fneu_padded.append(Fneu_pad)
        spks_padded.append(spks_pad)
        celltype_batch.append(torch.tensor(item['celltype'], dtype=torch.float32))
    
    return {
        'F': torch.stack(F_padded),
        'Fneu': torch.stack(Fneu_padded),
        'spks': torch.stack(spks_padded),
        'celltype': torch.stack(celltype_batch)
    }


class CelltypeTransformer(nn.Module):
    """Transformerベースのcelltype分類モデル"""
    
    def __init__(self, num_classes, d_model=64, nhead=4, num_layers=2):
        super(CelltypeTransformer, self).__init__()
        
        # 入力特徴量の変換 (F, Fneu, spks -> d_model)
        self.input_proj = nn.Linear(3, d_model)
        
        # 標準のTransformerエンコーダー
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, 
            nhead=nhead,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # 分類ヘッド
        self.classifier = nn.Linear(d_model, num_classes)
        
    def forward(self, F, Fneu, spks):
        # 入力特徴量の結合 [batch_size, seq_len, 3]
        x = torch.stack([F, Fneu, spks], dim=2)
        
        # 特徴量の次元変換 [batch_size, seq_len, d_model]
        x = self.input_proj(x)
        
        # Transformerエンコーダー
        x = self.transformer(x)
        
        # 全時間ステップの平均をとる
        x = torch.mean(x, dim=1)
        
        # 分類
        x = self.classifier(x)
        
        return x


def train_celltype_model(inputs_all, outputs_all, dict_idx_celltype, batch_size=32, epochs=20, lr=0.001):
    """モデルの学習パイプライン"""
    # tqdmのインポート
    from tqdm import tqdm
    """モデルの学習パイプライン"""
    
    # 全データの収集
    all_F = []
    all_Fneu = []
    all_spks = []
    all_celltype = []
    
    for i in range(len(inputs_all['F'])):
        for j in range(len(inputs_all['F'][i])):
            all_F.append(inputs_all['F'][i][j])
            all_Fneu.append(inputs_all['Fneu'][i][j])
            all_spks.append(inputs_all['spks'][i][j])
            all_celltype.append(outputs_all['celltype'][i][j])
    
    # データの分割
    train_indices, val_indices = train_test_split(
        range(len(all_F)), test_size=0.2, random_state=42
    )

    print(f"Training samples: {len(train_indices)}, Validation samples: {len(val_indices)}")
    
    # データセットの作成
    train_dataset = CelltypeDataset(
        [all_F[i] for i in train_indices],
        [all_Fneu[i] for i in train_indices],
        [all_spks[i] for i in train_indices],
        [all_celltype[i] for i in train_indices]
    )
    
    val_dataset = CelltypeDataset(
        [all_F[i] for i in val_indices],
        [all_Fneu[i] for i in val_indices],
        [all_spks[i] for i in val_indices],
        [all_celltype[i] for i in val_indices]
    )
    
    # データローダーの作成
    train_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=True, 
        collate_fn=collate_fn
    )
    
    val_loader = DataLoader(
        val_dataset, 
        batch_size=batch_size, 
        collate_fn=collate_fn
    )
    
    # デバイスの設定
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # モデルの初期化
    num_classes = len(dict_idx_celltype)
    model = CelltypeTransformer(num_classes).to(device)
    
    # 損失関数と最適化アルゴリズム
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    print("Training started...")
    # 学習ループ
    for epoch in range(epochs):
        # 訓練モード
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        
        from tqdm import tqdm
        for batch in tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs} (Train)'):
            # データをデバイスに転送
            F = batch['F'].to(device)
            Fneu = batch['Fneu'].to(device)
            spks = batch['spks'].to(device)
            celltype = batch['celltype'].to(device)
            
            # one-hot から整数インデックスに変換
            labels = torch.argmax(celltype, dim=1)
            
            # 勾配のリセット
            optimizer.zero_grad()
            
            # 順伝播
            outputs = model(F, Fneu, spks)
            
            # 損失計算
            loss = criterion(outputs, labels)
            
            # 逆伝播と最適化
            loss.backward()
            optimizer.step()
            
            # 統計更新
            train_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
        
        # 評価モード
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f'Epoch {epoch+1}/{epochs} (Val)'):
                # データをデバイスに転送
                F = batch['F'].to(device)
                Fneu = batch['Fneu'].to(device)
                spks = batch['spks'].to(device)
                celltype = batch['celltype'].to(device)
                
                # one-hot から整数インデックスに変換
                labels = torch.argmax(celltype, dim=1)
                
                # 順伝播
                outputs = model(F, Fneu, spks)
                
                # 損失計算
                loss = criterion(outputs, labels)
                
                # 統計更新
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
        
        # エポックごとの進捗表示
        train_loss = train_loss / len(train_loader)
        train_acc = 100.0 * train_correct / train_total
        val_loss = val_loss / len(val_loader)
        val_acc = 100.0 * val_correct / val_total
        
        print(f'Epoch {epoch+1}/{epochs} - '
              f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, '
              f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
    
    print("Training finished.")

    # モデルの保存
    torch.save(model.state_dict(), 'celltype_model.pth')
    
    return model


def predict(model, F, Fneu, spks, dict_idx_celltype):
    """新しいROIデータに対する予測"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.eval()
    
    # テンソルに変換
    F = torch.tensor(F, dtype=torch.float32).unsqueeze(0).to(device)
    Fneu = torch.tensor(Fneu, dtype=torch.float32).unsqueeze(0).to(device)
    spks = torch.tensor(spks, dtype=torch.float32).unsqueeze(0).to(device)
    
    with torch.no_grad():
        # 予測
        outputs = model(F, Fneu, spks)
        _, predicted = torch.max(outputs, 1)
        
    # クラスIDからcelltype名に変換
    predicted_celltype = dict_idx_celltype[predicted.item()]
    
    return predicted_celltype


# 使用例
if __name__ == "__main__":
    # モデルの学習
    model = train_celltype_model(inputs_all, outputs_all, dict_idx_celltype, epochs=5)
    
    # 新しいデータで予測
    new_F = np.random.randn(2000).astype(np.float32)
    new_Fneu = np.random.randn(2000).astype(np.float32)
    new_spks = np.random.randn(2000).astype(np.float32)
    
    predicted_celltype = predict(model, new_F, new_Fneu, new_spks, dict_idx_celltype)
    print(f"予測されたcelltype: {predicted_celltype}")

Training samples: 1080, Validation samples: 271
Training started...


Epoch 1/5 (Train): 100%|██████████| 34/34 [11:44<00:00, 20.71s/it]
Epoch 1/5 (Val): 100%|██████████| 9/9 [01:20<00:00,  8.89s/it]


Epoch 1/5 - Train Loss: 0.5977, Train Acc: 74.17%, Val Loss: 0.6022, Val Acc: 71.59%


Epoch 2/5 (Train): 100%|██████████| 34/34 [11:43<00:00, 20.69s/it]
Epoch 2/5 (Val): 100%|██████████| 9/9 [01:20<00:00,  8.90s/it]


Epoch 2/5 - Train Loss: 0.5550, Train Acc: 75.46%, Val Loss: 0.5669, Val Acc: 71.59%


Epoch 3/5 (Train): 100%|██████████| 34/34 [11:37<00:00, 20.52s/it]
Epoch 3/5 (Val): 100%|██████████| 9/9 [01:20<00:00,  8.89s/it]


Epoch 3/5 - Train Loss: 0.4961, Train Acc: 75.46%, Val Loss: 0.4633, Val Acc: 75.28%


Epoch 4/5 (Train): 100%|██████████| 34/34 [11:33<00:00, 20.39s/it]
Epoch 4/5 (Val): 100%|██████████| 9/9 [01:19<00:00,  8.78s/it]


Epoch 4/5 - Train Loss: 0.4248, Train Acc: 80.00%, Val Loss: 0.4155, Val Acc: 78.97%


Epoch 5/5 (Train): 100%|██████████| 34/34 [11:33<00:00, 20.39s/it]
Epoch 5/5 (Val): 100%|██████████| 9/9 [01:18<00:00,  8.77s/it]

Epoch 5/5 - Train Loss: 0.4047, Train Acc: 77.87%, Val Loss: 0.3865, Val Acc: 82.29%
Training finished.
予測されたcelltype: Not_Cell





In [90]:
!pip install tqdm

