In [4]:
import os
import pathlib
import sys
dir_notebook = os.path.dirname(os.path.abspath("__file__"))
dir_parent = os.path.dirname(dir_notebook)
if not dir_parent in sys.path:
    sys.path.append(dir_parent)
from deepinterpolation.inference_collection import core_inference
from deepinterpolation.generator_collection import SingleTifGenerator
import tkinter.filedialog 
import numpy as np
import tifffile
import h5py
import json
import tqdm
import datetime
import matplotlib.pyplot as plt

from PyQt5.QtWidgets import *
from PyQt5.QtGui import QPixmap, QImage, QPainter, QPen, QColor, QFont, QPainterPath, QBrush
from PyQt5.QtCore import Qt, QTimer, QItemSelection

from deepinterpolation.generator_collection import SingleTifGenerator, MultiContinuousTifGenerator
from deepinterpolation.trainor_collection import core_trainer
from deepinterpolation.network_collection import unet_single_1024

import sys
current_dir = os.path.dirname(os.path.abspath("__file__"))
parent_dir = os.path.dirname(os.path.dirname(current_dir))
sys.path.append(f"{parent_dir}optic")

from optic.config import *
from optic.manager import *
from optic.gui import *
from optic.io import *
from optic.utils import *
from optic.gui.bind_func import *
from optic.preprocessing import *

class DIPTrain(QMainWindow):
    def __init__(self):
        QMainWindow.__init__(self)
        self.widget_manager = initManagers(WidgetManager())
        self.current_dir = os.path.dirname(os.path.abspath("__file__"))

        self.setupUI()

    def setupUI(self):
        self.central_widget = QWidget(self)
        self.setCentralWidget(self.central_widget)
        self.setGeometry(100, 100, 1200, 600)
        self.layout_main = QGridLayout(self.central_widget)
        self.layout_main.addLayout(self.makeLayoutMain(), 0, 0)

        self.bindFuncAllWidget()

    """
    makeLayout Function; Component
    小要素のLayout
    return -> Layout
    """
    def makeLayoutComponentDIPConfig(self):
        layout = QGridLayout()
        list_key_label = ["pre_post_frame", "steps_per_epoch", "batch_size"]
        list_key_lineedit = list_key_label
        list_label = list_key_label
        list_text_set = ["30", "10", "5"]
        list_position = [(0, 0), (0, 1), (1, 0)]
        for key_label, key_lineedit, label, text_set, position in zip(list_key_label, list_key_lineedit, list_label, list_text_set, list_position):
            layout.addLayout(makeLayoutLineEditLabel(
                widget_manager=self.widget_manager, 
                key_label=key_label,
                key_lineedit=key_lineedit, 
                label=label, 
                text_set=text_set,
                axis="vertical")
                , position[0], position[1], 1, 1)
        return layout

    def makeLayoutComponentLoadTIF(self):
        layout = QVBoxLayout()
        layout.addLayout(makeLayoutLoadFileWidget(
            self.widget_manager, 
            label="TIF file path", 
            key_label="browse_tif", 
            key_lineedit="browse_tif", 
            key_button="browse_tif"
        ))
        return layout

    def makeLayoutComponentButtons(self):
        layout = QHBoxLayout()
        layout.addWidget(self.widget_manager.makeWidgetButton(key="run", label="Run"))
        layout.addWidget(self.widget_manager.makeWidgetButton(key="exit", label="Exit"))
        return layout


    """
    makeLayout Function; Section
    領域レベルの大Layout
    """
    def makeLayoutTop(self):
        layout = QVBoxLayout()
        layout.addLayout(self.makeLayoutComponentDIPConfig())
        return layout

    def makeLayoutBottom(self):
        layout = QVBoxLayout()
        layout.addLayout(self.makeLayoutComponentLoadTIF())
        layout.addLayout(self.makeLayoutComponentButtons())
        return layout

    
    def makeLayoutMain(self):
        layout = QVBoxLayout()
        layout.addLayout(self.makeLayoutTop())
        layout.addLayout(self.makeLayoutBottom())
        return layout
    
    """
    Functions
    """
    def getTiffStackShape(self):
        x, y, c, z, t = getTiffStackShape(self.path_input)
        return x, y
        
    def paramsSetting(self):
        training_param = {}
        generator_param = {}
        network_param = {}
        generator_test_param = {}

        now = datetime.datetime.now()
        now = now.strftime('%y%m%d_%H%M%S')

        # get params from lineedit
        pre_post_frame = int(self.widget_manager.dict_lineedit["pre_post_frame"].text())
        steps_per_epoch = int(self.widget_manager.dict_lineedit["steps_per_epoch"].text())
        batch_size = int(self.widget_manager.dict_lineedit["batch_size"].text())
        # get tiff stack shape
        width, height = self.getTiffStackShape()

        generator_test_param["pre_post_frame"] = pre_post_frame

        filePath = self.path_input
        generator_test_param["train_path"] = filePath
        generator_test_param["batch_size"] = batch_size
        generator_test_param["start_frame"] = 0
        generator_test_param["end_frame"] = 99
        generator_test_param["pre_post_omission"] = 1
        generator_test_param["steps_per_epoch"] = -1

        generator_param["pre_post_frame"] = pre_post_frame

        generator_param["train_path"] = filePath
        generator_param["batch_size"] = 1
        generator_param["start_frame"] = 0
        generator_param["end_frame"] = 99
        generator_param["pre_post_omission"] = 0
        generator_param["steps_per_epoch"] = -1

        training_param["run_uid"] = ""
        training_param["batch_size"] = batch_size
        training_param["steps_per_epoch"] = steps_per_epoch
        training_param["period_save"] = 25 
        training_param["nb_gpus"] = 0
        training_param["apply_learning_decay"] = 0
        training_param["nb_times_through_data"] = 1
        training_param["learning_rate"] = 0.0001
        training_param["pre_post_frame"] = pre_post_frame
        training_param["loss"] = "mean_absolute_error"
        training_param["nb_workers"] = 1


        run_uid = ""
        training_param["run_uid"] = run_uid
        training_param["model_string"] = (
            f"w{width}_"
            + f"h{height}_"
            + f"prepostframe{pre_post_frame}_"
            + now
        )
        jobdir = os.path.join(
            "models",
            training_param["model_string"]
        )

        training_param["output_dir"] = jobdir

        try:
            os.mkdir(jobdir)
        except:
            print("folder already exists")

        # if use multi tiff stack, implement also MultiContinuousTifGenerator
        train_generator = SingleTifGenerator(generator_param)
        test_generator = SingleTifGenerator(generator_test_param)
        training_class = core_trainer(train_generator, test_generator, unet_single_1024({}), training_param)
        return training_class, training_param

    def runDIPTrain(self):
        self.path_input = self.widget_manager.dict_lineedit["browse_tif"].text()
        # try:
        #     print("input: ", self.path_input)
        #     training_class, training_param = self.paramsSetting()
        #     print("output: ", training_param["output_dir"])
        #     training_class.run()
        #     training_class.finalize()
        # except Exception as e:
        #     QMessageBox.warning(self, "Error", str(e))

        print("input: ", self.path_input)
        self.training_class, self.training_param = self.paramsSetting()
        print(self.training_param)
        print("output: ", self.training_param["output_dir"])
        self.training_class.run()
        self.training_class.finalize()

        print("DeepInterpolation Training completed.")
            


    """
    bindFunc Function
    配置したwidgetに関数を紐づけ
    """
    def bindFuncAllWidget(self):
        bindFuncLoadFileWidget(
            q_widget=self, 
            q_button=self.widget_manager.dict_button["browse_tif"], 
            q_lineedit=self.widget_manager.dict_lineedit["browse_tif"], 
            filetype=Extension.TIFF
        )
        self.widget_manager.dict_button["run"].clicked.connect(self.runDIPTrain)
        bindFuncExit(q_window=self, q_button=self.widget_manager.dict_button["exit"])


if __name__ == "__main__":
    app = QApplication(sys.argv) if QApplication.instance() is None else QApplication.instance()
    applyAppStyle(app)
    gui = DIPTrain()
    gui.show()
    sys.exit(app.exec_())

input:  D:/deepinterpolation/sample_data/AN13-f1-z120_240610_2-preWsk_trial001.tif
{'run_uid': '', 'batch_size': 5, 'steps_per_epoch': 10, 'period_save': 25, 'nb_gpus': 0, 'apply_learning_decay': 0, 'nb_times_through_data': 1, 'learning_rate': 0.0001, 'pre_post_frame': 30, 'loss': 'mean_absolute_error', 'nb_workers': 1, 'model_string': 'w512_h512_prepostframe30_241023_111153', 'output_dir': 'models\\w512_h512_prepostframe30_241023_111153'}
output:  models\w512_h512_prepostframe30_241023_111153
Epoch 1/7
Epoch 2/7
Epoch 3/7
Epoch 4/7
Epoch 5/7
Epoch 6/7
Epoch 7/7
Saved model to disk
DeepInterpolation Training completed.


SystemExit: 0

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [3]:
current_dir

'd:\\deepinterpolation\\notebook'