Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
[Improvement] Use masked_softmax + masked_logsoftmax + remove adamw (#…
Browse files Browse the repository at this point in the history
…1457)

* Use masked softmax + masked logsoftmax

* remove adamw optimizer in gluonnlp and use mxnet version

* use boolean mask

* try to fix

* Update attention_cell.py

* fix

* update

* Update attention_cell.py

* Update attention_cell.py

* Update README.md

* Use scale in the OP as Moises suggested

* Update attention_cell.py

* Update attention_cell.py

* Update attention_cell.py

* Update test_attention_cell.py
  • Loading branch information
sxjscience committed Dec 10, 2020
1 parent d4a1d5d commit 8e84bd1
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 397 deletions.
13 changes: 13 additions & 0 deletions scripts/question_answering/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,19 @@ python3 run_squad.py \
--overwrite_cache \
```

In addition, to train models with pre-chosen hyper-parameters, you can try out the scripts in [commands](./commands).

```
# Run FP32 training on SQuAD 2.0
bash commands/run_squad2_albert_base.sh 0 2.0 float32
# Run HOROVOD + FP32 training on SQuAD 2.0
bash commands/run_squad2_albert_base.sh 1 2.0 float32
# Run HOROVOD + AMP on SQuAD 2.0
bash commands/run_squad2_albert_base.sh 1 2.0 float16
```

### Using Horovod

We could speed up multi-GPU training via [Horovod](https://github.com/horovod/horovod).
Expand Down
1 change: 0 additions & 1 deletion src/gluonnlp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from . import loss
from . import lr_scheduler
from . import op
from . import optimizer
from . import registry
from . import sequence_sampler
from . import embedding
155 changes: 27 additions & 128 deletions src/gluonnlp/attention_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,7 @@ def gen_self_attn_mask(data,
mask = mask * batch_ones.reshape((-1, 1, 1))
else:
raise NotImplementedError
mask = mask.astype(dtype)
return mask.astype(np.int32)
return mask.astype(np.bool)


def gen_mem_attn_mask(mem, mem_valid_length, data, data_valid_length=None,
Expand Down Expand Up @@ -241,53 +240,38 @@ def gen_mem_attn_mask(mem, mem_valid_length, data, data_valid_length=None,
else:
query_length_ones = np.ones_like(data_steps)
mask = query_length_ones.reshape((1, -1, 1)) * mem_mask
return mask.astype(np.int32)
return mask.astype(np.bool)


# TODO(sxjscience) Directly implement a kernel for masked softmax
def masked_softmax(att_score, mask, dtype=np.float32, axis: int = -1):
def masked_softmax(att_score, mask, axis: int = -1, temperature=None):
"""Ignore the masked elements when calculating the softmax. The mask can be broadcastable.
Parameters
----------
att_score : Symborl or NDArray
att_score : Symbol or NDArray
Shape (..., length, ...)
mask : Symbol or NDArray or None
Shape (..., length, ...)
1 --> The element is not masked
0 --> The element is masked
dtype
data type
axis
The axis to calculate the softmax. att_score.shape[axis] must be the same as mask.shape[axis]
temperature
The temperature. It scales down the scores before applying the softmax.
Returns
-------
att_weights : Symborl or NDArray
Shape (..., length, ...)
"""
if mask is not None:
# Fill in the masked scores with a very small value
neg = -1e18
if _np.dtype(dtype) == np.float16:
neg = -1e4
else:
try:
# if AMP (automatic mixed precision) is enabled, -1e18 will cause NaN.
from mxnet import amp
if amp.amp._amp_initialized:
neg = -1e4
except ImportError:
pass

att_score = np.where(mask, att_score, neg)
logits = npx.softmax(att_score, axis=axis) * mask
if mask is None:
return npx.softmax(att_score, axis=axis, temperature=temperature)
else:
logits = npx.softmax(att_score, axis=axis)
return logits
return npx.masked_softmax(att_score, mask=mask.astype(np.bool),
axis=axis, temperature=temperature)


# TODO(sxjscience) Directly implement a kernel for masked logsoftmax
def masked_logsoftmax(att_score, mask, dtype=np.float32, axis: int = -1):
def masked_logsoftmax(att_score, mask, axis: int = -1):
"""Ignore the masked elements when calculating the softmax. The mask can be broadcastable.
Parameters
Expand All @@ -298,96 +282,20 @@ def masked_logsoftmax(att_score, mask, dtype=np.float32, axis: int = -1):
Shape (..., length, ...)
mask = 1 --> not masked
mask = 0 --> masked
dtype
data type
axis
The axis to calculate the softmax. att_score.shape[axis] must be the same as mask.shape[axis]
Returns
-------
logits : Symborl or NDArray
Shape (..., length, ...)
The masked values will be all zero
"""
if mask is not None:
# Fill in the masked scores with a very small value
neg = -1e18
if _np.dtype(dtype) == np.float16:
neg = -1e4
else:
try:
# if AMP (automatic mixed precision) is enabled, -1e18 will cause NaN.
from mxnet import amp
if amp.amp._amp_initialized:
neg = -1e4
except ImportError:
pass
att_score = np.where(mask, att_score, neg)
logits = np.where(mask, npx.log_softmax(att_score, axis=axis), -np.inf)
else:
logits = npx.log_softmax(att_score, axis=axis)
return logits


# TODO(sxjscience) Default to einsum. Current it is not the default because
# 1) einsum is super-slow: https://github.com/apache/incubator-mxnet/issues/18043
def dot_attn_score(query, key, scaled=True, normalized=False, eps=1E-6,
layout='NT'):
"""The inner function call to calculate the score used in dot-product attention.
We support multiple leading batch dimensions.
scaled is True:
D(h_q, h_k) = <h_q, h_k> / sqrt(dim_q)
normalized is True:
D(h_q, h_k) = <h_q / ||h_q||, h_k / ||h_k||>
both scaled and normalized:
D(h_q, h_k) = <h_q / ||h_q||, h_k / ||h_k||> / sqrt(dim_q)
Parameters
----------
query : symbol or ndarray
- layout is 'NT'
(B0, ..., BN, query_length, query_dim)
- layout is 'TN'
(query_length, B0, ..., BN, query_dim)
key : symbol or ndarray
- layout is 'NT'
(B0, ..., BN, key_length, key_dim)
- layout is 'TN'
(key_length, B0, ..., BN, key_dim)
scaled : bool
Whether to divide the query by the square-root of the query_dim
If True: D(h_q, h_k) = <h_q, h_k> / sqrt(dim_q)
normalized : bool
Whether to normalize the query and the key embeddings
If True: D(h_q, h_k) = <h_q / ||h_q||, h_k / ||h_k||>
eps : float
The epsilon used in the normalization
layout
The layout of the layer. Can be 'TN' or 'NT'.
Returns
-------
scores : symbol or ndarray
(B0, ..., BN, query_length, key_length)
"""
if normalized:
query = l2_normalize(query, -1, eps=eps)
key = l2_normalize(key, -1, eps=eps)
if scaled:
query_shape = npx.shape_array(query)
# TODO(sxjscience) Remove .astype(np.float32).
# Wait for https://github.com/apache/incubator-mxnet/issues/18084
query_units = query_shape[-1].astype(np.float32)
query = query / np.sqrt(query_units)
if layout == 'NT':
scores = npx.batch_dot(query, key, transpose_b=True)
if mask is None:
return npx.log_softmax(att_score, axis=axis)
else:
raise NotImplementedError('layout={} is not supported.'
' Currently, only layout = "NT" is implemented!'.format(layout))
return scores
mask = mask.astype(np.bool)
return np.where(mask, npx.masked_log_softmax(att_score, mask, axis=axis), -np.inf)


def multi_head_dot_attn(query, key, value,
Expand All @@ -397,8 +305,7 @@ def multi_head_dot_attn(query, key, value,
scaled: bool = True, normalized: bool = False,
eps: float = 1E-6, query_head_units: Optional[int] = None,
layout: str = 'NKT',
use_einsum: bool = False,
dtype=np.float32):
use_einsum: bool = False):
"""Multihead dot product attention between the query, key, value.
scaled is False, normalized is False:
Expand Down Expand Up @@ -488,8 +395,7 @@ def multi_head_dot_attn(query, key, value,
key = l2_normalize(key, axis=-1, eps=eps)
if scaled:
if query_head_units is None:
query_shape = npx.shape_array(query)
scale = np.sqrt(query_shape[-1])
raise NotImplementedError('You will need to specify query_head_units!')
else:
scale = math.sqrt(query_head_units)
else:
Expand All @@ -498,15 +404,13 @@ def multi_head_dot_attn(query, key, value,
# 1. Expand the dimension of the mask:
# (B, L_query, L_mem) --> (B, 1, L_query, L_mem)
if mask is not None:
mask = np.expand_dims(mask, axis=1)
mask = np.expand_dims(mask, axis=1).astype(np.bool)
# 2. Calculate the attention weights
# Score: (B, N, L_query, C_Q) X (B, N, L_mem, C_Q) --> (B, N, L_query, L_mem)
scores = npx.batch_dot(query, key, transpose_b=True)
if edge_scores is not None:
scores = scores + edge_scores
if scaled:
scores = scores / scale
attn_weights = masked_softmax(scores, mask, dtype=dtype, axis=-1)
attn_weights = masked_softmax(scores, mask, axis=-1, temperature=scale)
attn_weights = npx.dropout(attn_weights, p=dropout)
# 3. Calculate the context vector
# (B, N, L_query, L_mem) X (B, N, L_mem, C_V) --> (B, L_query, N * C_V)
Expand All @@ -519,19 +423,17 @@ def multi_head_dot_attn(query, key, value,
# 1. Expand the dimension of the mask:
# (B, L_query, L_mem) --> (B, 1, L_query, L_mem)
if mask is not None:
mask = np.expand_dims(mask, axis=1)
mask = np.expand_dims(mask, axis=1).astype(np.bool)
# 2. Calculate the attention weights
# Score: (B, L_query, N, C_Q) X (B, L_mem, N, C_Q) --> (B, N, L_query, L_mem)
if use_einsum:
scores = np.einsum('binc,bjnc->bnij', query, key)
else:
scores = npx.batch_dot(np.swapaxes(query, 1, 2), np.swapaxes(key, 1, 2),
transpose_b=True)
transpose_b=True)
if edge_scores is not None:
scores = scores + edge_scores
if scaled:
scores = scores / scale
attn_weights = masked_softmax(scores, mask, dtype=dtype)
attn_weights = masked_softmax(scores, mask, axis=-1, temperature=scale)
attn_weights = npx.dropout(attn_weights, p=dropout)
# 3. Calculate the context vector
# (B, N, L_query, L_mem) X (B, L_mem, N, C_V) --> (B, L_query, N * C_V)
Expand All @@ -545,7 +447,7 @@ def multi_head_dot_attn(query, key, value,
# 1. Expand the dimension of the mask:
# (B, L_query, L_mem) --> (B, 1, L_query, L_mem)
if mask is not None:
mask = np.expand_dims(mask, axis=1)
mask = np.expand_dims(mask, axis=1).astype(np.bool)
# 2. Calculate the attention weights
# Score: (L_query, B, N, C_Q) X (L_mem, B, N, C_Q) --> (B, N, L_query, L_mem)
# This layout structure can be implemented very efficiently because B, N are consecutive
Expand All @@ -560,9 +462,7 @@ def multi_head_dot_attn(query, key, value,
key.transpose((1, 2, 3, 0)))
if edge_scores is not None:
scores = scores + edge_scores
if scaled:
scores = scores / scale
attn_weights = masked_softmax(scores, mask, dtype=dtype)
attn_weights = masked_softmax(scores, mask, axis=-1, temperature=scale)
attn_weights = npx.dropout(attn_weights, p=dropout)
# 3. Calculate the context vector
# (B, N, L_query, L_mem) X (L_mem, B, N, C_V) --> (L_query, B, N * C_V)
Expand Down Expand Up @@ -641,8 +541,7 @@ def forward(self, query, key, value, mask=None, edge_scores=None):
scaled=self._scaled, normalized=self._normalized,
eps=self._eps,
query_head_units=self._query_head_units,
layout=self._layout, use_einsum=self._use_einsum,
dtype=self._dtype)
layout=self._layout, use_einsum=self._use_einsum)

def __repr__(self):
s = '{name}(\n' \
Expand Down

0 comments on commit 8e84bd1

Please sign in to comment.