Skip to content

Commit

Permalink
Squashed commit of the following:
Browse files Browse the repository at this point in the history
commit 232e0b6
Author: ZheyuYe <zheyu.ye1995@gmail.com>
Date:   Thu Jul 30 01:05:17 2020 +0800

    update

commit 995e5d7
Author: ZheyuYe <zheyu.ye1995@gmail.com>
Date:   Thu Jul 30 01:01:56 2020 +0800

    fix

commit 9623240
Author: ZheyuYe <zheyu.ye1995@gmail.com>
Date:   Thu Jul 30 00:52:17 2020 +0800

    fix

commit d9c4140
Author: ZheyuYe <zheyu.ye1995@gmail.com>
Date:   Wed Jul 29 23:07:10 2020 +0800

    fix transformer

commit e49fbe1
Author: ZheyuYe <zheyu.ye1995@gmail.com>
Date:   Wed Jul 29 22:18:12 2020 +0800

    update

commit 1f75b26
Author: ZheyuYe <zheyu.ye1995@gmail.com>
Date:   Wed Jul 29 22:04:08 2020 +0800

    test bart

commit 5bab516
Author: ZheyuYe <zheyu.ye1995@gmail.com>
Date:   Wed Jul 29 21:34:47 2020 +0800

    fix cfg

commit 6c62a29
Merge: 3366cf3 033214e
Author: ZheyuYe <zheyu.ye1995@gmail.com>
Date:   Wed Jul 29 21:33:10 2020 +0800

    Merge remote-tracking branch 'upstream/numpy' into bart

commit 033214e
Author: Xingjian Shi <xshiab@connect.ust.hk>
Date:   Wed Jul 29 00:36:57 2020 -0700

    [Numpy] Fix SQuAD + Fix GLUE downloading (dmlc#1280)

    * Update run_squad.py

    * Update run_squad.py

    * Update prepare_glue.py

commit 3c87457
Author: Xingjian Shi <xshiab@connect.ust.hk>
Date:   Tue Jul 28 18:03:21 2020 -0700

    Add layout + compute_layout support: TransformerNMT, BERT, ALBERT, ELECTRA, MobileBERT, RoBERTA, XLMR (dmlc#1258)

    * Add layout support

    * fix test

    * Update transformer.py

    * Update transformer.py

    * Update README.md

    * try to add set_layout

    * update test case

    * fix

    * update

    * update

    * update

    * Update bert.py

    * fix bug

    * update

    * Update test_models_bert.py

    * Update tokenizers.py

    * add compute layout

    * Update xlmr.py

    * Update test_models_bert.py

    * revise test cases

    * Update layers.py

    * move jieba to try import

    * fix

    * Update transformer.py

    * fix

    * Update bert.py

    * Update setup.py

    * Update test_models_bert.py

    * Update test_models_bert.py

    * fix

    * update

    * Revise

    * Update electra.py

    * Update electra.py

    * Update test_models_electra.py

    * fix

    * fix bug

    * Update test_models_albert.py

    * add more testcases

    * fix

    * Update albert.py

    * Update albert.py

    * fix bug

    * fix testcase

    * Update test_models_electra.py

    * Update bert.py

    * update

    * Update test_models_electra.py

    * Update mobilebert.py

    * Update mobilebert.py

    * update mobilebert

    * Update test_models_mobilebert.py

    * Update mobilebert.py

    * fix bug

    * Update roberta.py

    * fix roberta

    * update

    * update

    * fix import

    * fix bug

    * update

    * reduce test workloads

    * address comment

    * address comment

commit 4d43f82
Author: Sheng Zha <szha@users.noreply.github.com>
Date:   Mon Jul 27 20:21:00 2020 -0700

    add subversion/wget to docker, add readme (dmlc#1279)

commit d76897b
Author: phile <phile_999@126.com>
Date:   Tue Jul 28 10:10:13 2020 +0800

    Add embedding related methods in numpy version (dmlc#1263)

    * A draft for embedding

    * fix embed_loader

    * add hyperbolic space and some updates

    * revise evaluation

    * fix

    * simple fixes

    * move l2norm to op.py

    * new features

    * fix

    * update

    * add tests, update

    * newline
  • Loading branch information
zheyuye committed Jul 29, 2020
1 parent 3366cf3 commit a8853f9
Show file tree
Hide file tree
Showing 35 changed files with 3,880 additions and 901 deletions.
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,17 @@ This is a work-in-progress.
First of all, install the latest MXNet. You may use the following commands:

```bash
# Install the version with CUDA 10.0
pip install -U --pre "mxnet-cu100>=2.0.0b20200716" -f https://dist.mxnet.io/python

# Install the version with CUDA 10.1
pip install -U --pre mxnet-cu101>=2.0.0b20200716 -f https://dist.mxnet.io/python
pip install -U --pre "mxnet-cu101>=2.0.0b20200716" -f https://dist.mxnet.io/python

# Install the version with CUDA 10.2
pip install -U --pre "mxnet-cu102>=2.0.0b20200716" -f https://dist.mxnet.io/python

# Install the cpu-only version
pip install -U --pre mxnet>=2.0.0b20200716 -f https://dist.mxnet.io/python
pip install -U --pre "mxnet>=2.0.0b20200716" -f https://dist.mxnet.io/python
```


Expand Down
96 changes: 62 additions & 34 deletions scripts/datasets/general_nlp_benchmark/prepare_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,23 @@ def read_tsv_glue(tsv_file, num_skip=1, keep_column_names=False):
nrows = len(elements)
else:
assert nrows == len(elements)
return pd.DataFrame(out, columns=column_names)
df = pd.DataFrame(out, columns=column_names)
series_l = []
for col_name in df.columns:
idx = df[col_name].first_valid_index()
val = df[col_name][idx]
if isinstance(val, str):
try:
dat = pd.to_numeric(df[col_name])
series_l.append(dat)
continue
except ValueError:
pass
finally:
pass
series_l.append(df[col_name])
new_df = pd.DataFrame({name: series for name, series in zip(df.columns, series_l)})
return new_df


def read_jsonl_superglue(jsonl_file):
Expand Down Expand Up @@ -157,6 +173,13 @@ def read_sts(dir_path):
else:
df = df[[7, 8, 1, 9]]
df.columns = ['sentence1', 'sentence2', 'genre', 'score']
genre_l = []
for ele in df['genre'].tolist():
if ele == 'main-forum':
genre_l.append('main-forums')
else:
genre_l.append(ele)
df['genre'] = pd.Series(genre_l)
df_dict[fold] = df
return df_dict, None

Expand Down Expand Up @@ -320,8 +343,8 @@ def read_rte_superglue(dir_path):
def read_wic(dir_path):
df_dict = dict()
meta_data = dict()
meta_data['entities1'] = {'type': 'entity', 'parent': 'sentence1'}
meta_data['entities2'] = {'type': 'entity', 'parent': 'sentence2'}
meta_data['entities1'] = {'type': 'entity', 'attrs': {'parent': 'sentence1'}}
meta_data['entities2'] = {'type': 'entity', 'attrs': {'parent': 'sentence2'}}

for fold in ['train', 'val', 'test']:
if fold != 'test':
Expand All @@ -340,13 +363,13 @@ def read_wic(dir_path):
end2 = row['end2']
if fold == 'test':
out.append([sentence1, sentence2,
(start1, end1),
(start2, end2)])
{'start': start1, 'end': end1},
{'start': start2, 'end': end2}])
else:
label = row['label']
out.append([sentence1, sentence2,
(start1, end1),
(start2, end2),
{'start': start1, 'end': end1},
{'start': start2, 'end': end2},
label])
df = pd.DataFrame(out, columns=columns)
df_dict[fold] = df
Expand All @@ -357,8 +380,8 @@ def read_wsc(dir_path):
df_dict = dict()
tokenizer = WhitespaceTokenizer()
meta_data = dict()
meta_data['noun'] = {'type': 'entity', 'parent': 'text'}
meta_data['pronoun'] = {'type': 'entity', 'parent': 'text'}
meta_data['noun'] = {'type': 'entity', 'attrs': {'parent': 'text'}}
meta_data['pronoun'] = {'type': 'entity', 'attrs': {'parent': 'text'}}
for fold in ['train', 'val', 'test']:
jsonl_path = os.path.join(dir_path, '{}.jsonl'.format(fold))
df = read_jsonl_superglue(jsonl_path)
Expand All @@ -374,20 +397,20 @@ def read_wsc(dir_path):
span2_text = target['span2_text']
# Build entity
# list of entities
# 'entity': {'start': 0, 'end': 100}
# 'entities': {'start': 0, 'end': 100}
tokens, offsets = tokenizer.encode_with_offsets(text, str)
pos_start1 = offsets[span1_index][0]
pos_end1 = pos_start1 + len(span1_text)
pos_start2 = offsets[span2_index][0]
pos_end2 = pos_start2 + len(span2_text)
if fold == 'test':
samples.append({'text': text,
'noun': (pos_start1, pos_end1),
'pronoun': (pos_start2, pos_end2)})
'noun': {'start': pos_start1, 'end': pos_end1},
'pronoun': {'start': pos_start2, 'end': pos_end2}})
else:
samples.append({'text': text,
'noun': (pos_start1, pos_end1),
'pronoun': (pos_start2, pos_end2),
'noun': {'start': pos_start1, 'end': pos_end1},
'pronoun': {'start': pos_start2, 'end': pos_end2},
'label': label})
df = pd.DataFrame(samples)
df_dict[fold] = df
Expand All @@ -406,8 +429,8 @@ def read_boolq(dir_path):
def read_record(dir_path):
df_dict = dict()
meta_data = dict()
meta_data['entities'] = {'type': 'entity', 'parent': 'text'}
meta_data['answers'] = {'type': 'entity', 'parent': 'text'}
meta_data['entities'] = {'type': 'entity', 'attrs': {'parent': 'text'}}
meta_data['answers'] = {'type': 'entity', 'attrs': {'parent': 'text'}}
for fold in ['train', 'val', 'test']:
if fold != 'test':
columns = ['source', 'text', 'entities', 'query', 'answers']
Expand All @@ -422,15 +445,11 @@ def read_record(dir_path):
passage = row['passage']
text = passage['text']
entities = passage['entities']
entities = [(ele['start'], ele['end']) for ele in entities]
entities = [{'start': ele['start'], 'end': ele['end']} for ele in entities]
for qas in row['qas']:
query = qas['query']
if fold != 'test':
answer_entities = []
for answer in qas['answers']:
start = answer['start']
end = answer['end']
answer_entities.append((start, end))
answer_entities = qas['answers']
out.append((source, text, entities, query, answer_entities))
else:
out.append((source, text, entities, query))
Expand Down Expand Up @@ -518,11 +537,15 @@ def format_mrpc(data_dir):
os.makedirs(mrpc_dir, exist_ok=True)
mrpc_train_file = os.path.join(mrpc_dir, "msr_paraphrase_train.txt")
mrpc_test_file = os.path.join(mrpc_dir, "msr_paraphrase_test.txt")
download(GLUE_TASK2PATH["mrpc"]['train'], mrpc_train_file)
download(GLUE_TASK2PATH["mrpc"]['test'], mrpc_test_file)
download(GLUE_TASK2PATH["mrpc"]['train'], mrpc_train_file,
sha1_hash=_URL_FILE_STATS[GLUE_TASK2PATH["mrpc"]['train']])
download(GLUE_TASK2PATH["mrpc"]['test'], mrpc_test_file,
sha1_hash=_URL_FILE_STATS[GLUE_TASK2PATH["mrpc"]['test']])
assert os.path.isfile(mrpc_train_file), "Train data not found at %s" % mrpc_train_file
assert os.path.isfile(mrpc_test_file), "Test data not found at %s" % mrpc_test_file
download(GLUE_TASK2PATH["mrpc"]['dev'], os.path.join(mrpc_dir, "dev_ids.tsv"))
download(GLUE_TASK2PATH["mrpc"]['dev'],
os.path.join(mrpc_dir, "dev_ids.tsv"),
sha1_hash=_URL_FILE_STATS[GLUE_TASK2PATH["mrpc"]['dev']])

dev_ids = []
with open(os.path.join(mrpc_dir, "dev_ids.tsv"), encoding="utf8") as ids_fh:
Expand Down Expand Up @@ -575,7 +598,7 @@ def get_tasks(benchmark, task_names):
@DATA_PARSER_REGISTRY.register('prepare_glue')
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--benchmark", choices=['glue', 'superglue', 'sts'],
parser.add_argument("--benchmark", choices=['glue', 'superglue'],
default='glue', type=str)
parser.add_argument("-d", "--data_dir", help="directory to save data to", type=str,
default=None)
Expand Down Expand Up @@ -618,39 +641,44 @@ def main(args):
base_dir = os.path.join(args.data_dir, 'rte_diagnostic')
os.makedirs(base_dir, exist_ok=True)
download(TASK2PATH['diagnostic'][0],
path=os.path.join(base_dir, 'diagnostic.tsv'))
path=os.path.join(base_dir, 'diagnostic.tsv'),
sha1_hash=_URL_FILE_STATS[TASK2PATH['diagnostic'][0]])
download(TASK2PATH['diagnostic'][1],
path=os.path.join(base_dir, 'diagnostic-full.tsv'))
path=os.path.join(base_dir, 'diagnostic-full.tsv'),
sha1_hash=_URL_FILE_STATS[TASK2PATH['diagnostic'][1]])
df = reader(base_dir)
df.to_pickle(os.path.join(base_dir, 'diagnostic-full.pd.pkl'))
df.to_parquet(os.path.join(base_dir, 'diagnostic-full.parquet'))
else:
for key, name in [('broadcoverage-diagnostic', 'AX-b'),
('winogender-diagnostic', 'AX-g')]:
data_file = os.path.join(args.cache_path, "{}.zip".format(key))
url = TASK2PATH[key]
reader = TASK2READER[key]
download(url, data_file)
download(url, data_file, sha1_hash=_URL_FILE_STATS[url])
with zipfile.ZipFile(data_file) as zipdata:
zipdata.extractall(args.data_dir)
df = reader(os.path.join(args.data_dir, name))
df.to_pickle(os.path.join(args.data_dir, name, '{}.pd.pkl'.format(name)))
df.to_parquet(os.path.join(args.data_dir, name, '{}.parquet'.format(name)))
elif task == 'mrpc':
reader = TASK2READER[task]
format_mrpc(args.data_dir)
df_dict, meta_data = reader(os.path.join(args.data_dir, 'mrpc'))
for key, df in df_dict.items():
if key == 'val':
key = 'dev'
df.to_pickle(os.path.join(args.data_dir, 'mrpc', '{}.pd.pkl'.format(key)))
df.to_parquet(os.path.join(args.data_dir, 'mrpc', '{}.parquet'.format(key)))
with open(os.path.join(args.data_dir, 'mrpc', 'metadata.json'), 'w') as f:
json.dump(meta_data, f)
else:
# Download data
data_file = os.path.join(args.cache_path, "{}.zip".format(task))
url = TASK2PATH[task]
reader = TASK2READER[task]
download(url, data_file)
download(url, data_file, sha1_hash=_URL_FILE_STATS[url])
base_dir = os.path.join(args.data_dir, task)
if os.path.exists(base_dir):
print('Found!')
continue
zip_dir_name = None
with zipfile.ZipFile(data_file) as zipdata:
if zip_dir_name is None:
Expand All @@ -662,7 +690,7 @@ def main(args):
for key, df in df_dict.items():
if key == 'val':
key = 'dev'
df.to_pickle(os.path.join(base_dir, '{}.pd.pkl'.format(key)))
df.to_parquet(os.path.join(base_dir, '{}.parquet'.format(key)))
if meta_data is not None:
with open(os.path.join(base_dir, 'metadata.json'), 'w') as f:
json.dump(meta_data, f)
Expand Down
4 changes: 2 additions & 2 deletions scripts/question_answering/run_squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,8 +563,8 @@ def train(args):
segment_ids = sample.segment_ids.as_in_ctx(ctx) if use_segmentation else None
valid_length = sample.valid_length.as_in_ctx(ctx)
p_mask = sample.masks.as_in_ctx(ctx)
gt_start = sample.gt_start.as_in_ctx(ctx)
gt_end = sample.gt_end.as_in_ctx(ctx)
gt_start = sample.gt_start.as_in_ctx(ctx).astype(np.int32)
gt_end = sample.gt_end.as_in_ctx(ctx).astype(np.int32)
is_impossible = sample.is_impossible.as_in_ctx(ctx).astype(np.int32)
batch_idx = mx.np.arange(tokens.shape[0], dtype=np.int32, ctx=ctx)
p_mask = 1 - p_mask # In the network, we use 1 --> no_mask, 0 --> mask
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def find_version(*file_paths):
'scripts',
)),
package_dir={"": "src"},
package_data={'': [os.path.join('models', 'model_zoo_checksums', '*.txt'),
os.path.join('cli', 'data', 'url_checksums', '*.txt')]},
zip_safe=True,
include_package_data=True,
install_requires=requirements,
Expand Down
1 change: 1 addition & 0 deletions src/gluonnlp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
from . import optimizer
from . import registry
from . import sequence_sampler
from . import embedding

0 comments on commit a8853f9

Please sign in to comment.