Skip to content

Commit

Permalink
add autotune p_keep for text (refs #24)
Browse files Browse the repository at this point in the history
  • Loading branch information
cwmeijer committed Dec 22, 2021
1 parent 7e8b498 commit 430a022
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 35 deletions.
65 changes: 44 additions & 21 deletions dianna/methods/rise.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,42 +33,68 @@ def explain_text(self, model_or_function, input_text, labels=(0,), batch_size=10
runner = get_function(model_or_function, preprocess_function=self.preprocess_function)
input_tokens = np.asarray(model_or_function.tokenizer(input_text))
text_length = len(input_tokens)
p_keep = self._determine_p_keep()
self.masks = self._generate_masks_for_text(text_length, p_keep) # Expose masks for to make user inspection possible
sentences = self._create_masked_sentences(input_tokens)
# p_keep = self._determine_p_keep_for_text(input_tokens, runner) if self.p_keep == None else self.p_keep
p_keep = 0.5
input_shape = (text_length,)
self.masks = self._generate_masks_for_text(input_shape, p_keep, self.n_masks) # Expose masks for to make user inspection possible
sentences = self._create_masked_sentences(input_tokens, self.masks)
saliencies = self._get_saliencies(runner, sentences, text_length, batch_size, p_keep)
return self._reshape_result(input_tokens, labels, saliencies)

@staticmethod
def _reshape_result(input_tokens, labels, saliencies):
word_lengths = [len(t) for t in input_tokens]
word_indices = [sum(word_lengths[:i]) + i for i in range(len(input_tokens))]
return [list(zip(input_tokens, word_indices, saliencies[label])) for label in labels]
def _determine_p_keep_for_text(self, input_data, runner, n_masks=100):
p_keeps = np.arange(0.1, 1.0, 0.1)
stds = []
for p_keep in p_keeps:
std = self._calculate_mean_class_std_for_text(p_keep, runner, input_data, n_masks=n_masks)
stds += [std]
best_i = np.argmax(stds)
best_p_keep = p_keeps[best_i]
print('Rise parameter p_keep was automatically determined at {}'.format(best_p_keep))
return best_p_keep

def _calculate_mean_class_std_for_text(self, p_keep, runner, input_data, n_masks):
batch_size = 50
img_shape = input_data.shape
masks = self._generate_masks_for_text(img_shape, p_keep, n_masks)
masked = self._create_masked_sentences(input_data, masks)
predictions = []
for i in range(0, n_masks, batch_size):
current_input = masked[i:i + batch_size]
current_predictions = runner(current_input)
predictions.append(current_predictions)
predictions = np.concatenate(predictions)
std_per_class = predictions.std(axis=0)
return np.mean(std_per_class)

def _generate_masks_for_text(self, input_shape, p_keep, n_masks):
masks = np.random.choice(a=(True, False), size=(n_masks,) + input_shape, p=(p_keep, 1 - p_keep))
return masks

def _get_saliencies(self, runner, sentences, text_length, batch_size, p_keep):
self.predictions = self._get_predictions(sentences, runner, batch_size)
unnormalized_saliency = self.predictions.T.dot(self.masks.reshape(self.n_masks, -1)).reshape(-1, text_length)
return normalize(unnormalized_saliency, self.n_masks, p_keep)

@staticmethod
def _reshape_result(input_tokens, labels, saliencies):
word_lengths = [len(t) for t in input_tokens]
word_indices = [sum(word_lengths[:i]) + i for i in range(len(input_tokens))]
return [list(zip(input_tokens, word_indices, saliencies[label])) for label in labels]

def _get_predictions(self, sentences, runner, batch_size):
predictions = []
for i in tqdm(range(0, self.n_masks, batch_size), desc='Explaining'):
predictions.append(runner(sentences[i:i + batch_size]))
predictions = np.concatenate(predictions)
return predictions

def _create_masked_sentences(self, tokens):
def _create_masked_sentences(self, tokens, masks):
tokens_masked = []
for mask in self.masks:
for mask in masks:
tokens_masked.append(tokens[mask])
sentences = [" ".join(t) for t in tokens_masked]
return sentences

def _generate_masks_for_text(self, input_size, p_keep):

masks = np.random.choice(a=(True, False), size=(self.n_masks, input_size), p=(p_keep, 1 - p_keep))
return masks

def explain_image(self, model_or_function, input_data, batch_size=100):
"""Run the RISE explainer.
The model will be called with masked images,
Expand Down Expand Up @@ -101,17 +127,17 @@ def explain_image(self, model_or_function, input_data, batch_size=100):
return normalize(saliency, self.n_masks, p_keep)

def _determine_p_keep_for_images(self, input_data, runner, n_masks=100):
p_keeps = np.arange(0.1, 0.9, 0.1)
p_keeps = np.arange(0.1, 1.0, 0.1)
stds = []
for p_keep in p_keeps:
std = self._calculate_mean_class_std(p_keep, runner, input_data, n_masks=n_masks)
std = self._calculate_mean_class_std_for_images(p_keep, runner, input_data, n_masks=n_masks)
stds += [std]
best_i = np.argmax(stds)
best_p_keep = p_keeps[best_i]
print('Rise parameter p_keep was automatically determined at {}'.format(best_p_keep))
return best_p_keep

def _calculate_mean_class_std(self, p_keep, runner, input_data, n_masks):
def _calculate_mean_class_std_for_images(self, p_keep, runner, input_data, n_masks):
batch_size = 50
img_shape = input_data.shape[1:3]
masks = self.generate_masks_for_images(img_shape, p_keep, n_masks, use_progress_bar=False)
Expand All @@ -125,9 +151,6 @@ def _calculate_mean_class_std(self, p_keep, runner, input_data, n_masks):
std_per_class = predictions.std(axis=0)
return np.mean(std_per_class)

def _determine_p_keep(self):
return self.p_keep if not self.p_keep is None else 0.5

def generate_masks_for_images(self, input_size, p_keep, n_masks, use_progress_bar=True):
"""Generate a set of random masks to mask the input data
Expand Down
54 changes: 40 additions & 14 deletions tests/test_rise.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,18 @@ def test_rise_filename(self):
def test_rise_determine_p_keep_for_images(self):
'''
When using the large sample size of 10000, the mean STD for each class for the following p_keeps
[ 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
[ 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
is as follows:
[2.071084, 2.6516771, 2.896659, 2.9460478, 2.9888847, 2.8803914, 2.6940017, 2.3410206]
[2.069784, 2.600222, 2.8940516, 2.9950087, 2.9579144, 2.8919978, 2.6288269, 2.319147, 1.763127]
So best p_keep should be .4 or .5 ( or at least between .3 and .6).
When using 20 n_masks we got this p_keep histogram: [ 1 7 19 24 21 18 8 2]
When using 30 n_masks we got this p_keep histogram: [ 0 4 11 30 23 23 9]
When using 50 n_masks we got this p_keep histogram: [ 0 3 16 35 28 14 4]
When using 100 n_masks we got this p_keep histogram: [ 0 3 14 37 23 21 2]
When using 200 n_masks we got this p_keep histogram: [ 0 0 16 37 32 15]
When using 20 n_masks we got this p_keep bincount: [ 1 7 19 24 21 18 8 2]
When using 30 n_masks we got this p_keep bincount: [ 0 4 11 30 23 23 9]
When using 50 n_masks we got this p_keep bincount: [ 0 3 16 35 28 14 4]
When using 100 n_masks we got this p_keep bincount: [ 0 3 14 37 23 21 2]
When using 200 n_masks we got this p_keep bincount: [ 0 0 16 37 32 15]
It seems 20 is not enough to have a good chance of getting a good p_keep. For 200 every sample returns a reasonable p_keep but is a bit much to be practicle. I think we should use 100 to be on the save side.
Returns:
'''
np.random.seed(0)
Expand All @@ -53,21 +52,17 @@ def test_rise_determine_p_keep_for_images(self):
model_filename = 'tests/test_data/mnist_model.onnx'
data = get_mnist_1_data()

explainer = rise.RISE()
p_keep = explainer._determine_p_keep_for_images(data, get_function(model_filename))
p_keep = rise.RISE()._determine_p_keep_for_images(data, get_function(model_filename))

assert p_keep in expected_p_keeps # Sanity check: is the outcome in the acceptable range?
assert p_keep == expected_p_exact_keep ## Exact test: is the outcome the same as before?
assert p_keep == expected_p_exact_keep # Exact test: is the outcome the same as before?

class RiseOnText(TestCase):
@pytest.mark.skip
def test_rise_text(self):
np.random.seed(42)

model_path = 'tests/test_data/movie_review_model.onnx'
word_vector_file = 'tests/test_data/word_vectors.txt'
runner = ModelRunner(model_path, word_vector_file, max_filter_size=5)

review = 'such a bad movie'

positive_explanation = dianna.explain_text(runner, review, labels=(1, 0), method='RISE')[0]
Expand All @@ -83,3 +78,34 @@ def test_rise_text(self):
assert words == expected_words
assert word_indices == expected_word_indices
assert np.allclose(positive_scores, expected_positive_scores)

def test_rise_determine_p_keep_for_text(self):
'''
When using the large sample size of 10000, the mean STD for each class for the following p_keeps
[ 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
is as follows:
[0.18085817, 0.239386, 0.27801532, 0.30555934, 0.31592548, 0.31345606, 0.2901688, 0.2539522, 0.19383237]
So best p_keep should be .4 or .5 ( or at least between .4 and .7).
When using 20 n_masks we got this p_keep bincount: [ 0 0 2 11 37 41 7 2]
When using 30 n_masks we got this p_keep bincount: [ 0 0 1 9 33 42 14 1]
When using 50 n_masks we got this p_keep bincount: [ 0 0 0 5 50 42 3]
When using 100 n_masks we got this p_keep bincount: [ 0 0 0 7 61 31 1]
It seems 20 is not enough to have a good chance of getting a good p_keep. It seems 50 n_masks is enough here.
To be on the safe side we go for 100 anyway.
'''
np.random.seed(0)
expected_p_keeps = [.3, .4, .5, .6]
expected_p_exact_keep = .5
model_path = 'tests/test_data/movie_review_model.onnx'
word_vector_file = 'tests/test_data/word_vectors.txt'
runner = ModelRunner(model_path, word_vector_file, max_filter_size=5)
input_text = 'such a bad movie'
runner = get_function(runner)
input_tokens = np.asarray(runner.tokenizer(input_text))

p_keep = rise.RISE()._determine_p_keep_for_text(input_tokens, runner)

assert p_keep in expected_p_keeps # Sanity check: is the outcome in the acceptable range?
assert p_keep == expected_p_exact_keep # Exact test: is the outcome the same as before?

0 comments on commit 430a022

Please sign in to comment.