Skip to content

Commit

Permalink
Delete old training data saves and warn user (#365)
Browse files Browse the repository at this point in the history
* Move qtpy functionality to brainglobe-utils

* Prompt user before overwriting data

* Update use of brainglobe-utils functions
  • Loading branch information
adamltyson committed Jan 18, 2024
1 parent baa8b2d commit ae10446
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 145 deletions.
71 changes: 54 additions & 17 deletions cellfinder/napari/curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
import tifffile
from brainglobe_napari_io.cellfinder.utils import convert_layer_to_cells
from brainglobe_utils.cells.cells import Cell
from brainglobe_utils.general.system import delete_directory_contents
from brainglobe_utils.IO.yaml import save_yaml
from brainglobe_utils.qtpy.dialog import display_warning
from brainglobe_utils.qtpy.interaction import add_button, add_combobox
from magicgui.widgets import ProgressBar
from napari.qt.threading import thread_worker
from napari.utils.notifications import show_info
Expand All @@ -20,8 +23,6 @@
QWidget,
)

from .utils import add_button, add_combobox, display_question

# Constants used throughout
WINDOW_HEIGHT = 750
WINDOW_WIDTH = 1500
Expand Down Expand Up @@ -173,33 +174,33 @@ def add_loading_panel(self, row: int, column: int = 0):
self.load_data_layout,
"Training_data (non_cells)",
self.point_layer_names,
4,
row=4,
callback=self.set_training_data_non_cell,
)
self.mark_as_cell_button = add_button(
"Mark as cell(s)",
self.load_data_layout,
self.mark_as_cell,
5,
row=5,
)
self.mark_as_non_cell_button = add_button(
"Mark as non cell(s)",
self.load_data_layout,
self.mark_as_non_cell,
5,
row=5,
column=1,
)
self.add_training_data_button = add_button(
"Add training data layers",
self.load_data_layout,
self.add_training_data,
6,
row=6,
)
self.save_training_data_button = add_button(
"Save training data",
self.load_data_layout,
self.save_training_data,
6,
row=6,
column=1,
)
self.load_data_layout.setColumnMinimumWidth(0, COLUMN_WIDTH)
Expand Down Expand Up @@ -256,7 +257,7 @@ def add_training_data(self):

overwrite = False
if self.training_data_cell_layer or self.training_data_non_cell_layer:
overwrite = display_question(
overwrite = display_warning(
self,
"Training data layers exist",
"Training data layers already exist, "
Expand Down Expand Up @@ -363,7 +364,10 @@ def mark_point_as_type(self, point_type: str):
)

def save_training_data(
self, *, block: bool = False, prompt_for_directory: bool = True
self,
*,
block: bool = False,
prompt_for_directory: bool = True,
) -> None:
"""
Parameters
Expand All @@ -373,16 +377,45 @@ def save_training_data(
prompt_for_directory :
If `True` show a file dialog for the user to select a directory.
"""

if self.is_data_extractable():
if prompt_for_directory:
self.get_output_directory()
# if the directory is not empty
if any(self.output_directory.iterdir()):
choice = display_warning(
self,
"About to save training data",
"Existing files will be will be deleted. Proceed?",
)
if not choice:
return
if self.output_directory is not None:
self.__prep_directories_for_save()
self.__extract_cubes(block=block)
self.__save_yaml_file()
show_info("Done")

self.update_status_label("Ready")

def __prep_directories_for_save(self):
self.yaml_filename = self.output_directory / "training.yml"
self.cell_cube_dir = self.output_directory / "cells"
self.no_cell_cube_dir = self.output_directory / "non_cells"

self.__delete_existing_saved_training_data()

def __delete_existing_saved_training_data(self):
self.yaml_filename.unlink(missing_ok=True)
for directory in (
self.cell_cube_dir,
self.no_cell_cube_dir,
):
if directory.exists():
delete_directory_contents(directory)
else:
directory.mkdir(exist_ok=True, parents=True)

def __extract_cubes(self, *, block=False):
"""
Parameters
Expand Down Expand Up @@ -489,18 +522,16 @@ def convert_layers_to_cells(self):
self.non_cells_to_extract = list(set(self.non_cells_to_extract))

def __save_yaml_file(self):
# TODO: implement this in a portable way
yaml_filename = self.output_directory / "training.yml"
yaml_section = [
{
"cube_dir": str(self.output_directory / "cells"),
"cube_dir": str(self.cell_cube_dir),
"cell_def": "",
"type": "cell",
"signal_channel": 0,
"bg_channel": 1,
},
{
"cube_dir": str(self.output_directory / "non_cells"),
"cube_dir": str(self.no_cell_cube_dir),
"cell_def": "",
"type": "no_cell",
"signal_channel": 0,
Expand All @@ -509,7 +540,7 @@ def __save_yaml_file(self):
]

yaml_contents = {"data": yaml_section}
save_yaml(yaml_contents, yaml_filename)
save_yaml(yaml_contents, self.yaml_filename)

def update_progress(self, attributes: dict):
"""
Expand Down Expand Up @@ -538,9 +569,15 @@ def extract_cubes(self):
"non_cells": self.non_cells_to_extract,
}

for cell_type, cell_list in to_extract.items():
cell_type_output_directory = self.output_directory / cell_type
cell_type_output_directory.mkdir(exist_ok=True, parents=True)
directories = {
"cells": self.cell_cube_dir,
"non_cells": self.no_cell_cube_dir,
}

for cell_type in ["cells", "non_cells"]:
cell_type_output_directory = directories[cell_type]
cell_list = to_extract[cell_type]

self.update_status_label(f"Saving {cell_type}...")

cube_generator = CubeGeneratorFromFile(
Expand Down
90 changes: 1 addition & 89 deletions cellfinder/napari/utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,10 @@
from typing import Callable, List, Optional, Tuple
from typing import List, Tuple

import napari
import numpy as np
import pandas as pd
from brainglobe_utils.cells.cells import Cell
from pkg_resources import resource_filename
from qtpy.QtWidgets import (
QComboBox,
QLabel,
QLayout,
QMessageBox,
QPushButton,
QWidget,
)

brainglobe_logo = resource_filename(
"cellfinder", "napari/images/brainglobe.png"
Expand Down Expand Up @@ -98,83 +90,3 @@ def cells_to_array(cells: List[Cell]) -> Tuple[np.ndarray, np.ndarray]:
points = cells_df_as_np(df[df["type"] == Cell.CELL])
rejected = cells_df_as_np(df[df["type"] == Cell.UNKNOWN])
return points, rejected


def add_combobox(
layout: QLayout,
label: str,
items: List[str],
row: int,
column: int = 0,
label_stack: bool = False,
callback=None,
width: int = 150,
) -> Tuple[QComboBox, Optional[QLabel]]:
"""
Add a selection box to *layout*.
"""
if label_stack:
combobox_row = row + 1
combobox_column = column
else:
combobox_row = row
combobox_column = column + 1
combobox = QComboBox()
combobox.addItems(items)
if callback:
combobox.currentIndexChanged.connect(callback)
combobox.setMaximumWidth = width

if label is not None:
combobox_label = QLabel(label)
combobox_label.setMaximumWidth = width
layout.addWidget(combobox_label, row, column)
else:
combobox_label = None

layout.addWidget(combobox, combobox_row, combobox_column)
return combobox, combobox_label


def add_button(
label: str,
layout: QLayout,
connected_function: Callable,
row: int,
column: int = 0,
visibility: bool = True,
minimum_width: int = 0,
alignment: str = "center",
) -> QPushButton:
"""
Add a button to *layout*.
"""
button = QPushButton(label)
if alignment == "center":
pass
elif alignment == "left":
button.setStyleSheet("QPushButton { text-align: left; }")
elif alignment == "right":
button.setStyleSheet("QPushButton { text-align: right; }")

button.setVisible(visibility)
button.setMinimumWidth(minimum_width)
layout.addWidget(button, row, column)
button.clicked.connect(connected_function)
return button


def display_question(widget: QWidget, title: str, message: str) -> bool:
"""
Display a warning in a pop up that informs about overwriting files.
"""
message_reply = QMessageBox.question(
widget,
title,
message,
QMessageBox.Yes | QMessageBox.Cancel,
)
if message_reply == QMessageBox.Yes:
return True
else:
return False
39 changes: 0 additions & 39 deletions tests/napari/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
import pytest
from brainglobe_utils.cells.cells import Cell
from qtpy.QtWidgets import QGridLayout

from cellfinder.napari.utils import (
add_button,
add_combobox,
add_layers,
html_label_widget,
)
Expand All @@ -27,38 +23,3 @@ def test_html_label_widget():
label_widget = html_label_widget("A nice label", tag="h1")
assert label_widget["widget_type"] == "Label"
assert label_widget["label"] == "<h1>A nice label</h1>"


@pytest.mark.parametrize("label_stack", [True, False])
@pytest.mark.parametrize("label", ["A label", None])
def test_add_combobox(label, label_stack):
"""
Smoke test for add_combobox for all conditional branches
"""
layout = QGridLayout()
combobox = add_combobox(
layout,
row=0,
label=label,
items=["item 1", "item 2"],
label_stack=label_stack,
)
assert combobox is not None


@pytest.mark.parametrize(
argnames="alignment", argvalues=["center", "left", "right"]
)
def test_add_button(alignment):
"""
Smoke tests for add_button for all conditional branches
"""
layout = QGridLayout()
button = add_button(
layout=layout,
connected_function=lambda: None,
label="A button",
row=0,
alignment=alignment,
)
assert button is not None

0 comments on commit ae10446

Please sign in to comment.