Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

[Numpy] Add "match_tokens_with_char_spans" + Enable downloading from S3 + Add Ubuntu test #1249

Merged
merged 18 commits into from
Jun 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
[run]
omit =
tests/*
conda/*
scripts/tests/*
scripts/*
concurrency =
multiprocessing
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/unittests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ jobs:
strategy:
fail-fast: false
matrix:
# TODO Add ubuntu test by "ubuntu-latest", Add windows test by using "windows-latest"
os: [macos-latest]
# TODO Add windows test by using "windows-latest"
os: [macos-latest, ubuntu-latest]
python-version: [ '3.6', '3.7', '3.8']
steps:
- name: Checkout repository
Expand All @@ -35,7 +35,7 @@ jobs:
python -m pip install --user --upgrade pip
python -m pip install --user setuptools pytest pytest-cov
python -m pip install --upgrade cython
python -m pip install --pre --user mxnet==2.0.0b20200604 -f https://dist.mxnet.io/python
python -m pip install --pre --user mxnet>=2.0.0b20200604 -f https://dist.mxnet.io/python
python -m pip install --user -e .[extras]
- name: Test project
run: |
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ First of all, install the latest MXNet. You may use the following commands:
```bash

# Install the version with CUDA 10.1
pip install -U --pre mxnet-cu101==2.0.0b20200604 -f https://dist.mxnet.io/python
pip install -U --pre mxnet-cu101>=2.0.0b20200604 -f https://dist.mxnet.io/python

# Install the cpu-only version
pip install -U --pre mxnet==2.0.0b20200604 -f https://dist.mxnet.io/python
pip install -U --pre mxnet>=2.0.0b20200604 -f https://dist.mxnet.io/python
```


Expand Down
10 changes: 5 additions & 5 deletions scripts/datasets/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Thus, the typical workflow for running experiments:

- Download and prepare data with scripts in [datasets](.).
In case you will need to preprocess the dataset, there are toolkits in [preprocess](../preprocess).
- Run the experiments in [scripts](../scripts)
- Run the experiments in [scripts](..)


## Available Datasets
Expand All @@ -24,16 +24,16 @@ In case you will need to preprocess the dataset, there are toolkits in [preproce
- [Text8](./language_modeling)
- [Enwiki8](./language_modeling)
- [Google Billion Words](./language_modeling)
- [Music Generation](TBA)
- [Music Generation](./music_generation)
- [LakhMIDI](./music_generation/README.md#lakh-midi)
- [MAESTRO](./music_generation/README.md#maestro)
- [Pretraining Corpus](./pretrain_corpus)
- [Wikipedia](./pretrain_corpus/README.md#wikipedia)
- [BookCorpus](./pretrain_corpus/README.md#bookcorpus)
- [OpenWebText](./pretrain_corpus/README.md#openwebtext)
- [General NLP Benchmarks](./general_benchmarks)
- [GLUE](./general_benchmarks/README.md#glue-benchmark)
- [SuperGLUE](./general_benchmarks/README.md#superglue-benchmark)
- [General NLP Benchmarks](./general_nlp_benchmark)
- [GLUE](./general_nlp_benchmark/README.md#glue-benchmark)
- [SuperGLUE](./general_nlp_benchmark/README.md#superglue-benchmark)

## Contribution Guide

Expand Down
29 changes: 10 additions & 19 deletions scripts/question_answering/squad_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Optional, List
from collections import namedtuple
import itertools
import bisect
import re
import numpy as np
import numpy.ma as ma
Expand All @@ -12,6 +11,7 @@
import json
import string
from gluonnlp.data.tokenizers import BaseTokenizerWithVocab
from gluonnlp.utils.preprocessing import match_tokens_with_char_spans
from typing import Tuple
from mxnet.gluon.utils import download

Expand Down Expand Up @@ -389,34 +389,25 @@ def convert_squad_example_to_feature(example: SquadExample,
gt_span_start_pos, gt_span_end_pos = None, None
token_answer_mismatch = False
unreliable_span = False
np_offsets = np.array(offsets)
if is_training and not example.is_impossible:
assert example.start_position >= 0 and example.end_position >= 0
# From the offsets, we locate the first offset that contains start_pos and the last offset
# that contains end_pos, i.e.
# offsets[lower_idx][0] <= start_pos < offsets[lower_idx][1]
# offsets[upper_idx][0] < end_pos <= offsets[upper_idx[1]
# We convert the character-level offsets to token-level offsets
# Also, if the answer after tokenization + detokenization is not the same as the original
# answer,
offsets_lower = [offset[0] for offset in offsets]
offsets_upper = [offset[1] for offset in offsets]
# answer, we try to localize the answer text and do a rematch
candidates = [(example.start_position, example.end_position)]
all_possible_start_pos = {example.start_position}
find_all_candidates = False
lower_idx, upper_idx = None, None
first_lower_idx, first_upper_idx = None, None
while len(candidates) > 0:
start_position, end_position = candidates.pop()
if end_position > offsets_upper[-1] or start_position < offsets_lower[0]:
# Detect the out-of-boundary case
warnings.warn('The selected answer is not covered by the tokens! '
'Use the end_position. '
'qas_id={}, context_text={}, start_pos={}, end_pos={}, '
'offsets={}'.format(example.qas_id, context_text,
start_position, end_position, offsets))
end_position = min(offsets_upper[-1], end_position)
start_position = max(offsets_upper[0], start_position)
lower_idx = bisect.bisect(offsets_lower, start_position) - 1
upper_idx = bisect.bisect_left(offsets_upper, end_position)
# Match the token offsets
token_start_ends = match_tokens_with_char_spans(np_offsets,
np.array([[start_position,
end_position]]))
lower_idx = int(token_start_ends[0][0])
upper_idx = int(token_start_ends[0][1])
if not find_all_candidates:
first_lower_idx = lower_idx
first_upper_idx = upper_idx
Expand Down
2 changes: 1 addition & 1 deletion src/gluonnlp/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from . import config
from . import lazy_imports
from . import misc
from . import preprocessing
from . import registry
from . import testing
from .parameter import *
from .misc import *
15 changes: 14 additions & 1 deletion src/gluonnlp/utils/lazy_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
'try_import_scipy',
'try_import_mwparserfromhell',
'try_import_fasttext',
'try_import_langid']
'try_import_langid',
'try_import_boto3']


def try_import_sentencepiece():
Expand Down Expand Up @@ -132,3 +133,15 @@ def try_import_langid():
raise ImportError('"langid" is not installed. You must install langid in order to use the'
' functionality. You may try to use `pip install langid`.')
return langid


def try_import_boto3():
try:
import boto3
except ImportError:
raise ImportError('"boto3" is not installed. To enable fast downloading in EC2. You should '
'install boto3 and correctly configure the S3. '
'See https://boto3.readthedocs.io/ for more information. '
'If you are using EC2, downloading from s3:// will '
'be multiple times faster than using the traditional http/https URL.')
return boto3
65 changes: 48 additions & 17 deletions src/gluonnlp/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,15 @@
import tqdm
except ImportError:
tqdm = None
from .lazy_imports import try_import_boto3
from mxnet.gluon.utils import shape_is_known, replace_file
from collections import OrderedDict
import glob as _glob


S3_PREFIX = 's3://'


def glob(url, separator=','):
"""Return a list of paths matching a pathname pattern.

Expand Down Expand Up @@ -396,6 +401,15 @@ def download(url: str,
fname
The file path of the downloaded file.
"""
is_s3 = url.startswith(S3_PREFIX)
if is_s3:
boto3 = try_import_boto3()
s3 = boto3.resource('s3')
components = url[len(S3_PREFIX):].split('/')
if len(components) < 2:
raise ValueError('Invalid S3 url. Received url={}'.format(url))
s3_bucket_name = components[0]
s3_key = '/'.join(components[1:])
if path is None:
fname = url.split('/')[-1]
# Empty filenames are invalid
Expand Down Expand Up @@ -424,23 +438,40 @@ def download(url: str,
# pylint: disable=W0703
try:
print('Downloading {} from {}...'.format(fname, url))
r = requests.get(url, stream=True, verify=verify_ssl)
if r.status_code != 200:
raise RuntimeError('Failed downloading url {}'.format(url))
# create uuid for temporary files
random_uuid = str(uuid.uuid4())
total_size = int(r.headers.get('content-length', 0))
chunk_size = 1024
if tqdm is not None:
t = tqdm.tqdm(total=total_size, unit='iB', unit_scale=True)
with open('{}.{}'.format(fname, random_uuid), 'wb') as f:
for chunk in r.iter_content(chunk_size=chunk_size):
if chunk: # filter out keep-alive new chunks
if tqdm is not None:
t.update(len(chunk))
f.write(chunk)
if tqdm is not None:
t.close()
if is_s3:
response = s3.meta.client.head_object(Bucket=s3_bucket_name,
Key=s3_key)
total_size = int(response.get('ContentLength', 0))
random_uuid = str(uuid.uuid4())
tmp_path = '{}.{}'.format(fname, random_uuid)
if tqdm is not None:
def hook(t_obj):
def inner(bytes_amount):
t_obj.update(bytes_amount)
return inner
with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True) as t:
s3.meta.client.download_file(s3_bucket_name, s3_key, tmp_path,
Callback=hook(t))
else:
s3.meta.client.download_file(s3_bucket_name, s3_key, tmp_path)
else:
r = requests.get(url, stream=True, verify=verify_ssl)
if r.status_code != 200:
raise RuntimeError('Failed downloading url {}'.format(url))
# create uuid for temporary files
random_uuid = str(uuid.uuid4())
total_size = int(r.headers.get('content-length', 0))
chunk_size = 1024
if tqdm is not None:
t = tqdm.tqdm(total=total_size, unit='iB', unit_scale=True)
with open('{}.{}'.format(fname, random_uuid), 'wb') as f:
for chunk in r.iter_content(chunk_size=chunk_size):
if chunk: # filter out keep-alive new chunks
if tqdm is not None:
t.update(len(chunk))
f.write(chunk)
if tqdm is not None:
t.close()
# if the target file exists(created by other processes)
# and have the same hash with target file
# delete the temporary file
Expand Down
59 changes: 59 additions & 0 deletions src/gluonnlp/utils/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,62 @@ def get_trimmed_lengths(lengths: List[int],
return trimmed_lengths
else:
return np.minimum(lengths, max_length)


def match_tokens_with_char_spans(token_offsets: np.ndarray,
spans: np.ndarray) -> np.ndarray:
"""Match the span offsets with the character-level offsets.

For each span, we perform the following:

1: Cutoff the boundary

span[0] = max(span[0], token_offsets[0, 0])
span[1] = min(span[1], token_offsets[-1, 1])

2: Find start + end

We try to select the smallest number of tokens that cover the entity, i.e.,
we will find start + end, in which tokens[start:end + 1] covers the span.

We will use the following algorithm:

For "start", we search for
token_offsets[start, 0] <= span[0] < token_offsets[start + 1, 0]

For "end", we search for:
token_offsets[end - 1, 1] < spans[1] <= token_offsets[end, 1]

Parameters
----------
token_offsets
The offsets of the input tokens. Must be sorted.
That is, it will satisfy
1. token_offsets[i][0] <= token_offsets[i][1]
2. token_offsets[i][0] <= token_offsets[i + 1][0]
3. token_offsets[i][1] <= token_offsets[i + 1][1]
Shape (#num_tokens, 2)
spans
The character-level offsets (begin/end) of the selected spans.
Shape (#spans, 2)

Returns
-------
token_start_ends
The token-level starts and ends. The end will also be used.
Shape (#spans, 2)
"""
offsets_starts = token_offsets[:, 0]
offsets_ends = token_offsets[:, 1]
span_char_starts = spans[:, 0]
span_char_ends = spans[:, 1]

# Truncate the span
span_char_starts = np.maximum(offsets_starts[0], span_char_starts)
span_char_ends = np.minimum(offsets_ends[-1], span_char_ends)

# Search for valid start + end
span_token_starts = np.searchsorted(offsets_starts, span_char_starts, side='right') - 1
span_token_ends = np.searchsorted(offsets_ends, span_char_ends, side='left')
return np.concatenate((np.expand_dims(span_token_starts, axis=-1),
np.expand_dims(span_token_ends, axis=-1)), axis=-1)
42 changes: 21 additions & 21 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,25 @@ def test_list_backbone_names():
assert len(list_backbone_names()) > 0


@pytest.mark.parametrize('name', list_backbone_names())
def test_get_backbone(name):
with tempfile.TemporaryDirectory() as root:
model_cls, cfg, tokenizer, local_params_path, _ = get_backbone(name, root=root)
net = model_cls.from_cfg(cfg)
net.load_parameters(local_params_path)
net.hybridize()
num_params, num_fixed_params = count_parameters(net.collect_params())
assert num_params > 0
def test_get_backbone():
for name in list_backbone_names():
with tempfile.TemporaryDirectory() as root:
model_cls, cfg, tokenizer, local_params_path, _ = get_backbone(name, root=root)
net = model_cls.from_cfg(cfg)
net.load_parameters(local_params_path)
net.hybridize()
num_params, num_fixed_params = count_parameters(net.collect_params())
assert num_params > 0

# Test for model export + save
batch_size = 1
sequence_length = 16
inputs = mx.np.random.randint(0, 10, (batch_size, sequence_length))
token_types = mx.np.random.randint(0, 2, (batch_size, sequence_length))
valid_length = mx.np.random.randint(1, 10, (batch_size,))
if 'roberta' in name or 'xlmr' in name:
out = net(inputs, valid_length)
else:
out = net(inputs, token_types, valid_length)
mx.npx.waitall()
net.export(os.path.join(root, 'model'))
# Test for model export + save
batch_size = 1
sequence_length = 4
inputs = mx.np.random.randint(0, 10, (batch_size, sequence_length))
token_types = mx.np.random.randint(0, 2, (batch_size, sequence_length))
valid_length = mx.np.random.randint(1, sequence_length, (batch_size,))
if 'roberta' in name or 'xlmr' in name:
out = net(inputs, valid_length)
else:
out = net(inputs, token_types, valid_length)
mx.npx.waitall()
net.export(os.path.join(root, 'model'))
Loading