From 33a89608d452c1aa1ab130c453d57a1f07999706 Mon Sep 17 00:00:00 2001 From: cakester Date: Sat, 19 Nov 2022 21:33:25 -0500 Subject: [PATCH] Add TypeHint for loader/film.py (#1151) --- ivadomed/loader/film.py | 39 +++++++++++++++++++++++++++------------ 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/ivadomed/loader/film.py b/ivadomed/loader/film.py index 3d5403db5..13f04039b 100644 --- a/ivadomed/loader/film.py +++ b/ivadomed/loader/film.py @@ -1,6 +1,8 @@ +from __future__ import annotations import json from pathlib import Path from copy import deepcopy +from typing import List, Union import numpy as np from loguru import logger @@ -9,6 +11,14 @@ from sklearn.neighbors import KernelDensity from sklearn.preprocessing import OneHotEncoder from ivadomed.keywords import MetadataKW +import typing + +if typing.TYPE_CHECKING: + from ivadomed.loader.bids_dataset import BidsDataset + from ivadomed.loader.bids3d_dataset import Bids3DDataset + from ivadomed.loader.mri2d_segmentation_dataset import MRI2DSegmentationDataset + + import torch.nn as nn from ivadomed import __path__ @@ -19,12 +29,16 @@ "acq-MToff_MTS": 3, "acq-MTon_MTS": 4, "acq-T1w_MTS": 5} -def normalize_metadata(ds_in, clustering_models, debugging, metadata_type, train_set=False): +def normalize_metadata(ds_in: Union[BidsDataset, Bids3DDataset, MRI2DSegmentationDataset], + clustering_models: dict, + debugging: bool, + metadata_type: str, + train_set: bool = False) -> (list, OneHotEncoder) | list: """Categorize each metadata value using a KDE clustering method, then apply a one-hot-encoding. Args: - ds_in (BidsDataset): Dataset with metadata. - clustering_models: Pre-trained clustering model that has been trained on metadata of the training set. + ds_in (BidsDataset): Dataset BidsDataset, Bids3D, MRI2D with metadata. + clustering_models (dict): Pre-trained clustering model that has been trained on metadata of the training set. debugging (bool): If True, extended verbosity and intermediate outputs. metadata_type (str): Choice between 'mri_params', 'constrasts' or the name of a column from the participants.tsv file. @@ -105,11 +119,11 @@ class Kde_model(): kde (sklearn.neighbors.KernelDensity): minima (float): Local minima. """ - def __init__(self): + def __init__(self) -> None: self.kde = KernelDensity() self.minima = None - def train(self, data, value_range, gridsearch_bandwidth_range): + def train(self, data: list, value_range: np.ndarray, gridsearch_bandwidth_range: np.ndarray) -> None: # reshape data to fit sklearn data = np.array(data).reshape(-1, 1) @@ -130,18 +144,18 @@ def train(self, data, value_range, gridsearch_bandwidth_range): # find local minima self.minima = s[argrelextrema(e, np.less)[0]] - def predict(self, data): + def predict(self, data: float) -> int: x = [i for i, m in enumerate(self.minima) if data < m] pred = min(x) if len(x) else len(self.minima) return pred -def clustering_fit(dataset, key_lst): +def clustering_fit(dataset: list, key_lst: List[str]) -> dict: """This function creates clustering models for each metadata type, using Kernel Density Estimation algorithm. Args: - datasets (list): data + dataset (list): data key_lst (list of str): names of metadata to cluster Returns: @@ -192,7 +206,7 @@ def check_isMRIparam(mri_param_type: str, mri_param: dict, subject: str, metadat return True -def get_film_metadata_models(ds_train, metadata_type, debugging=False): +def get_film_metadata_models(ds_train: MRI2DSegmentationDataset, metadata_type: str, debugging: bool = False): """Get FiLM models. This function pulls the clustering and one-hot encoder models that are used by FiLMedUnet. @@ -221,7 +235,8 @@ def get_film_metadata_models(ds_train, metadata_type, debugging=False): return ds_train, train_onehotencoder, metadata_clustering_models -def store_film_params(gammas, betas, metadata_values, metadata, model, film_layers, depth, film_metadata): +def store_film_params(gammas: dict, betas: dict, metadata_values: list, metadata: list, model: nn.Module, + film_layers: list, depth: int, film_metadata: str) -> (dict, dict, list): """Store FiLM params. Args: @@ -235,7 +250,7 @@ def store_film_params(gammas, betas, metadata_values, metadata, model, film_laye film_metadata (str): Metadata of interest used to modulate the network (e.g., contrast, tumor_type). Returns: - dict, dict: gammas, betas + dict, dict, list: gammas, betas, metadata_values """ new_input = [metadata[k][0][film_metadata] for k in range(len(metadata))] metadata_values.append(new_input) @@ -255,7 +270,7 @@ def store_film_params(gammas, betas, metadata_values, metadata, model, film_laye return gammas, betas, metadata_values -def save_film_params(gammas, betas, metadata_values, depth, ofolder): +def save_film_params(gammas: dict, betas: dict, metadata_values: list, depth: int, ofolder: str) -> None: """Save FiLM params as npy files. These parameters can be further used for visualisation purposes. They are saved in the `ofolder` with `.npy` format.