Skip to content

Commit

Permalink
Changes from my multiple-choice work (allenai#4368)
Browse files Browse the repository at this point in the history
* Ability to ignore dimensions in the bert pooler

* File reading utilities

* Productivity through formatting

* More reasonable defaults for the Huggingface AdamW optimizer

* Changelog

* Adds a test for the BertPooler

* We can't run the new transformers lib yet

* Pin more recent transformer version

* Update CHANGELOG.md

Co-authored-by: Evan Pete Walsh <epwalsh10@gmail.com>

* Adds ability to override transformer weights

* Adds a transformer cache, and the ability to override weights

* Fix up this PR

* Fix comment

Co-authored-by: Evan Pete Walsh <epwalsh10@gmail.com>
  • Loading branch information
dirkgr and epwalsh committed Jun 29, 2020
1 parent eee15ca commit 96ff585
Show file tree
Hide file tree
Showing 13 changed files with 196 additions and 36 deletions.
7 changes: 4 additions & 3 deletions CHANGELOG.md
Expand Up @@ -16,17 +16,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
in distributed training.
- Fixed checking equality of `ArrayField`s.
- Fixed a bug where `NamespaceSwappingField` did not work correctly with `.empty_field()`.
- Put more sensible defaults on the `huggingface_adamw` optimizer.

### Added

- A method to ModelTestCase for running basic model tests when you aren't using config files.

### Added

- `BertPooler` can now unwrap and re-wrap extra dimensions if necessary.
- Added some convenience methods for reading files.
- Added an option to `file_utils.cached_path` to automatically extract archives.
- Added the ability to pass an archive file instead of a local directory to `Vocab.from_files`.
- Added the ability to pass an archive file instead of a glob to `ShardedDatasetReader`.


## [v1.0.0](https://github.com/allenai/allennlp/releases/tag/v1.0.0) - 2020-06-16

### Fixed
Expand Down
100 changes: 100 additions & 0 deletions allennlp/common/cached_transformers.py
@@ -0,0 +1,100 @@
import logging
from typing import NamedTuple, Optional, Dict, Tuple
import transformers
from transformers import AutoModel


logger = logging.getLogger(__name__)


class TransformerSpec(NamedTuple):
model_name: str
override_weights_file: Optional[str] = None
override_weights_strip_prefix: Optional[str] = None


_model_cache: Dict[TransformerSpec, transformers.PreTrainedModel] = {}


def get(
model_name: str,
make_copy: bool,
override_weights_file: Optional[str] = None,
override_weights_strip_prefix: Optional[str] = None,
) -> transformers.PreTrainedModel:
"""
Returns a transformer model from the cache.
# Parameters
model_name : `str`
The name of the transformer, for example `"bert-base-cased"`
make_copy : `bool`
If this is `True`, return a copy of the model instead of the cached model itself. If you want to modify the
parameters of the model, set this to `True`. If you want only part of the model, set this to `False`, but
make sure to `copy.deepcopy()` the bits you are keeping.
override_weights_file : `str`, optional
If set, this specifies a file from which to load alternate weights that override the
weights from huggingface. The file is expected to contain a PyTorch `state_dict`, created
with `torch.save()`.
override_weights_strip_prefix : `str`, optional
If set, strip the given prefix from the state dict when loading it.
"""
global _model_cache
spec = TransformerSpec(model_name, override_weights_file, override_weights_strip_prefix)
transformer = _model_cache.get(spec, None)
if transformer is None:
if override_weights_file is not None:
from allennlp.common.file_utils import cached_path
import torch

override_weights_file = cached_path(override_weights_file)
override_weights = torch.load(override_weights_file)
if override_weights_strip_prefix is not None:

def strip_prefix(s):
if s.startswith(override_weights_strip_prefix):
return s[len(override_weights_strip_prefix) :]
else:
return s

valid_keys = {
k
for k in override_weights.keys()
if k.startswith(override_weights_strip_prefix)
}
if len(valid_keys) > 0:
logger.info(
"Loading %d tensors from %s", len(valid_keys), override_weights_file
)
else:
raise ValueError(
f"Specified prefix of '{override_weights_strip_prefix}' means no tensors "
f"will be loaded from {override_weights_file}."
)
override_weights = {strip_prefix(k): override_weights[k] for k in valid_keys}

transformer = AutoModel.from_pretrained(model_name, state_dict=override_weights)
else:
transformer = AutoModel.from_pretrained(model_name)
_model_cache[spec] = transformer
if make_copy:
import copy

return copy.deepcopy(transformer)
else:
return transformer


_tokenizer_cache: Dict[Tuple[str, frozenset], transformers.PreTrainedTokenizer] = {}


def get_tokenizer(model_name: str, **kwargs) -> transformers.PreTrainedTokenizer:
cache_key = (model_name, frozenset(kwargs.items()))

global _tokenizer_cache
tokenizer = _tokenizer_cache.get(cache_key, None)
if tokenizer is None:
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, **kwargs)
_tokenizer_cache[cache_key] = tokenizer
return tokenizer
15 changes: 14 additions & 1 deletion allennlp/common/file_utils.py
Expand Up @@ -9,7 +9,7 @@
import json
from urllib.parse import urlparse
from pathlib import Path
from typing import Optional, Tuple, Union, IO, Callable, Set, List
from typing import Optional, Tuple, Union, IO, Callable, Set, List, Iterator, Iterable
from hashlib import sha256
from functools import wraps
from zipfile import ZipFile, is_zipfile
Expand Down Expand Up @@ -458,3 +458,16 @@ def open_compressed(

open_fn = bz2.open
return open_fn(filename, mode=mode, encoding=encoding, **kwargs)


def text_lines_from_file(filename: Union[str, Path], strip_lines: bool = True) -> Iterator[str]:
with open_compressed(filename, "rt", encoding="UTF-8", errors="replace") as p:
if strip_lines:
for line in p:
yield line.strip()
else:
yield from p


def json_lines_from_file(filename: Union[str, Path]) -> Iterable[Union[list, dict]]:
return (json.loads(line) for line in text_lines_from_file(filename))
9 changes: 6 additions & 3 deletions allennlp/data/tokenizers/pretrained_transformer_tokenizer.py
Expand Up @@ -3,7 +3,6 @@

from overrides import overrides
from transformers import PreTrainedTokenizer
from transformers.tokenization_auto import AutoTokenizer

from allennlp.common.util import sanitize_wordpiece
from allennlp.data.tokenizers.token import Token
Expand Down Expand Up @@ -74,7 +73,9 @@ def __init__(
tokenizer_kwargs.setdefault("use_fast", True)
# Note: Just because we request a fast tokenizer doesn't mean we get one.

self.tokenizer = AutoTokenizer.from_pretrained(
from allennlp.common import cached_transformers

self.tokenizer = cached_transformers.get_tokenizer(
model_name, add_special_tokens=False, **tokenizer_kwargs
)

Expand Down Expand Up @@ -114,7 +115,9 @@ def _reverse_engineer_special_tokens(
self.single_sequence_token_type_id = None

# Reverse-engineer the tokenizer for two sequences
tokenizer_with_special_tokens = AutoTokenizer.from_pretrained(
from allennlp.common import cached_transformers

tokenizer_with_special_tokens = cached_transformers.get_tokenizer(
model_name, add_special_tokens=True, **tokenizer_kwargs
)
dummy_output = tokenizer_with_special_tokens.encode_plus(
Expand Down
34 changes: 27 additions & 7 deletions allennlp/modules/seq2vec_encoders/bert_pooler.py
@@ -1,8 +1,9 @@
from typing import Optional

from overrides import overrides

import torch
import torch.nn
from transformers.modeling_auto import AutoModel

from allennlp.modules.seq2vec_encoders.seq2vec_encoder import Seq2VecEncoder

Expand All @@ -23,7 +24,7 @@ class BertPooler(Seq2VecEncoder):
pretrained_model : `Union[str, BertModel]`, required
The pretrained BERT model to use. If this is a string,
we will call `BertModel.from_pretrained(pretrained_model)`
we will call `transformers.AutoModel.from_pretrained(pretrained_model)`
and use that.
requires_grad : `bool`, optional, (default = `True`)
If True, the weights of the pooler will be updated during training.
Expand All @@ -33,15 +34,27 @@ class BertPooler(Seq2VecEncoder):
"""

def __init__(
self, pretrained_model: str, requires_grad: bool = True, dropout: float = 0.0
self,
pretrained_model: str,
*,
override_weights_file: Optional[str] = None,
override_weights_strip_prefix: Optional[str] = None,
requires_grad: bool = True,
dropout: float = 0.0
) -> None:
super().__init__()

model = AutoModel.from_pretrained(pretrained_model)
from allennlp.common import cached_transformers

model = cached_transformers.get(
pretrained_model, False, override_weights_file, override_weights_strip_prefix
)

self._dropout = torch.nn.Dropout(p=dropout)

self.pooler = model.pooler
import copy

self.pooler = copy.deepcopy(model.pooler)
for param in self.pooler.parameters():
param.requires_grad = requires_grad
self._embedding_dim = model.config.hidden_size
Expand All @@ -54,7 +67,14 @@ def get_input_dim(self) -> int:
def get_output_dim(self) -> int:
return self._embedding_dim

def forward(self, tokens: torch.Tensor, mask: torch.BoolTensor = None):
pooled = self.pooler(tokens)
def forward(
self, tokens: torch.Tensor, mask: torch.BoolTensor = None, num_wrapping_dims: int = 0
):
pooler = self.pooler
for _ in range(num_wrapping_dims):
from allennlp.modules import TimeDistributed

pooler = TimeDistributed(pooler)
pooled = pooler(tokens)
pooled = self._dropout(pooled)
return pooled
Expand Up @@ -6,7 +6,6 @@
import torch
import torch.nn.functional as F
from transformers import XLNetConfig
from transformers.modeling_auto import AutoModel

from allennlp.data.tokenizers import PretrainedTransformerTokenizer
from allennlp.modules.token_embedders.token_embedder import TokenEmbedder
Expand Down Expand Up @@ -41,12 +40,19 @@ class PretrainedTransformerEmbedder(TokenEmbedder):
def __init__(
self,
model_name: str,
*,
max_length: int = None,
sub_module: str = None,
train_parameters: bool = True,
override_weights_file: Optional[str] = None,
override_weights_strip_prefix: Optional[str] = None
) -> None:
super().__init__()
self.transformer_model = AutoModel.from_pretrained(model_name)
from allennlp.common import cached_transformers

self.transformer_model = cached_transformers.get(
model_name, True, override_weights_file, override_weights_strip_prefix
)
self.config = self.transformer_model.config
if sub_module:
assert hasattr(self.transformer_model, sub_module)
Expand Down
Expand Up @@ -35,7 +35,7 @@ def __init__(
super().__init__()
# The matched version v.s. mismatched
self._matched_embedder = PretrainedTransformerEmbedder(
model_name, max_length, train_parameters=train_parameters
model_name, max_length=max_length, train_parameters=train_parameters
)

@overrides
Expand Down
6 changes: 3 additions & 3 deletions allennlp/training/optimizers.py
Expand Up @@ -265,11 +265,11 @@ def __init__(
self,
model_parameters: List[Tuple[str, torch.nn.Parameter]],
parameter_groups: List[Tuple[List[str], Dict[str, Any]]] = None,
lr: float = 0.001,
lr: float = 1e-5,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-06,
eps: float = 1e-08,
weight_decay: float = 0.0,
correct_bias: bool = False,
correct_bias: bool = True,
):
super().__init__(
params=make_parameter_groups(model_parameters, parameter_groups),
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -64,7 +64,7 @@
"scikit-learn",
"scipy",
"pytest",
"transformers>=2.9,<2.12",
"transformers>=2.10,<2.11",
"jsonpickle",
"dataclasses;python_version<'3.7'",
"filelock>=3.0,<3.1",
Expand Down
19 changes: 9 additions & 10 deletions tests/data/token_indexers/pretrained_transformer_indexer_test.py
@@ -1,5 +1,4 @@
from transformers.tokenization_auto import AutoTokenizer

from allennlp.common import cached_transformers
from allennlp.common.testing import AllenNlpTestCase
from allennlp.data import Vocabulary
from allennlp.data.token_indexers import PretrainedTransformerIndexer
Expand All @@ -8,7 +7,7 @@

class TestPretrainedTransformerIndexer(AllenNlpTestCase):
def test_as_array_produces_token_sequence_bert_uncased(self):
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
tokenizer = cached_transformers.get_tokenizer("bert-base-uncased")
allennlp_tokenizer = PretrainedTransformerTokenizer("bert-base-uncased")
indexer = PretrainedTransformerIndexer(model_name="bert-base-uncased")
string_specials = "[CLS] AllenNLP is great [SEP]"
Expand All @@ -22,7 +21,7 @@ def test_as_array_produces_token_sequence_bert_uncased(self):
assert indexed["token_ids"] == expected_ids

def test_as_array_produces_token_sequence_bert_cased(self):
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
tokenizer = cached_transformers.get_tokenizer("bert-base-cased")
allennlp_tokenizer = PretrainedTransformerTokenizer("bert-base-cased")
indexer = PretrainedTransformerIndexer(model_name="bert-base-cased")
string_specials = "[CLS] AllenNLP is great [SEP]"
Expand All @@ -36,7 +35,7 @@ def test_as_array_produces_token_sequence_bert_cased(self):
assert indexed["token_ids"] == expected_ids

def test_as_array_produces_token_sequence_bert_cased_sentence_pair(self):
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
tokenizer = cached_transformers.get_tokenizer("bert-base-cased")
allennlp_tokenizer = PretrainedTransformerTokenizer(
"bert-base-cased", add_special_tokens=False
)
Expand All @@ -53,7 +52,7 @@ def test_as_array_produces_token_sequence_bert_cased_sentence_pair(self):
assert indexed["token_ids"] == expected_ids

def test_as_array_produces_token_sequence_roberta(self):
tokenizer = AutoTokenizer.from_pretrained("roberta-base")
tokenizer = cached_transformers.get_tokenizer("roberta-base")
allennlp_tokenizer = PretrainedTransformerTokenizer("roberta-base")
indexer = PretrainedTransformerIndexer(model_name="roberta-base")
string_specials = "<s> AllenNLP is great </s>"
Expand All @@ -67,7 +66,7 @@ def test_as_array_produces_token_sequence_roberta(self):
assert indexed["token_ids"] == expected_ids

def test_as_array_produces_token_sequence_roberta_sentence_pair(self):
tokenizer = AutoTokenizer.from_pretrained("roberta-base")
tokenizer = cached_transformers.get_tokenizer("roberta-base")
allennlp_tokenizer = PretrainedTransformerTokenizer(
"roberta-base", add_special_tokens=False
)
Expand All @@ -86,7 +85,7 @@ def test_as_array_produces_token_sequence_roberta_sentence_pair(self):
def test_transformers_vocab_sizes(self):
def check_vocab_size(model_name: str):
namespace = "tags"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer = cached_transformers.get_tokenizer(model_name)
allennlp_tokenizer = PretrainedTransformerTokenizer(model_name)
indexer = PretrainedTransformerIndexer(model_name=model_name, namespace=namespace)
allennlp_tokens = allennlp_tokenizer.tokenize("AllenNLP is great!")
Expand All @@ -102,7 +101,7 @@ def check_vocab_size(model_name: str):

def test_transformers_vocabs_added_correctly(self):
namespace, model_name = "tags", "roberta-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer = cached_transformers.get_tokenizer(model_name)
allennlp_tokenizer = PretrainedTransformerTokenizer(model_name)
indexer = PretrainedTransformerIndexer(model_name=model_name, namespace=namespace)
allennlp_tokens = allennlp_tokenizer.tokenize("AllenNLP is great!")
Expand Down Expand Up @@ -142,7 +141,7 @@ def test_mask(self):
assert padded_tokens["token_ids"][-padding_length:].tolist() == padding_suffix

def test_long_sequence_splitting(self):
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
tokenizer = cached_transformers.get_tokenizer("bert-base-uncased")
allennlp_tokenizer = PretrainedTransformerTokenizer("bert-base-uncased")
indexer = PretrainedTransformerIndexer(model_name="bert-base-uncased", max_length=4)
string_specials = "[CLS] AllenNLP is great [SEP]"
Expand Down

0 comments on commit 96ff585

Please sign in to comment.