Skip to content

Commit aa128f2

Browse files
committed
Merge branch 'master' of https://github.com/s0tt/modAL
2 parents 72f9d92 + 8b71d24 commit aa128f2

File tree

1 file changed

+0
-35
lines changed

1 file changed

+0
-35
lines changed

modAL/dropout.py

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import numpy as np
2-
import sys
32
import torch
43
from collections.abc import Mapping
54
from typing import Callable
@@ -19,27 +18,6 @@ def default_logits_adaptor(input_tensor: torch.tensor, samples: modALinput):
1918
# default Callable parameter for get_predictions
2019
return input_tensor
2120

22-
def KL_divergence(classifier: BaseEstimator, X: modALinput, n_instances: int = 1,
23-
random_tie_break: bool = False, dropout_layer_indexes: list = [],
24-
num_cycles : int = 50, **mc_dropout_kwargs) -> np.ndarray:
25-
"""
26-
TODO: Work in progress
27-
"""
28-
# set dropout layers to train mode
29-
set_dropout_mode(classifier.estimator.module_, dropout_layer_indexes, train_mode=True)
30-
31-
predictions = get_predictions(classifier, X, num_cycles)
32-
33-
# set dropout layers to eval
34-
set_dropout_mode(classifier.estimator.module_, dropout_layer_indexes, train_mode=False)
35-
36-
#KL_divergence = _KL_divergence(predictions)
37-
38-
if not random_tie_break:
39-
return multi_argmax(KL_divergence, n_instances=n_instances)
40-
41-
return shuffled_argmax(KL_divergence, n_instances=n_instances)
42-
4321
def mc_dropout_multi(classifier: BaseEstimator, X: modALinput, query_strategies: list = ["bald", "mean_st", "max_entropy", "max_var"],
4422
n_instances: int = 1, random_tie_break: bool = False, dropout_layer_indexes: list = [],
4523
num_cycles : int = 50, sample_per_forward_pass: int = 1000,
@@ -436,23 +414,10 @@ def _bald_divergence(proba: list) -> np.ndarray:
436414
bald = np.sum(shaped, where=~np.isnan(shaped), axis=-1)
437415
return bald
438416

439-
def _KL_divergence(proba) -> np.ndarray:
440-
441-
#create 3D or 4D array from prediction dim: (drop_cycles, proba.shape[0], proba.shape[1], opt:proba.shape[2])
442-
proba_stacked = np.stack(proba, axis=len(proba[0].shape))
443-
# TODO work in progress
444-
# TODO add dimensionality adaption
445-
#number_of_dimensions = proba_stacked.ndim
446-
#if proba_stacked.ndim > 2:
447-
448-
normalized_proba = normalize(proba_stacked, axis=0)
449-
450417

451418
def set_dropout_mode(model, dropout_layer_indexes: list, train_mode: bool):
452419
"""
453420
Function to enable the dropout layers by setting them to user specified mode (bool: train_mode)
454-
TODO: Reduce maybe complexity
455-
TODO: Keras support
456421
"""
457422

458423
modules = list(model.modules()) # list of all modules in the network.

0 commit comments

Comments
 (0)