Skip to content

Commit

Permalink
Add TypeHint for loader/film.py (#1151)
Browse files Browse the repository at this point in the history
  • Loading branch information
cakester committed Nov 20, 2022
1 parent 2a5ab30 commit 33a8960
Showing 1 changed file with 27 additions and 12 deletions.
39 changes: 27 additions & 12 deletions ivadomed/loader/film.py
@@ -1,6 +1,8 @@
from __future__ import annotations
import json
from pathlib import Path
from copy import deepcopy
from typing import List, Union

import numpy as np
from loguru import logger
Expand All @@ -9,6 +11,14 @@
from sklearn.neighbors import KernelDensity
from sklearn.preprocessing import OneHotEncoder
from ivadomed.keywords import MetadataKW
import typing

if typing.TYPE_CHECKING:
from ivadomed.loader.bids_dataset import BidsDataset
from ivadomed.loader.bids3d_dataset import Bids3DDataset
from ivadomed.loader.mri2d_segmentation_dataset import MRI2DSegmentationDataset

import torch.nn as nn

from ivadomed import __path__

Expand All @@ -19,12 +29,16 @@
"acq-MToff_MTS": 3, "acq-MTon_MTS": 4, "acq-T1w_MTS": 5}


def normalize_metadata(ds_in, clustering_models, debugging, metadata_type, train_set=False):
def normalize_metadata(ds_in: Union[BidsDataset, Bids3DDataset, MRI2DSegmentationDataset],
clustering_models: dict,
debugging: bool,
metadata_type: str,
train_set: bool = False) -> (list, OneHotEncoder) | list:
"""Categorize each metadata value using a KDE clustering method, then apply a one-hot-encoding.
Args:
ds_in (BidsDataset): Dataset with metadata.
clustering_models: Pre-trained clustering model that has been trained on metadata of the training set.
ds_in (BidsDataset): Dataset BidsDataset, Bids3D, MRI2D with metadata.
clustering_models (dict): Pre-trained clustering model that has been trained on metadata of the training set.
debugging (bool): If True, extended verbosity and intermediate outputs.
metadata_type (str): Choice between 'mri_params', 'constrasts' or the name of a column from the
participants.tsv file.
Expand Down Expand Up @@ -105,11 +119,11 @@ class Kde_model():
kde (sklearn.neighbors.KernelDensity):
minima (float): Local minima.
"""
def __init__(self):
def __init__(self) -> None:
self.kde = KernelDensity()
self.minima = None

def train(self, data, value_range, gridsearch_bandwidth_range):
def train(self, data: list, value_range: np.ndarray, gridsearch_bandwidth_range: np.ndarray) -> None:
# reshape data to fit sklearn
data = np.array(data).reshape(-1, 1)

Expand All @@ -130,18 +144,18 @@ def train(self, data, value_range, gridsearch_bandwidth_range):
# find local minima
self.minima = s[argrelextrema(e, np.less)[0]]

def predict(self, data):
def predict(self, data: float) -> int:
x = [i for i, m in enumerate(self.minima) if data < m]
pred = min(x) if len(x) else len(self.minima)
return pred


def clustering_fit(dataset, key_lst):
def clustering_fit(dataset: list, key_lst: List[str]) -> dict:
"""This function creates clustering models for each metadata type,
using Kernel Density Estimation algorithm.
Args:
datasets (list): data
dataset (list): data
key_lst (list of str): names of metadata to cluster
Returns:
Expand Down Expand Up @@ -192,7 +206,7 @@ def check_isMRIparam(mri_param_type: str, mri_param: dict, subject: str, metadat
return True


def get_film_metadata_models(ds_train, metadata_type, debugging=False):
def get_film_metadata_models(ds_train: MRI2DSegmentationDataset, metadata_type: str, debugging: bool = False):
"""Get FiLM models.
This function pulls the clustering and one-hot encoder models that are used by FiLMedUnet.
Expand Down Expand Up @@ -221,7 +235,8 @@ def get_film_metadata_models(ds_train, metadata_type, debugging=False):
return ds_train, train_onehotencoder, metadata_clustering_models


def store_film_params(gammas, betas, metadata_values, metadata, model, film_layers, depth, film_metadata):
def store_film_params(gammas: dict, betas: dict, metadata_values: list, metadata: list, model: nn.Module,
film_layers: list, depth: int, film_metadata: str) -> (dict, dict, list):
"""Store FiLM params.
Args:
Expand All @@ -235,7 +250,7 @@ def store_film_params(gammas, betas, metadata_values, metadata, model, film_laye
film_metadata (str): Metadata of interest used to modulate the network (e.g., contrast, tumor_type).
Returns:
dict, dict: gammas, betas
dict, dict, list: gammas, betas, metadata_values
"""
new_input = [metadata[k][0][film_metadata] for k in range(len(metadata))]
metadata_values.append(new_input)
Expand All @@ -255,7 +270,7 @@ def store_film_params(gammas, betas, metadata_values, metadata, model, film_laye
return gammas, betas, metadata_values


def save_film_params(gammas, betas, metadata_values, depth, ofolder):
def save_film_params(gammas: dict, betas: dict, metadata_values: list, depth: int, ofolder: str) -> None:
"""Save FiLM params as npy files.
These parameters can be further used for visualisation purposes. They are saved in the `ofolder` with `.npy` format.
Expand Down

0 comments on commit 33a8960

Please sign in to comment.