11import numpy as np
2- import sys
32import torch
43from collections .abc import Mapping
54from typing import Callable
@@ -19,27 +18,6 @@ def default_logits_adaptor(input_tensor: torch.tensor, samples: modALinput):
1918 # default Callable parameter for get_predictions
2019 return input_tensor
2120
22- def KL_divergence (classifier : BaseEstimator , X : modALinput , n_instances : int = 1 ,
23- random_tie_break : bool = False , dropout_layer_indexes : list = [],
24- num_cycles : int = 50 , ** mc_dropout_kwargs ) -> np .ndarray :
25- """
26- TODO: Work in progress
27- """
28- # set dropout layers to train mode
29- set_dropout_mode (classifier .estimator .module_ , dropout_layer_indexes , train_mode = True )
30-
31- predictions = get_predictions (classifier , X , num_cycles )
32-
33- # set dropout layers to eval
34- set_dropout_mode (classifier .estimator .module_ , dropout_layer_indexes , train_mode = False )
35-
36- #KL_divergence = _KL_divergence(predictions)
37-
38- if not random_tie_break :
39- return multi_argmax (KL_divergence , n_instances = n_instances )
40-
41- return shuffled_argmax (KL_divergence , n_instances = n_instances )
42-
4321def mc_dropout_multi (classifier : BaseEstimator , X : modALinput , query_strategies : list = ["bald" , "mean_st" , "max_entropy" , "max_var" ],
4422 n_instances : int = 1 , random_tie_break : bool = False , dropout_layer_indexes : list = [],
4523 num_cycles : int = 50 , sample_per_forward_pass : int = 1000 ,
@@ -436,23 +414,10 @@ def _bald_divergence(proba: list) -> np.ndarray:
436414 bald = np .sum (shaped , where = ~ np .isnan (shaped ), axis = - 1 )
437415 return bald
438416
439- def _KL_divergence (proba ) -> np .ndarray :
440-
441- #create 3D or 4D array from prediction dim: (drop_cycles, proba.shape[0], proba.shape[1], opt:proba.shape[2])
442- proba_stacked = np .stack (proba , axis = len (proba [0 ].shape ))
443- # TODO work in progress
444- # TODO add dimensionality adaption
445- #number_of_dimensions = proba_stacked.ndim
446- #if proba_stacked.ndim > 2:
447-
448- normalized_proba = normalize (proba_stacked , axis = 0 )
449-
450417
451418def set_dropout_mode (model , dropout_layer_indexes : list , train_mode : bool ):
452419 """
453420 Function to enable the dropout layers by setting them to user specified mode (bool: train_mode)
454- TODO: Reduce maybe complexity
455- TODO: Keras support
456421 """
457422
458423 modules = list (model .modules ()) # list of all modules in the network.
0 commit comments