Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Sentence Order Prediction #1061

Merged
merged 183 commits into from Apr 24, 2020
Merged
Show file tree
Hide file tree
Changes from 176 commits
Commits
Show all changes
183 commits
Select commit Hold shift + click to select a range
430f942
misc run scripts
phu-pmh Oct 30, 2019
39603c3
sbatch
phu-pmh Oct 31, 2019
9b324f9
sweep scripts
phu-pmh Nov 4, 2019
d3cc769
Merge branch 'master' of https://github.com/nyu-mll/jiant
phu-pmh Nov 5, 2019
00bc40c
Merge branch 'master' of https://github.com/nyu-mll/jiant
phu-pmh Nov 9, 2019
4e297b1
update
phu-pmh Nov 9, 2019
b75d0f5
qa
phu-pmh Nov 10, 2019
1aadf48
update
phu-pmh Nov 10, 2019
8993b9e
Merge branch 'master' of https://github.com/nyu-mll/jiant
phu-pmh Nov 10, 2019
a3f10e2
update
phu-pmh Nov 13, 2019
aa0d8b4
update
phu-pmh Nov 13, 2019
275d7a3
Merge branch 'master' of https://github.com/nyu-mll/jiant
phu-pmh Nov 13, 2019
4b6b939
Merge branch 'master' of https://github.com/nyu-mll/jiant
phu-pmh Nov 16, 2019
7252ea5
update
phu-pmh Nov 16, 2019
f0d9c56
update
phu-pmh Nov 20, 2019
00223c6
Merge branch 'master' of https://github.com/nyu-mll/jiant
phu-pmh Nov 27, 2019
b0a8ec3
sb file
phu-pmh Dec 12, 2019
c4d2601
moving update_metrics to outside scope of dataparallel
Jan 14, 2020
acb9d24
fixing micro_avg calculation
Jan 16, 2020
8bdec95
undo debugging
Jan 16, 2020
0d879b1
Merge branch 'master' of https://github.com/nyu-mll/jiant
phu-pmh Jan 17, 2020
4f0a169
Merge branch 'master' into fix_dataparallel_metric_calculation
Jan 17, 2020
5bb8389
Fixing tests, moving update_metrics out of other tasks
Jan 17, 2020
fb59ecc
Merge branch 'master' of https://github.com/nyu-mll/jiant into fix_da…
Jan 17, 2020
04dbbda
Merge branch 'fix_dataparallel_metric_calculation' of https://github.…
Jan 17, 2020
3ddf564
remove extraneous change
Jan 17, 2020
e588909
MLM task
phu-pmh Jan 21, 2020
dfa9fd9
Added MLM task
phu-pmh Jan 21, 2020
46182a9
update
phu-pmh Jan 24, 2020
607bcd2
Merge branch 'MLM' of https://github.com/nyu-mll/jiant into MLM
phu-pmh Jan 24, 2020
d1daf23
fix multiple choice dataparallel forward
Jan 25, 2020
9539302
Merge branch 'master' into fix_dataparallel_metric_calculation
Jan 25, 2020
fc5f026
update
phu-pmh Jan 27, 2020
ce7f5c2
add _mask_id to transformers
HaokunLiu Jan 28, 2020
ffc7354
Update
phu-pmh Jan 30, 2020
c50d75b
Merge branch 'master' of https://github.com/nyu-mll/jiant into MLM
phu-pmh Jan 30, 2020
9649224
Merge branch 'master' into fix_dataparallel_metric_calculation
Jan 30, 2020
69a9364
MLM update
phu-pmh Jan 30, 2020
697d62c
Merge branch 'add-_mask_id-to-transformers' into MLM
HaokunLiu Jan 30, 2020
a4666da
adding update_metrics abstraction
Jan 30, 2020
fa13f6f
delete update_metrics_ notation
Jan 30, 2020
6b61e8b
fixed wrong index problem
phu-pmh Jan 30, 2020
3e10e3b
Merge branch 'MLM' of https://github.com/nyu-mll/jiant into MLM
phu-pmh Jan 30, 2020
afc0938
removed unrelated files
phu-pmh Jan 31, 2020
dcff7e7
removed unrelated files
phu-pmh Jan 31, 2020
1c1e6fb
removed unrelated files
phu-pmh Jan 31, 2020
f25ee99
fix PEP8
phu-pmh Jan 31, 2020
3f35212
Fixed get_pretained_lm_head for BERT and ALBERT
phu-pmh Jan 31, 2020
fc85270
spelling check
Feb 1, 2020
321bda8
black formatting
Feb 1, 2020
ae92b78
fixing tests
Feb 2, 2020
4f36878
bug fix
phu-pmh Feb 3, 2020
0467871
Adding batch_size constraints to multi-GPU setting
Feb 5, 2020
e3c5c79
adding documentation
Feb 5, 2020
6e96fd0
adding batch size test
Feb 5, 2020
845bf4f
Merge branch 'master' of https://github.com/nyu-mll/jiant into fix_da…
Feb 5, 2020
b41c268
black correct version
Feb 5, 2020
6f82412
Fixing batch size assertion
Feb 5, 2020
c749ea7
generalize batch size assertion for more than 2 GPU setting
Feb 5, 2020
73222a5
reducing label loops in code
Feb 6, 2020
fe39525
fixing span forward
Feb 8, 2020
745836d
Fixing span prediction forward for multi-GPU
invalid-email-address Feb 8, 2020
14caaab
fix commonsenseQA forward
invalid-email-address Feb 8, 2020
4271a7a
Merge branch 'master' of https://github.com/nyu-mll/jiant into MLM
phu-pmh Feb 10, 2020
918c0df
MLM
phu-pmh Feb 10, 2020
5ed0691
adding function documentation
Feb 11, 2020
ffac8bf
Merge branch 'master' into fix_dataparallel_metric_calculation
Feb 11, 2020
fe86d96
resolving nits, fixing seq_gen forward
Feb 11, 2020
eee439f
Merge branch 'fix_dataparallel_metric_calculation' of https://github.…
Feb 11, 2020
b61fa7c
remove nit
Feb 11, 2020
55312e8
fixing batch_size assert and SpanPrediction task
Feb 12, 2020
7d165cf
Remove debugging
Feb 12, 2020
52f66c7
Fix batch size mismatch multi-GPU test
Feb 12, 2020
a0220f8
Fix order of assert checking for batch size mismatch
Feb 12, 2020
fe89674
mlm training
phu-pmh Feb 12, 2020
2218e5b
update
phu-pmh Feb 14, 2020
cd75715
Merge branch 'fix_dataparallel_metric_calculation' of https://github.…
phu-pmh Feb 14, 2020
58b2914
sbatch
phu-pmh Feb 16, 2020
052b1c0
update
phu-pmh Feb 17, 2020
b26927a
data parallel
phu-pmh Feb 17, 2020
cd4b5a6
update data parallel stuffs
phu-pmh Feb 19, 2020
0d6d691
update MLM
phu-pmh Feb 20, 2020
b3617fa
using sequencelabel, using 1 paragraph per example
Feb 23, 2020
0af6476
update label mapping
phu-pmh Feb 24, 2020
e9f863c
adding exmaples-porportion-mixing
Feb 24, 2020
89e44c5
changing dataloader to work with wikitext103
Feb 24, 2020
0752771
weight sampling
Feb 24, 2020
5482ac2
add early stopping only onb one task
Mar 5, 2020
6d85b27
commit
phu-pmh Mar 6, 2020
d67e195
Merge branch 'MLM' of https://github.com/nyu-mll/jiant into MLM
phu-pmh Mar 6, 2020
921e717
Merge branch 'master' of https://github.com/nyu-mll/jiant into MLM
Mar 8, 2020
05d5750
Cleaning up code
Mar 8, 2020
ddcd357
Removing unecessarily tracked git folders
Mar 8, 2020
9e4e3a7
Removing unnecesary changes
Mar 8, 2020
b9b5f57
revert README
Mar 8, 2020
6b4c9d5
revert README.md again
Mar 8, 2020
35130ca
Making more general for Transformer-based embedders
Mar 8, 2020
20de779
torch.uint8 -> torch.bool
Mar 8, 2020
4020c81
Merge branch 'MLM' of https://github.com/nyu-mll/jiant into MLM
Mar 8, 2020
09f5903
Fixing indexing issues
Mar 8, 2020
4f45826
get rid of unecessary changes
Mar 8, 2020
8ac8c70
black cleanup
Mar 8, 2020
6cee66e
update
phu-pmh Mar 8, 2020
3709696
Prevent updating update_metrics twice in one step
Mar 10, 2020
3fb4e3e
ALBERT SOP update
phu-pmh Mar 16, 2020
a56b7c7
update
phu-pmh Mar 18, 2020
b84da1d
update
phu-pmh Mar 18, 2020
2a19c2c
update
phu-pmh Mar 18, 2020
b1ac702
update
phu-pmh Mar 20, 2020
2ba9651
Fixing SOP to work with jiant
Mar 20, 2020
214abb3
delete debugging
Mar 20, 2020
df0e556
tying pooler weights from ALBERT
Mar 22, 2020
c9ef3f4
fixed SOP tie weight, and MLM vocab error
phu-pmh Mar 22, 2020
b5c9469
Merge branch 'add_sop' of https://github.com/nyu-mll/jiant into add_sop
phu-pmh Mar 22, 2020
d671173
dataset update for SOP
phu-pmh Mar 24, 2020
5f9cc19
removed pdb
phu-pmh Mar 24, 2020
6333b66
Fix ALBERT -> MLM problem, reduce amount of times get_data_iter is ca…
Mar 24, 2020
fab7d8c
Merge branch 'add_sop' of https://github.com/nyu-mll/jiant into add_sop
Mar 24, 2020
e77f096
delete debugging
Mar 24, 2020
8587bb5
adding utf-8 encoding
Mar 24, 2020
58bb44c
Removing two-layer MLM class hierarchy
Mar 24, 2020
c4723fc
MLM indexing bug
phu-pmh Mar 26, 2020
f109381
Merge branch 'add_sop' of https://github.com/nyu-mll/jiant into add_sop
Mar 27, 2020
23f440b
fixing MLM error
phu-pmh Mar 27, 2020
a17a612
Merge branch 'add_sop' of https://github.com/nyu-mll/jiant into add_sop
Mar 28, 2020
ffc2740
removed rest of the shifting code
Mar 28, 2020
060471c
adding
Mar 29, 2020
4430441
Merge branch 'add_sop' of https://github.com/nyu-mll/jiant into add_sop
Apr 2, 2020
8df2645
fixing batch[inputs] error
Apr 2, 2020
a55b447
change corpus to wikipedia raw
phu-pmh Apr 9, 2020
238fdaf
change corpus to wikipedia raw
phu-pmh Apr 9, 2020
9c42d98
Merge branch 'master' of https://github.com/nyu-mll/jiant into add_sop
Apr 10, 2020
4a6a9d4
Finish merge
Apr 10, 2020
dc59491
style
Apr 10, 2020
307ca7a
Revert rest of mlm_weight
Apr 10, 2020
632858d
Revert LM change
Apr 10, 2020
fe62a16
Revert
Apr 10, 2020
bb8ebda
Merge branch 'add_sop' of https://github.com/nyu-mll/jiant into add_sop
Apr 10, 2020
431d6f7
Merging SOP
Apr 10, 2020
e99f39a
Improving documentation
Apr 10, 2020
089939a
Revert base_roberta
Apr 10, 2020
689fd52
revert unecessary change
Apr 10, 2020
b28cab6
Correcting documentation
Apr 10, 2020
8735c78
revert unnecessary changes
Apr 10, 2020
32e00ac
Refactoring SOP to make clearer
Apr 11, 2020
df39c8a
Adding SOPClassifier
Apr 11, 2020
5036612
Fixing SOP Task
Apr 11, 2020
283559e
Adding further documentation
Apr 11, 2020
98d873f
Adding more description of dataset
Apr 11, 2020
9fa473c
fixing merge conflict
phu-pmh Apr 11, 2020
a4e6d10
data_iter fix
phu-pmh Apr 11, 2020
6049772
cleaning up unnecessary files
phu-pmh Apr 11, 2020
8fc7b26
Making documentation clearer about our implementation of ALBERT SOP
Apr 12, 2020
2dbf444
Fix docstring
Apr 12, 2020
469164b
Refactoring SOP back as a PairClassificationTask, adding more documen…
Apr 12, 2020
d36e985
Adding more documentation, adding process_split
Apr 12, 2020
470e6b8
Fix typo in comment
Apr 12, 2020
e93d812
Adding modified SOP code
Apr 13, 2020
fd9b880
fixing based on comments
phu-pmh Apr 13, 2020
abe3230
fixing len(current_chunk)==1 condition
phu-pmh Apr 13, 2020
03546fb
fixing len(current_chunk)==1 condition
phu-pmh Apr 13, 2020
84b229f
documentation fix
phu-pmh Apr 13, 2020
8beac4a
minor fix
phu-pmh Apr 13, 2020
0413e04
minor fix: tokenizer
phu-pmh Apr 13, 2020
23a8df9
minor fix: current_length update
phu-pmh Apr 13, 2020
a2c1ed0
minor fix: current_length update
phu-pmh Apr 13, 2020
d398956
minor fix
phu-pmh Apr 13, 2020
01f959e
bug fix
phu-pmh Apr 13, 2020
670cde6
bug fix
phu-pmh Apr 13, 2020
6530c16
Fixing document leakage bug
Apr 13, 2020
ffc2cb9
Fixing document delimiting bug
Apr 13, 2020
3b92b7f
Cleaning up test
Apr 13, 2020
59c5635
Black style
Apr 13, 2020
a030f0b
Accurately updating current_length based on when len for_next_chunk > 2
Apr 14, 2020
821c7b6
SOP data generation insturctions
phu-pmh Apr 15, 2020
e284ba9
Fix documentation
phu-pmh Apr 15, 2020
298330f
Merge branch 'master' into add_sop
Apr 22, 2020
337d53e
Fixing docstrings and adding source of code
Apr 22, 2020
4f0fd05
Merge branch 'master' into add_sop
Apr 23, 2020
62d4bba
Fixing typos and data script documentation
Apr 23, 2020
0ff4596
Merge branch 'add_sop' of https://github.com/nyu-mll/jiant into add_sop
Apr 23, 2020
c58bc2c
Merge branch 'master' into add_sop
Apr 23, 2020
b5d711b
Revert merge mistake
Apr 23, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
34 changes: 34 additions & 0 deletions jiant/models.py
Expand Up @@ -31,6 +31,7 @@
PairClassifier,
NullPhraseLayer,
TokenMultiProjectionEncoder,
SOPClassifier,
)
from jiant.modules.attn_pair_encoder import AttnPairEncoder
from jiant.modules.sentence_encoder import SentenceEncoder
Expand Down Expand Up @@ -65,6 +66,7 @@
WiCTask,
MRPCTask,
QQPTask,
SentenceOrderTask,
)
from jiant.utils import config
from jiant.utils.utils import (
Expand Down Expand Up @@ -687,6 +689,34 @@ def build_single_sentence_module(task, d_inp: int, project_before_pooling: bool,
return module


def build_sop(task, d_inp, model, params):
"""
Build and load the pretrained head for the sentence order prediction task.
Right now, there is only support for ALBERT.
pruksmhc marked this conversation as resolved.
Show resolved Hide resolved
Parameters
----------
task: Task,
d_inp: int,
model: MultiTaskModel,
params: Params

Returns
-------
module: SOPCLassifier, which is loaded with pretrained weights from ALBERT SOP
pretraining.

"""
input_module = model.sent_encoder._text_field_embedder.input_module
assert (
"albert" in input_module
), "SOP is only supported for ALBERT, please set input_module to an ALBERT model"
module = SOPClassifier(d_inp, task.n_classes, params)
# The huggingface implementation exposes the pretrained projection layer for the SOP task, which
# we use. See: https://github.com/huggingface/transformers/issues/2671 for more details.
module.pooler.project = model.sent_encoder._text_field_embedder.model.pooler
pruksmhc marked this conversation as resolved.
Show resolved Hide resolved
return module


def build_pair_sentence_module(task, d_inp, model, params):
""" Build a pair classifier, shared if necessary """

Expand Down Expand Up @@ -745,6 +775,8 @@ def build_pair_attn(d_in, d_hid_attn):
d_out = d_out + d_inp if isinstance(task, WiCTask) else d_out
classifier = Classifier.from_params(4 * d_out, n_classes, params)
module = PairClassifier(pooler, classifier, pair_attn)
if isinstance(task, SentenceOrderTask):
module = build_sop(task, d_inp, model, params)
return module


Expand Down Expand Up @@ -899,6 +931,8 @@ def forward(self, task, batch, predict=False):
out = self._span_forward(batch, task, predict)
elif isinstance(task, SpanPredictionTask):
out = self._span_prediction_forward(batch, task, predict)
elif isinstance(task, SentenceOrderTask):
out = self._sop_forward(batch, task, predict)
else:
raise ValueError("Task-specific components not found!")
return out
Expand Down
24 changes: 24 additions & 0 deletions jiant/modules/simple_modules.py
Expand Up @@ -94,6 +94,30 @@ def from_params(cls, d_inp, n_classes, params):
)


class SOPClassifier(nn.Module):
"""
Task head for sentence order prediction task. We implement the pooled output from ALBERT
via a linear layer followed by Tanh activation layer, which is then fed into the
classification linear layer.
"""

def __init__(self, d_inp, n_classes, params):
super(SOPClassifier, self).__init__()
self.activation = nn.Tanh()
self.pooler = Pooler(d_inp=d_inp, d_proj=d_inp, pool_type=params["pool_type"])
assert params["cls_type"] == "log_reg", (
"The ALBERT implementation of the SOP "
"task takes the final layer from the pooled"
"output. Please set cls_type = log_reg."
)
self.classifier = Classifier.from_params(d_inp, n_classes, params)

def forward(self, seq_emb, mask):
seq_emb = self.activation(self.pooler(seq_emb, mask))
logits = self.classifier(seq_emb)
return logits


class SingleClassifier(nn.Module):
""" Thin wrapper around a set of modules. For single-sentence classification. """

Expand Down
169 changes: 169 additions & 0 deletions jiant/tasks/tasks.py
Expand Up @@ -4,6 +4,7 @@
import logging as log
import os
from typing import Any, Dict, Iterable, List, Sequence, Type
import random

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -3731,3 +3732,171 @@ def get_metrics(self, reset=False):
"""Get metrics specific to the task"""
acc = self.scorer1.get_metric(reset)
return {"accuracy": acc}


@register_task("wikipedia_corpus_sop", rel_path="wikipedia_sop_small")
class SentenceOrderTask(PairClassificationTask):
""" Task class for Sentence Order Prediction (SOP). See the ALBERT paper for details on SOP:
https://arxiv.org/abs/1909.11942.
We are currently using an preprocessed version of the Wikipedia corpus
(more specifically, the Wikidump version 2020-03-01 data) that consists of 5% of the data. You can generate
the data by following the instructions from jiant/scripts/sop.
One thing to note about our SOP ALBERT implementation is that we do not load the pretrained
weights for the SOP head beacuse they are unavailable in Huggingface. We only use the
pretrained weights of the linear layer from ALBERT that creates the pooled output used in SOP.
"""

def __init__(self, path, max_seq_len, name, **kw):
super(SentenceOrderTask, self).__init__(name, n_classes=2, **kw)
self.path = path
self.max_seq_len = max_seq_len
self.train_data_text = None
self.val_data_text = None
self.test_data_text = None
self.files_by_split = {
"train": os.path.join(path, "train.txt"),
"val": os.path.join(path, "valid.txt"),
"test": os.path.join(path, "test.txt"),
}
self._label_namespace = self.name + "_labels"

def get_target_seq_length(self):
target_is_max = random.random() > 0.1
max_seq_len = self.max_seq_len - 3 # exclude [CLS], [SEP], and [SEP]
if target_is_max:
target_seq_length = max_seq_len
else:
target_seq_length = random.randint(2, max_seq_len)
return target_seq_length

def get_data_iter(self, path):
pruksmhc marked this conversation as resolved.
Show resolved Hide resolved
"""Loading data file and tokenizing the text. We override the
this function and all functions that call this function because
the step of reading in the data for SOP is different than other
PairClassificationTasks.

ALBERT does SOP classification by, for each document:
For each example, we first fetch target_seq_length number of sentences from the dcoument:
pruksmhc marked this conversation as resolved.
Show resolved Hide resolved
-90% of the time, this target_seq_length is equal to max_seq_length, and
pruksmhc marked this conversation as resolved.
Show resolved Hide resolved
10% of the time, it is set to a random number of tokens between 2 and max_seq_length.
-Given the sampled sentences, randomly sample N such that the first N sentences in the
sampled go to the first segment, and the rest go to the second.
-50% of the time, the first and second segments are switched.
Args:
path: (str) data file path
"""

def _tokenize(tokenizer_name, sent):
tokenizer = get_tokenizer(tokenizer_name)
return tokenizer.tokenize(sent)

def is_end_document(seg):
tokenized_eod = _tokenize(self._tokenizer_name, "END OF ARTICLE")
return set(tokenized_eod).issubset(set(seg))

f = open(path, "r")
# The dataset comes with one sentence per line, thus we split by
# line here.
current_chunk = [_tokenize(self._tokenizer_name, next(f))]
current_length = len(current_chunk[0])
target_seq_length = self.get_target_seq_length()
while len(current_chunk) > 0:
segment = next(f)
segment = _tokenize(self._tokenizer_name, segment)
if is_end_document(segment) or current_length >= target_seq_length:
for_next_chunk = []
if current_length > target_seq_length:
# Since the most current sentence added to the chunk exceeds the target
# length, we save it for the next chunk (next example).
for_next_chunk.append(current_chunk.pop())
if not is_end_document(segment):
for_next_chunk.append(segment)
target_seq_length = self.get_target_seq_length()
if len(current_chunk) >= 2:
phu-pmh marked this conversation as resolved.
Show resolved Hide resolved
# Make sure we have at least 2 sentences to distribute between the two
# segments.
a_end = random.randint(1, len(current_chunk) - 1)
tokens_a = []
for j in range(a_end):
tokens_a.extend(current_chunk[j])
tokens_b = []
for j in range(a_end, len(current_chunk)):
tokens_b.extend(current_chunk[j])
in_order = random.random() < 0.5
if in_order:
yield (tokens_a, tokens_b, in_order)
else:
yield (tokens_b, tokens_a, in_order)
# if len(current_chunk) >=2, we will yield and reinitialize
# if len(current_chunk) ==1, we will not yeild, and reinitialize
if len(for_next_chunk) > 0 and not is_end_document(segment):
# Make sure we only sample articles for each example that
# belong to the same document.
current_chunk = for_next_chunk
current_length = sum([len(chunk) for chunk in for_next_chunk])
else:
# We find the next sentence for the next example.
try: # Might run into StopIterationError
current_chunk = [_tokenize(self._tokenizer_name, next(f))]
current_length = len(current_chunk[0])
except:
print("Done loading data for SOP")
current_chunk = []
current_length = 0
pass
else:
current_chunk.append(segment)
current_length += len(segment)
pyeres marked this conversation as resolved.
Show resolved Hide resolved

def load_data(self):
pass

def process_split(
pruksmhc marked this conversation as resolved.
Show resolved Hide resolved
self, split, indexers, model_preprocessing_interface
) -> Iterable[Type[Instance]]:
"""Process a sentence order prediction split by indexing and creating fields.
We override the PairClassificationTask process_split because our data split
is different from the typical PairClassificationTask due to the more memory-efficient
generator way of loading data we employ for SOP due to the dataset size.
Args:
split: (list) a single list of sentences
indexers: (Indexer object) indexer to index input words
"""

def _make_instance(sent_pairs_):
sent_a, sent_b, is_right_order = sent_pairs_
inp = model_preprocessing_interface.boundary_token_fn(sent_a, sent_b)
input_sent = sentence_to_text_field(inp, indexers)
label = LabelField(is_right_order, label_namespace="labels", skip_indexing=True)
d = {"inputs": input_sent, "labels": label}
return Instance(d)

for sent_pairs in split:
yield _make_instance(sent_pairs)

def get_split_text(self, split: str):
"""Get split text as iterable of records.
Args:
split: (str) should be one of 'train', 'val', or 'test'.
"""
return self.get_data_iter(self.files_by_split[split])

def get_sentences(self) -> Iterable[Sequence[str]]:
"""Yield sentences, used to compute vocabulary.
"""
for split in self.files_by_split:
# Don't use test set for vocab building.
if split.startswith("test"):
continue
for sent in self.get_data_iter(self.files_by_split[split]):
# only counting sent[0] is enough for computing vocab
yield sent[0]

def count_examples(self):
"""Computes number of samples
Assuming every line is one example.
"""
example_counts = {}
for split, split_path in self.files_by_split.items():
example_counts[split] = sum(1 for _ in self.get_data_iter(split_path))
self.example_counts = example_counts
30 changes: 30 additions & 0 deletions scripts/sop/README.md
@@ -0,0 +1,30 @@
# Downloading Wikipedia Corpus

We use the preprocessing code from https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/LanguageModeling/BERT#getting-the-data
and the bash scripts provided here is used to help with streamlining the data generation in the NVIDIA repository.

First, git clone https://github.com/NVIDIA/DeepLearningExamples.git.
Then, move create_wiki_data.sh and get_small_english_wiki.sh into DeepLearningExamples/PyTorch/LanguageModeling/BERT/data.
pyeres marked this conversation as resolved.
Show resolved Hide resolved

Then, follow the instructions below:

NVIDIA script download the latest Wikipedia dump. We use the Wikipedia dump 2020-03-01.
To download the Wikipedia dump 2020-03-01, replace line 29 of `DeepLearningExamples/PyTorch/LanguageModeling/BERT/data/WikiDownloader.py`:
`'en' : 'https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles.xml.bz2',` with `'en' : `https://dumps.wikimedia.org/enwiki/20200301/enwiki-20200301-pages-articles.xml.bz2`.

The data creation for SOP is almost the same as MLM, except you need to edit the following.
In `DeepLearningExamples/PyTorch/LanguageModeling/BERT/data/TextSharding.py`, replace line 55:
`self.articles[global_article_count] = line.rstrip()` with `self.articles[global_article_count] = line.rstrip() + "\n ========THIS IS THE END OF ARTICLE.========"`.
This is because SOP requires a signal for the end of each Wikipedia article.

Run `bash create_wiki_sop_data.sh $lang $save_directory`
The NVIDIA code supports English (en) and Chinese (zh) wikipedia.

For example, to download and process English Wikipedia and save it in `~/Download` directory, run
`bash create_wiki_sop_data.sh en ~/Download`

pruksmhc marked this conversation as resolved.
Show resolved Hide resolved
The above command will download the entire English Wikipedia.

In our experiments, we only use a small subset (around 5% of) the entire English Wikipedia, which has the same number of sentences as Wikitext103.
To get this subset, run `bash get_small_english_wiki.sh $path_to_wikicorpus_en`. where $path_to_wikicorpus_en is the directory where you saved the full processed `wikicorpus_en` corpus.

49 changes: 49 additions & 0 deletions scripts/sop/create_wiki_sop_data.sh
@@ -0,0 +1,49 @@
#!/bin/bash

# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

lang=$1 #the language, 'en' for English wikipedia
export BERT_PREP_WORKING_DIR=$2

# clone wikiextractor if it doesn't exist
if [ ! -d "wikiextractor" ]; then
git clone https://github.com/attardi/wikiextractor.git
fi

echo "Downloading $lang wikpedia in directory $save_dir"
# Download
python3 bertPrep.py --action download --dataset wikicorpus_$lang


# Properly format the text files
python3 bertPrep.py --action text_formatting --dataset wikicorpus_$lang


# Shard the text files (group wiki+books then shard)
python3 bertPrep.py --action sharding --dataset wikicorpus_$lang


# Combine sharded files into one
save_dir=$BERT_PREP_WORKING_DIR/sharded_training_shards_256_test_shards_256_fraction_0.2/wikicorpus_$lang
cat $save_dir/*training*.txt > $save_dir/train_$lang.txt
cat $save_dir/*test*.txt > $save_dir/test_$lang.txt
rm -rf $save_dir/wiki*training*.txt
rm -rf $save_dir/wiki*test*.txt

# remove some remaining xml tags
sed -i 's/<[^>]*>//g' $save_dir/train_$lang.txt
sed -i 's/<[^>]*>//g' $save_dir/test_$lang.txt

echo "Your corpus is saved in $save_dir"

6 changes: 6 additions & 0 deletions scripts/sop/get_small_english_wiki.sh
@@ -0,0 +1,6 @@
wiki_path=$1

mkdir -p $wiki_path/wikipedia_sop_small
head -3978309 $wiki_path/train_en.txt > $wiki_path/wikipedia_sop_small/train.txt
head -10001 $wiki_path/test_en.txt > $wiki_path/wikipedia_sop_small/test.txt
tail -8438 $wiki_path/train_en.txt > $wiki_path/wikipedia_sop_small/valid.txt