Skip to content

Commit

Permalink
Add layout + compute_layout support: TransformerNMT, BERT, ALBERT, EL…
Browse files Browse the repository at this point in the history
…ECTRA, MobileBERT, RoBERTA, XLMR (dmlc#1258)

* Add layout support

* fix test

* Update transformer.py

* Update transformer.py

* Update README.md

* try to add set_layout

* update test case

* fix

* update

* update

* update

* Update bert.py

* fix bug

* update

* Update test_models_bert.py

* Update tokenizers.py

* add compute layout

* Update xlmr.py

* Update test_models_bert.py

* revise test cases

* Update layers.py

* move jieba to try import

* fix

* Update transformer.py

* fix

* Update bert.py

* Update setup.py

* Update test_models_bert.py

* Update test_models_bert.py

* fix

* update

* Revise

* Update electra.py

* Update electra.py

* Update test_models_electra.py

* fix

* fix bug

* Update test_models_albert.py

* add more testcases

* fix

* Update albert.py

* Update albert.py

* fix bug

* fix testcase

* Update test_models_electra.py

* Update bert.py

* update

* Update test_models_electra.py

* Update mobilebert.py

* Update mobilebert.py

* update mobilebert

* Update test_models_mobilebert.py

* Update mobilebert.py

* fix bug

* Update roberta.py

* fix roberta

* update

* update

* fix import

* fix bug

* update

* reduce test workloads

* address comment

* address comment
  • Loading branch information
sxjscience committed Jul 29, 2020
1 parent 4d43f82 commit 3c87457
Show file tree
Hide file tree
Showing 23 changed files with 2,280 additions and 765 deletions.
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,17 @@ This is a work-in-progress.
First of all, install the latest MXNet. You may use the following commands:

```bash
# Install the version with CUDA 10.0
pip install -U --pre "mxnet-cu100>=2.0.0b20200716" -f https://dist.mxnet.io/python

# Install the version with CUDA 10.1
pip install -U --pre mxnet-cu101>=2.0.0b20200716 -f https://dist.mxnet.io/python
pip install -U --pre "mxnet-cu101>=2.0.0b20200716" -f https://dist.mxnet.io/python

# Install the version with CUDA 10.2
pip install -U --pre "mxnet-cu102>=2.0.0b20200716" -f https://dist.mxnet.io/python

# Install the cpu-only version
pip install -U --pre mxnet>=2.0.0b20200716 -f https://dist.mxnet.io/python
pip install -U --pre "mxnet>=2.0.0b20200716" -f https://dist.mxnet.io/python
```


Expand Down
3 changes: 1 addition & 2 deletions scripts/conversion_toolkits/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ Notice: pleas set up the `--electra_path` with the cloned path or get this elect

```bash
# Need to use TF 1.13.2 to use contrib layer
pip uninstall tensorflow
pip install tensorflow==1.13.2
pip install tensorflow==1.13.2 --upgrade --force-reinstall

# Actual conversion
bash convert_electra.sh
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def find_version(*file_paths):
'scripts',
)),
package_dir={"": "src"},
package_data={'': [os.path.join('models', 'model_zoo_checksums', '*.txt'),
os.path.join('cli', 'data', 'url_checksums', '*.txt')]},
zip_safe=True,
include_package_data=True,
install_requires=requirements,
Expand Down
79 changes: 59 additions & 20 deletions src/gluonnlp/attention_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
def gen_self_attn_mask(F, data,
valid_length=None,
dtype: type = np.float32,
attn_type: str = 'full'):
attn_type: str = 'full',
layout: str = 'NT'):
"""Generate the mask used for the encoder, i.e, self-attention.
In our implementation, 1 --> not masked, 0 --> masked
Expand Down Expand Up @@ -100,38 +101,50 @@ def gen_self_attn_mask(F, data,
Parameters
----------
F :
data :
The data. Shape (batch_size, seq_length, C)
valid_length :
F
data
The data.
- layout = 'NT'
Shape (batch_size, seq_length, C)
- layout = 'TN'
Shape (seq_length, batch_size, C)
valid_length
Shape (batch_size,)
dtype
Data type of the mask
attn_type : str
attn_type
Can be 'full' or 'causal'
layout
The layout of the data
Returns
-------
mask
Shape (batch_size, seq_length, seq_length)
"""
if layout == 'NT':
batch_axis, time_axis = 0, 1
elif layout == 'TN':
batch_axis, time_axis = 1, 0
else:
raise NotImplementedError('Unsupported layout={}'.format(layout))
if attn_type == 'full':
if valid_length is not None:
valid_length = valid_length.astype(dtype)
steps = F.npx.arange_like(data, axis=1) # (seq_length,)
steps = F.npx.arange_like(data, axis=time_axis) # (seq_length,)
mask1 = (F.npx.reshape(steps, (1, 1, -1))
< F.npx.reshape(valid_length, (-2, 1, 1)))
mask2 = (F.npx.reshape(steps, (1, -1, 1))
< F.npx.reshape(valid_length, (-2, 1, 1)))
mask = mask1 * mask2
else:
# TODO(sxjscience) optimize
seq_len_ones = F.np.ones_like(F.npx.arange_like(data, axis=1)) # (seq_length,)
batch_ones = F.np.ones_like(F.npx.arange_like(data, axis=0)) # (batch_size,)
seq_len_ones = F.np.ones_like(F.npx.arange_like(data, axis=time_axis)) # (seq_length,)
batch_ones = F.np.ones_like(F.npx.arange_like(data, axis=batch_axis)) # (batch_size,)
mask = batch_ones.reshape((-1, 1, 1)) * seq_len_ones.reshape((1, -1, 1))\
* seq_len_ones.reshape((1, 1, -1))
elif attn_type == 'causal':
steps = F.npx.arange_like(data, axis=1)
steps = F.npx.arange_like(data, axis=time_axis)
# mask: (seq_length, seq_length)
# batch_mask: (batch_size, seq_length)
mask = (F.np.expand_dims(steps, axis=0) <= F.np.expand_dims(steps, axis=1)).astype(dtype)
Expand All @@ -140,15 +153,17 @@ def gen_self_attn_mask(F, data,
batch_mask = (F.np.expand_dims(steps, axis=0) < F.np.expand_dims(valid_length, axis=-1)).astype(dtype)
mask = mask * F.np.expand_dims(batch_mask, axis=-1)
else:
batch_ones = F.np.ones_like(F.npx.arange_like(data, axis=0), dtype=np.float32) # (batch_size,)
batch_ones = F.np.ones_like(F.npx.arange_like(data, axis=batch_axis),
dtype=dtype) # (batch_size,)
mask = mask * batch_ones.reshape((-1, 1, 1))
else:
raise NotImplementedError
mask = mask.astype(dtype)
return mask


def gen_mem_attn_mask(F, mem, mem_valid_length, data, data_valid_length=None, dtype=np.float32):
def gen_mem_attn_mask(F, mem, mem_valid_length, data, data_valid_length=None,
dtype=np.float32, layout: str = 'NT'):
"""Generate the mask used for the decoder. All query slots are attended to the memory slots.
In our implementation, 1 --> not masked, 0 --> masked
Expand Down Expand Up @@ -183,34 +198,48 @@ def gen_mem_attn_mask(F, mem, mem_valid_length, data, data_valid_length=None, dt
Parameters
----------
F :
mem :
Shape (batch_size, mem_length, C_mem)
mem
- layout = 'NT'
Shape (batch_size, mem_length, C_mem)
- layout = 'TN'
Shape (mem_length, batch_size, C_mem)
mem_valid_length :
Shape (batch_size,)
data :
Shape (batch_size, query_length, C_data)
data
- layout = 'NT'
Shape (batch_size, query_length, C_data)
- layout = 'TN'
Shape (query_length, batch_size, C_data)
data_valid_length :
Shape (batch_size,)
dtype : type
dtype
Data type of the mask
layout
Layout of the data + mem tensor
Returns
-------
mask :
Shape (batch_size, query_length, mem_length)
"""
if layout == 'NT':
batch_axis, time_axis = 0, 1
elif layout == 'TN':
batch_axis, time_axis = 1, 0
else:
raise NotImplementedError('Unsupported layout={}'.format(layout))
mem_valid_length = mem_valid_length.astype(dtype)
mem_steps = F.npx.arange_like(mem, axis=1) # (mem_length,)
mem_steps = F.npx.arange_like(mem, axis=time_axis) # (mem_length,)
data_steps = F.npx.arange_like(data, axis=time_axis) # (query_length,)
mem_mask = (F.npx.reshape(mem_steps, (1, 1, -1))
< F.npx.reshape(mem_valid_length, (-2, 1, 1))).astype(dtype) # (B, 1, mem_length)
if data_valid_length is not None:
data_valid_length = data_valid_length.astype(dtype)
data_steps = F.npx.arange_like(data, axis=1) # (query_length,)
data_mask = (F.npx.reshape(data_steps, (1, -1, 1))
< F.npx.reshape(data_valid_length, (-2, 1, 1))).astype(dtype) # (B, query_length, 1)
mask = mem_mask * data_mask
else:
query_length_ones = F.np.ones_like(F.npx.arange_like(data, axis=1)) # (query_length,)
query_length_ones = F.np.ones_like(data_steps)
mask = query_length_ones.reshape((1, -1, 1)) * mem_mask
return mask

Expand Down Expand Up @@ -594,6 +623,7 @@ def __init__(self, query_units=None, num_heads=None, attention_dropout=0.0,
self._normalized = normalized
self._eps = eps
self._dtype = dtype
assert layout in ['NTK', 'NKT', 'TNK']
self._layout = layout
self._use_einsum = use_einsum
if self._query_units is not None:
Expand All @@ -604,6 +634,10 @@ def __init__(self, query_units=None, num_heads=None, attention_dropout=0.0,
else:
self._query_head_units = None

@property
def layout(self):
return self._layout

def hybrid_forward(self, F, query, key, value, mask=None, edge_scores=None):
return multi_head_dot_attn(F, query=query, key=key, value=value,
mask=mask, edge_scores=edge_scores,
Expand Down Expand Up @@ -764,6 +798,11 @@ def __init__(self, query_units,
else:
raise NotImplementedError('method="{}" is currently not supported!'.format(method))

@property
def layout(self) -> str:
"""Layout of the cell"""
return self._layout

def hybrid_forward(self, F, rel_positions, query=None):
"""
Expand Down
20 changes: 10 additions & 10 deletions src/gluonnlp/data/tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,20 @@
import json
from collections import OrderedDict
import abc
import sys
import warnings
import itertools
from typing import NewType
import sacremoses
import jieba
from uuid import uuid4
from .vocab import Vocab
from ..registry import TOKENIZER_REGISTRY
from ..utils.lazy_imports import try_import_subword_nmt, \
try_import_sentencepiece, \
try_import_huggingface_tokenizers, \
try_import_yttm, \
try_import_spacy, \
try_import_jieba
from ..utils.lazy_imports import try_import_subword_nmt,\
try_import_sentencepiece,\
try_import_huggingface_tokenizers,\
try_import_yttm,\
try_import_spacy,\
try_import_jieba


SentencesType = NewType('SentencesType', Union[str, List[str]])
TokensType = NewType('TokensType', Union[List[str], List[List[str]]])
Expand Down Expand Up @@ -553,10 +552,10 @@ class JiebaTokenizer(BaseTokenizerWithVocab):
"""

def __init__(self, ditionary=None, vocab: Optional[Vocab] = None):
def __init__(self, dictionary=None, vocab: Optional[Vocab] = None):
self._vocab = vocab
jieba = try_import_jieba()
self._tokenizer = jieba.Tokenizer(ditionary)
self._tokenizer = jieba.Tokenizer(dictionary)
self._tokenizer.initialize(self._tokenizer.dictionary)

def encode(self, sentences, output_type=str):
Expand Down Expand Up @@ -626,6 +625,7 @@ def __getstate__(self):
return d

def __setstate__(self, state):
jieba = try_import_jieba()
self._tokenizer = jieba.Tokenizer()
for k, v in state.items():
setattr(self._tokenizer, k, v)
Expand Down
5 changes: 3 additions & 2 deletions src/gluonnlp/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,9 +356,10 @@ def __init__(self, mode='erf'):

def hybrid_forward(self, F, x):
if self._mode == 'erf':
return x * 0.5 * (1.0 + F.npx.erf(x / math.sqrt(2.0)))
return F.npx.leaky_relu(x, act_type='gelu')
elif self._mode == 'tanh':
return 0.5 * x * (1.0 + F.np.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * (x ** 3))))
return 0.5 * x\
* (1.0 + F.np.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * (x ** 3))))
elif self._mode == 'sigmoid':
return x * F.npx.sigmoid(1.702 * x)
else:
Expand Down
Loading

0 comments on commit 3c87457

Please sign in to comment.