Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
[FEATURE] Add raw attention scores to the AttentionCell #951 (#964)
Browse files Browse the repository at this point in the history
* Implement: Add raw attention scores to the AttentionCell #951

* Fix pylint issues.

* Address code-review comments.

* Separate _compute_score form _compute_weight.
  • Loading branch information
emilmont authored and leezu committed Oct 25, 2019
1 parent 7557893 commit 06ecac8
Showing 1 changed file with 55 additions and 17 deletions.
72 changes: 55 additions & 17 deletions src/gluonnlp/model/attention_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,38 @@
from mxnet.gluon import nn
from .block import L2Normalization


def _apply_mask(F, att_score, mask, dtype):
"""Fill in the masked scores with a very small value
Parameters
----------
F : symbol or ndarray
att_score : Symbol or NDArray
Shape (batch_size, query_length, memory_length)
mask : Symbol or NDArray or None
Shape (batch_size, query_length, memory_length)
Returns
-------
att_score : Symbol or NDArray
Shape (batch_size, query_length, memory_length)
"""
# Fill in the masked scores with a very small value
neg = -1e18
if np.dtype(dtype) == np.float16:
neg = -1e4
else:
try:
# if AMP (automatic mixed precision) is enabled, -1e18 will cause NaN.
from mxnet.contrib import amp
if amp.amp._amp_initialized:
neg = -1e4
except ImportError:
pass
att_score = F.where(mask, att_score, neg * F.ones_like(att_score))
return att_score


# TODO(sxjscience) Add mask flag to softmax operator. Think about how to accelerate the kernel
def _masked_softmax(F, att_score, mask, dtype):
"""Ignore the masked elements when calculating the softmax
Expand All @@ -38,23 +70,12 @@ def _masked_softmax(F, att_score, mask, dtype):
Shape (batch_size, query_length, memory_length)
Returns
-------
att_weights : Symborl or NDArray
att_weights : Symbol or NDArray
Shape (batch_size, query_length, memory_length)
"""
if mask is not None:
# Fill in the masked scores with a very small value
neg = -1e18
if np.dtype(dtype) == np.float16:
neg = -1e4
else:
try:
# if AMP (automatic mixed precision) is enabled, -1e18 will cause NaN.
from mxnet.contrib import amp
if amp.amp._amp_initialized:
neg = -1e4
except ImportError:
pass
att_score = F.where(mask, att_score, neg * F.ones_like(att_score))
att_score = _apply_mask(F, att_score, mask, dtype)
att_weights = F.softmax(att_score, axis=-1) * mask
else:
att_weights = F.softmax(att_score, axis=-1)
Expand Down Expand Up @@ -353,14 +374,23 @@ def hybrid_forward(self, F, x, g, v): # pylint: disable=arguments-differ
weight_initializer=weight_initializer,
prefix='score_')

def _compute_weight(self, F, query, key, mask=None):
def _compute_score(self, F, query, key, mask=None):
mapped_query = self._query_mid_layer(query)
mapped_key = self._key_mid_layer(key)
mid_feat = F.broadcast_add(F.expand_dims(mapped_query, axis=2),
F.expand_dims(mapped_key, axis=1))
mid_feat = self._act(mid_feat)
att_score = self._attention_score(mid_feat).reshape(shape=(0, 0, 0))
att_weights = self._dropout_layer(_masked_softmax(F, att_score, mask, self._dtype))
if mask is not None:
att_score = _apply_mask(F, att_score, mask, self._dtype)
return att_score

def _compute_weight(self, F, query, key, mask=None):
att_score = self._compute_score(F, query, key, mask)
att_weights = F.softmax(att_score, axis=-1)
if mask is not None:
att_weights = att_weights * mask
att_weights = self._dropout_layer(att_weights)
return att_weights


Expand Down Expand Up @@ -449,7 +479,7 @@ def __init__(self, units=None, luong_style=False, scaled=True, normalized=False,
with self.name_scope():
self._l2_norm = L2Normalization(axis=-1)

def _compute_weight(self, F, query, key, mask=None):
def _compute_score(self, F, query, key, mask=None):
if self._units is not None:
query = self._proj_query(query)
if not self._luong_style:
Expand All @@ -466,6 +496,14 @@ def _compute_weight(self, F, query, key, mask=None):
query = F.contrib.div_sqrt_dim(query)

att_score = F.batch_dot(query, key, transpose_b=True)
if mask is not None:
att_score = _apply_mask(F, att_score, mask, self._dtype)
return att_score

att_weights = self._dropout_layer(_masked_softmax(F, att_score, mask, self._dtype))
def _compute_weight(self, F, query, key, mask=None):
att_score = self._compute_score(F, query, key, mask)
att_weights = F.softmax(att_score, axis=-1)
if mask is not None:
att_weights = att_weights * mask
att_weights = self._dropout_layer(att_weights)
return att_weights

0 comments on commit 06ecac8

Please sign in to comment.