Skip to content

Commit

Permalink
Add Type Hint for utils.py (#1162)
Browse files Browse the repository at this point in the history
Co-authored-by: Yang Ding <yang.ding@motioncorrect.com>
  • Loading branch information
cakester and dyt811 committed Nov 21, 2022
1 parent dcda8dd commit 1e8a6bd
Showing 1 changed file with 22 additions and 14 deletions.
36 changes: 22 additions & 14 deletions ivadomed/loader/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
import collections.abc
import re
import sys
Expand All @@ -17,6 +18,10 @@
from ivadomed.keywords import SplitDatasetKW, LoaderParamsKW, ROIParamsKW, ContrastParamsKW
import nibabel as nib
import random
import typing
if typing.TYPE_CHECKING:
from typing import Union
from typing import Optional

__numpy_type_map = {
'float64': torch.DoubleTensor,
Expand All @@ -40,7 +45,8 @@
".tiff", ".png", ".jpg", ".jpeg"]


def split_dataset(df, split_method, data_testing, random_seed, train_frac=0.8, test_frac=0.1):
def split_dataset(df: pd.DataFrame, split_method: str, data_testing: dict, random_seed: int, train_frac: float = 0.8,
test_frac: float = 0.1) -> (list, list, Union[list, object]):
"""Splits dataset into training, validation and testing sets by applying train, test and validation fractions
according to the split_method.
The "data_testing" parameter can be used to specify the data_type and data_value to include in the testing set,
Expand Down Expand Up @@ -117,8 +123,9 @@ def split_dataset(df, split_method, data_testing, random_seed, train_frac=0.8, t
return X_train, X_val, X_test


def get_new_subject_file_split(df, split_method, data_testing, random_seed,
train_frac, test_frac, path_output, balance, subject_selection=None):
def get_new_subject_file_split(df: pd.DataFrame, split_method: str, data_testing: dict, random_seed: int,
train_frac: float, test_frac: float, path_output: str, balance: str,
subject_selection: dict = None) -> (list, list, list):
"""Randomly split dataset between training / validation / testing.
Randomly split dataset between training / validation / testing\
Expand Down Expand Up @@ -187,7 +194,8 @@ def get_new_subject_file_split(df, split_method, data_testing, random_seed,
return train_lst, valid_lst, test_lst


def get_subdatasets_subject_files_list(split_params, df, path_output, subject_selection=None):
def get_subdatasets_subject_files_list(split_params: dict, df: pd.DataFrame, path_output: str,
subject_selection: dict = None) -> (list, list, list):
"""Get lists of subject filenames for each sub-dataset between training / validation / testing.
Args:
Expand Down Expand Up @@ -233,7 +241,7 @@ def get_subdatasets_subject_files_list(split_params, df, path_output, subject_se
return train_lst, valid_lst, test_lst


def imed_collate(batch):
def imed_collate(batch: dict) -> dict | list | str | torch.Tensor:
"""Collates data to create batches
Args:
Expand Down Expand Up @@ -272,7 +280,7 @@ def imed_collate(batch):
return batch


def filter_roi(roi_data, nb_nonzero_thr):
def filter_roi(roi_data: np.ndarray, nb_nonzero_thr: int) -> bool:
"""Filter slices from dataset using ROI data.
This function filters slices (roi_data) where the number of non-zero voxels within the
Expand All @@ -290,7 +298,7 @@ def filter_roi(roi_data, nb_nonzero_thr):
return not np.any(roi_data) or np.count_nonzero(roi_data) <= nb_nonzero_thr


def orient_img_hwd(data, slice_axis):
def orient_img_hwd(data: np.ndarray, slice_axis: int) -> np.ndarray:
"""Orient a given RAS image to height, width, depth according to slice axis.
Args:
Expand All @@ -309,7 +317,7 @@ def orient_img_hwd(data, slice_axis):
return data


def orient_img_ras(data, slice_axis):
def orient_img_ras(data: np.ndarray, slice_axis: int) -> np.ndarray:
"""Orient a given array with dimensions (height, width, depth) to RAS orientation.
Args:
Expand All @@ -329,7 +337,7 @@ def orient_img_ras(data, slice_axis):
return data


def orient_shapes_hwd(data, slice_axis):
def orient_shapes_hwd(data: list | tuple, slice_axis: int) -> np.ndarray:
"""Swap dimensions according to match the height, width, depth orientation.
Args:
Expand All @@ -349,7 +357,7 @@ def orient_shapes_hwd(data, slice_axis):
return np.array(data)


def update_metadata(metadata_src_lst, metadata_dest_lst):
def update_metadata(metadata_src_lst: list, metadata_dest_lst: list) -> list:
"""Update metadata keys with a reference metadata.
A given list of metadata keys will be changed and given the values of the reference metadata.
Expand All @@ -371,7 +379,7 @@ def update_metadata(metadata_src_lst, metadata_dest_lst):
return metadata_dest_lst


def reorient_image(arr, slice_axis, nib_ref, nib_ref_canonical):
def reorient_image(arr: np.ndarray, slice_axis: int, nib_ref: nib, nib_ref_canonical: nib) -> nd.ndarray:
"""Reorient an image to match a reference image orientation.
It reorients a array to a given orientation and convert it to a nibabel object using the
Expand All @@ -396,7 +404,7 @@ def reorient_image(arr, slice_axis, nib_ref, nib_ref_canonical):
return nib.orientations.apply_orientation(arr_ras, trans_orient)


def get_file_extension(filename):
def get_file_extension(filename: str) -> Optional[str]:
""" Get file extension if it is supported
Args:
filename (str): Path of the file.
Expand All @@ -409,7 +417,7 @@ def get_file_extension(filename):
return extension


def update_filename_to_nifti(filename):
def update_filename_to_nifti(filename: str) -> str:
"""
Update filename extension to 'nii.gz' if not a NifTI file.
Expand All @@ -430,7 +438,7 @@ def update_filename_to_nifti(filename):
return filename


def dropout_input(seg_pair):
def dropout_input(seg_pair: dict) -> dict:
"""Applies input-level dropout: zero to all channels minus one will be randomly set to zeros. This function verifies
if some channels are already empty. Always at least one input channel will be kept.
Expand Down

0 comments on commit 1e8a6bd

Please sign in to comment.