Skip to content

Commit

Permalink
Minor touch up on existing type hinting
Browse files Browse the repository at this point in the history
  • Loading branch information
dyt811 committed Nov 18, 2022
1 parent 24e61ff commit 4a7b58d
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions ivadomed/loader/film.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
from pathlib import Path
from copy import deepcopy
from typing import List
from typing import List, Union

import numpy as np
from loguru import logger
Expand All @@ -12,9 +12,12 @@
from sklearn.preprocessing import OneHotEncoder
from ivadomed.keywords import MetadataKW
import typing

if typing.TYPE_CHECKING:
from ivadomed.loader.bids_dataset import BidsDataset
from ivadomed.loader.bids3d_dataset import Bids3DDataset
from ivadomed.loader.mri2d_segmentation_dataset import MRI2DSegmentationDataset

import torch.nn as nn

from ivadomed import __path__
Expand All @@ -26,13 +29,16 @@
"acq-MToff_MTS": 3, "acq-MTon_MTS": 4, "acq-T1w_MTS": 5}


def normalize_metadata(ds_in: BidsDataset, clustering_models: any, debugging: bool, metadata_type: str, train_set: bool = False)\
-> (list, OneHotEncoder) | list:
def normalize_metadata(ds_in: Union[BidsDataset, Bids3DDataset, MRI2DSegmentationDataset],
clustering_models: dict,
debugging: bool,
metadata_type: str,
train_set: bool = False) -> (list, OneHotEncoder) | list:
"""Categorize each metadata value using a KDE clustering method, then apply a one-hot-encoding.
Args:
ds_in (BidsDataset): Dataset with metadata.
clustering_models: Pre-trained clustering model that has been trained on metadata of the training set.
ds_in (BidsDataset): Dataset BidsDataset, Bids3D, MRI2D with metadata.
clustering_models (dict): Pre-trained clustering model that has been trained on metadata of the training set.
debugging (bool): If True, extended verbosity and intermediate outputs.
metadata_type (str): Choice between 'mri_params', 'constrasts' or the name of a column from the
participants.tsv file.
Expand Down Expand Up @@ -117,7 +123,7 @@ def __init__(self) -> None:
self.kde = KernelDensity()
self.minima = None

def train(self, data: any, value_range: any, gridsearch_bandwidth_range: any) -> None:
def train(self, data: list, value_range: np.ndarray, gridsearch_bandwidth_range: np.ndarray) -> None:
# reshape data to fit sklearn
data = np.array(data).reshape(-1, 1)

Expand All @@ -138,7 +144,7 @@ def train(self, data: any, value_range: any, gridsearch_bandwidth_range: any) ->
# find local minima
self.minima = s[argrelextrema(e, np.less)[0]]

def predict(self, data: any) -> int:
def predict(self, data: float) -> int:
x = [i for i, m in enumerate(self.minima) if data < m]
pred = min(x) if len(x) else len(self.minima)
return pred
Expand All @@ -149,7 +155,7 @@ def clustering_fit(dataset: list, key_lst: List[str]) -> dict:
using Kernel Density Estimation algorithm.
Args:
datasets (list): data
dataset (list): data
key_lst (list of str): names of metadata to cluster
Returns:
Expand Down Expand Up @@ -230,7 +236,7 @@ def get_film_metadata_models(ds_train: MRI2DSegmentationDataset, metadata_type:


def store_film_params(gammas: dict, betas: dict, metadata_values: list, metadata: list, model: nn.Module,
film_layers: list, depth: int, film_metadata: str) -> (dict, dict):
film_layers: list, depth: int, film_metadata: str) -> (dict, dict, list):
"""Store FiLM params.
Args:
Expand All @@ -244,7 +250,7 @@ def store_film_params(gammas: dict, betas: dict, metadata_values: list, metadata
film_metadata (str): Metadata of interest used to modulate the network (e.g., contrast, tumor_type).
Returns:
dict, dict: gammas, betas
dict, dict, list: gammas, betas, metadata_values
"""
new_input = [metadata[k][0][film_metadata] for k in range(len(metadata))]
metadata_values.append(new_input)
Expand Down

0 comments on commit 4a7b58d

Please sign in to comment.