Skip to content

Commit

Permalink
Improve type annotations in MainWindowModel
Browse files Browse the repository at this point in the history
  • Loading branch information
samtygier-stfc committed May 9, 2024
1 parent e4d7c09 commit d329c50
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 17 deletions.
2 changes: 1 addition & 1 deletion mantidimaging/core/io/loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def load_stack_from_group(group: FilenameGroup, progress: Progress | None = None

def load_stack_from_image_params(image_params: ImageParameters,
progress: Progress | None = None,
dtype: npt.DTypeLike = np.float32):
dtype: npt.DTypeLike = np.float32) -> ImageStack:
return load(filename_group=image_params.file_group,
progress=progress,
dtype=dtype,
Expand Down
4 changes: 2 additions & 2 deletions mantidimaging/core/io/saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def image_save(images: ImageStack,
name_postfix: str = DEFAULT_NAME_POSTFIX,
indices: list[int] | Indices | None = None,
pixel_depth: str | None = None,
progress: Progress | None = None) -> str | list[str]:
progress: Progress | None = None) -> list[str]:
"""
Save image volume (3d) into a series of slices along the Z axis.
The Z axis in the script is the ndarray.shape[0].
Expand Down Expand Up @@ -144,7 +144,7 @@ def image_save(images: ImageStack,
if out_format in ['nxs']:
filename = os.path.join(output_dir, name_prefix + name_postfix)
write_nxs(data, filename + '.nxs', overwrite=overwrite_all)
return filename
return [filename]
else:
if out_format in ['fit', 'fits']:
write_func: Callable[[np.ndarray, str, bool, str | None], None] = write_fits
Expand Down
29 changes: 15 additions & 14 deletions mantidimaging/gui/windows/main/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from mantidimaging.core.data.dataset import StrictDataset, MixedDataset
from mantidimaging.core.io import loader, saver
from mantidimaging.core.io.filenames import FilenameGroup
from mantidimaging.core.io.loader.loader import LoadingParameters
from mantidimaging.core.io.loader.loader import LoadingParameters, ImageParameters
from mantidimaging.core.utility.data_containers import ProjectionAngles, FILE_TYPES

if TYPE_CHECKING:
Expand Down Expand Up @@ -38,7 +38,7 @@ def get_images_by_uuid(self, images_uuid: uuid.UUID) -> ImageStack | None:

def do_load_dataset(self, parameters: LoadingParameters, progress: Progress) -> StrictDataset:

def load(im_param):
def load(im_param: ImageParameters) -> ImageStack:
return loader.load_stack_from_image_params(im_param, progress, dtype=parameters.dtype)

sample = load(parameters.image_stacks[FILE_TYPES.SAMPLE])
Expand Down Expand Up @@ -68,7 +68,8 @@ def load_images_into_mixed_dataset(self, file_path: str, progress: Progress) ->
self.datasets[sd.id] = sd
return sd

def do_images_saving(self, images_id, output_dir, name_prefix, image_format, overwrite, pixel_depth, progress):
def do_images_saving(self, images_id: uuid.UUID, output_dir: str, name_prefix: str, image_format: str,
overwrite: bool, pixel_depth: str, progress: Progress) -> bool:
images = self.get_images_by_uuid(images_id)
if images is None:
self.raise_error_when_images_not_found(images_id)
Expand All @@ -82,9 +83,10 @@ def do_images_saving(self, images_id, output_dir, name_prefix, image_format, ove
images.filenames = filenames
return True

def do_nexus_saving(self, dataset_id: uuid.UUID, path: str, sample_name: str, save_as_float: bool) -> bool | None:
if dataset_id in self.datasets and isinstance(self.datasets[dataset_id], StrictDataset):
saver.nexus_save(self.datasets[dataset_id], path, sample_name, save_as_float) # type: ignore
def do_nexus_saving(self, dataset_id: uuid.UUID, path: str, sample_name: str, save_as_float: bool) -> bool:
dataset = self.datasets.get(dataset_id)
if isinstance(dataset, StrictDataset):
saver.nexus_save(dataset, path, sample_name, save_as_float)
return True
else:
raise RuntimeError(f"Failed to get StrictDataset with ID {dataset_id}")
Expand All @@ -95,13 +97,12 @@ def get_existing_180_id(self, dataset_id: uuid.UUID) -> uuid.UUID | None:
:param dataset_id: The Dataset ID.
:return: The 180 ID if found, None otherwise.
"""
if dataset_id in self.datasets and isinstance(self.datasets[dataset_id], StrictDataset):
dataset = self.datasets[dataset_id]
else:
dataset = self.datasets.get(dataset_id)
if not isinstance(dataset, StrictDataset):
raise RuntimeError(f"Failed to get StrictDataset with ID {dataset_id}")

if isinstance(dataset.proj180deg, ImageStack): # type: ignore
return dataset.proj180deg.id # type: ignore
if isinstance(dataset.proj180deg, ImageStack):
return dataset.proj180deg.id
return None

def add_180_deg_to_dataset(self, dataset_id: uuid.UUID, _180_deg_file: str) -> ImageStack:
Expand All @@ -123,7 +124,7 @@ def add_180_deg_to_dataset(self, dataset_id: uuid.UUID, _180_deg_file: str) -> I
dataset.proj180deg = _180_deg
return _180_deg

def add_projection_angles_to_sample(self, images_id: uuid.UUID, proj_angles: ProjectionAngles):
def add_projection_angles_to_sample(self, images_id: uuid.UUID, proj_angles: ProjectionAngles) -> None:
images = self.get_images_by_uuid(images_id)
if images is None:
self.raise_error_when_images_not_found(images_id)
Expand Down Expand Up @@ -160,7 +161,7 @@ def add_shutter_counts_to_sample(self, images_id: uuid.UUID, shutter_counts_file
raise RuntimeError
images.shutter_count_file = loader.load_shutter_counts(shutter_counts_file)

def _remove_dataset(self, dataset_id: uuid.UUID):
def _remove_dataset(self, dataset_id: uuid.UUID) -> None:
"""
Removes a dataset and the image stacks it contains from the model.
:param dataset_id: The dataset ID.
Expand Down Expand Up @@ -244,7 +245,7 @@ def add_recon_to_dataset(self, recon_data: ImageStack, stack_id: uuid.UUID) -> u
self.raise_error_when_parent_strict_dataset_not_found(stack_id)

@property
def recon_list_ids(self):
def recon_list_ids(self) -> list[uuid.UUID]:
return [dataset.recons.id for dataset in self.datasets.values()]

def get_recon_list_id(self, parent_id: uuid.UUID) -> uuid.UUID:
Expand Down

0 comments on commit d329c50

Please sign in to comment.