Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IO for transformer component #178

Merged
merged 8 commits into from May 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Empty file.
8 changes: 8 additions & 0 deletions examples/own/dep-distilbert.cfg
Expand Up @@ -48,5 +48,13 @@ maxout_pieces = 3
[nlp.pipeline.parser.model.tok2vec]
@architectures = "spacy.Tok2VecTransformer.v1"
name = "distilbert-base-uncased"
fast_tokenizer = true
width = 768

[nlp.pipeline.parser.model.tok2vec.pooling]
@layers = "reduce_mean.v1"

[nlp.pipeline.parser.model.tok2vec.get_spans]
@layers = "spacy-transformers.strided_spans.v1"
window = 256
stride = 196
1 change: 0 additions & 1 deletion examples/train_from_config.py
Expand Up @@ -13,7 +13,6 @@ def main(config_path, train_path, eval_path, gpu_id):
config_path = Path(config_path)
train_path = Path(train_path)
eval_path = Path(eval_path)
install_extensions()
if gpu_id >= 0:
spacy.util.use_gpu(gpu_id)
use_pytorch_for_gpu_memory()
Expand Down
66 changes: 65 additions & 1 deletion spacy_transformers/pipeline.py
@@ -1,13 +1,22 @@
from typing import List, Callable, Optional
from spacy.pipeline import Pipe
from spacy.language import component
from spacy.pipeline.pipes import _load_cfg
from spacy.tokens import Doc
from spacy.vocab import Vocab
from spacy.gold import Example
from spacy import util
from spacy.util import minibatch, eg2doc, link_vectors_to_models
from spacy_transformers.wrapper import PyTorchTransformer
from thinc.api import Model, set_dropout_rate

from .util import null_annotation_setter
import srsly
import torch
from transformers import AutoModel, AutoTokenizer
from transformers import WEIGHTS_NAME, CONFIG_NAME
from pathlib import Path

from .util import null_annotation_setter, install_extensions
from .util import FullTransformerBatch, TransformerData


Expand All @@ -22,6 +31,22 @@ class Transformer(Pipe):
the same spaCy token, the spaCy token receives the sum of their values.
"""

@classmethod
def from_nlp(cls, nlp, model, **cfg):

# fmt: off
arch = nlp.config.get("transformer", {}).get("model", {}).get("@architectures", None)
# fmt: on

# we want to prevent downloading the model again - we can just read from file
if arch is not None and "TransformerByName" in arch:
nlp.config["transformer"]["model"] = {
"@architectures": "spacy.TransformerFromFile.v1",
"get_spans": nlp.config["transformer"]["model"]["get_spans"],
}

return cls(nlp.vocab, model, **cfg)

def __init__(
self,
vocab: Vocab,
Expand All @@ -38,6 +63,7 @@ def __init__(
self.cfg = dict(cfg)
self.cfg["max_batch_size"] = max_batch_size
self.listeners: List[TransformerListener] = []
install_extensions()

def create_listener(self):
listener = TransformerListener(
Expand Down Expand Up @@ -157,6 +183,44 @@ def begin_training(
self.model.initialize(X=docs)
link_vectors_to_models(self.vocab)

def to_disk(self, path, exclude=tuple(), **kwargs):
"""Serialize the pipe and its model to disk."""

def save_model(p):
trf_dir = Path(p).absolute()
trf_dir.mkdir()
self.model.attrs["tokenizer"].save_pretrained(str(trf_dir))
transformer = self.model.layers[0].shims[0]._model
torch.save(transformer.state_dict(), trf_dir / WEIGHTS_NAME)
transformer.config.to_json_file(trf_dir / CONFIG_NAME)

serialize = {}
serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
serialize["vocab"] = lambda p: self.vocab.to_disk(p)
serialize["model"] = lambda p: save_model(p)
exclude = util.get_serialization_exclude(serialize, exclude, kwargs)
util.to_disk(path, serialize, exclude)

def from_disk(self, path, exclude=tuple(), **kwargs):
"""Load the pipe and its model from disk."""

def load_model(p):
trf_dir = Path(p).absolute()
transformer = AutoModel.from_pretrained(str(trf_dir))
wrapper = PyTorchTransformer(transformer)
assert len(self.model.layers) == 0
self.model.layers.append(wrapper)
tokenizer = AutoTokenizer.from_pretrained(str(trf_dir))
self.model.attrs["tokenizer"] = tokenizer

deserialize = {}
deserialize["vocab"] = lambda p: self.vocab.from_disk(p)
deserialize["cfg"] = lambda p: self.cfg.update(_load_cfg(p))
deserialize["model"] = lambda p: load_model(p)
exclude = util.get_serialization_exclude(deserialize, exclude, kwargs)
util.from_disk(path, deserialize, exclude)
return self


class TransformerListener(Model):
"""A layer that gets fed its answers from an upstream connection,
Expand Down
5 changes: 1 addition & 4 deletions spacy_transformers/tests/test_pipeline_component.py
Expand Up @@ -24,10 +24,7 @@ def docs(vocab):

@pytest.fixture
def component(vocab):
try:
install_extensions()
except ValueError:
pass
install_extensions()
return Transformer(Vocab(), DummyTransformer())


Expand Down
3 changes: 1 addition & 2 deletions spacy_transformers/util.py
Expand Up @@ -12,7 +12,6 @@
from spacy.tokens import Span
from ._align import get_token_positions


BatchEncoding = Dict


Expand Down Expand Up @@ -89,7 +88,7 @@ def split_by_doc(self) -> List["TransformerData"]:


def install_extensions():
Doc.set_extension("trf_data", default=TransformerData.empty())
Doc.set_extension("trf_data", default=TransformerData.empty(), force=True)


@registry.layers("spacy-transformers.strided_spans.v1")
Expand Down
11 changes: 11 additions & 0 deletions spacy_transformers/wrapper.py
Expand Up @@ -35,6 +35,17 @@ def TransformerModel(
)


@registry.architectures.register("spacy.TransformerFromFile.v1")
def TransformerFromFile(get_spans: Callable) -> Model[List[Doc], TransformerData]:
# This Model needs to be loaded further by calling from_disk on the pipeline component
return Model(
"transformer",
forward,
attrs={"get_spans": get_spans},
dims={"nO": None},
)


def forward(
model: Model, docs: List[Doc], is_train: bool
) -> Tuple[FullTransformerBatch, Callable]:
Expand Down