Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small speedups #2389

Merged
merged 4 commits into from
Aug 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
25 changes: 15 additions & 10 deletions flair/embeddings/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,9 @@ def __init__(
# if using context, can we cross document boundaries?
self.respect_document_boundaries = respect_document_boundaries

# send self to flair-device
self.to(flair.device)

# embedding parameters
if layers == 'all':
# send mini-token through to check how many layers the model has
Expand Down Expand Up @@ -894,7 +897,6 @@ def __init__(

# when initializing, embeddings are in eval mode by default
self.eval()
self.to(flair.device)

@staticmethod
def _get_begin_offset_of_tokenizer(tokenizer: PreTrainedTokenizer) -> int:
Expand Down Expand Up @@ -925,7 +927,11 @@ def _get_processed_token_text(self, token: Token) -> str:

def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:

batch_size = len(sentences)
# we require encoded subtokenized sentences, the mapping to original tokens and the number of
# parts that each sentence produces
subtokenized_sentences = []
all_token_subtoken_lengths = []
sentence_parts_lengths = []

# if we also use context, first expand sentence to include context
if self.context_length > 0:
Expand Down Expand Up @@ -958,11 +964,6 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:

sentences = expanded_sentences

# if sentence is too long, will be split into multiple parts
subtokenized_sentences = []
all_token_subtoken_lengths = []
sentence_parts_lengths = []

for sentence in sentences:

# subtokenize the sentence
Expand Down Expand Up @@ -991,14 +992,17 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:

n_parts: int = 0

subtokenized_sentence_splits = []
if self.allow_long_sentences:
# overlong sentences are handled as multiple splits
for encoded_input in encoded_inputs['input_ids']:
subtokenized_sentences.append(torch.tensor(encoded_input, dtype=torch.long))
subtokenized_sentence_splits.append(torch.tensor(encoded_input, dtype=torch.long))
n_parts += 1
else:
subtokenized_sentences.append(torch.tensor(encoded_inputs['input_ids'], dtype=torch.long))
subtokenized_sentence_splits.append(torch.tensor(encoded_inputs['input_ids'], dtype=torch.long))
n_parts += 1

subtokenized_sentences.extend(subtokenized_sentence_splits)
sentence_parts_lengths.append(n_parts)

# find longest sentence in batch
Expand Down Expand Up @@ -1101,7 +1105,8 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
sentences,
context_offsets):
for token_idx, token in enumerate(original_sentence):
token.set_embedding(self.name, expanded_sentence[token_idx + context_offset].get_embedding(self.name))
token.set_embedding(self.name,
expanded_sentence[token_idx + context_offset].get_embedding(self.name))
sentence = original_sentence

def _expand_sentence_with_context(self, sentence):
Expand Down
130 changes: 67 additions & 63 deletions flair/models/entity_linker_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,22 @@

log = logging.getLogger("flair")


class EntityLinker(flair.nn.DefaultClassifier):
"""
Entity Linking Model
The model expects text/sentences with annotated entity mentions and predicts entities to these mentions.
To this end a word embedding is used to embed the sentences and the embedding of the entity mention goes through a linear layer to get the actual class label.
The model is able to predict '<unk>' for entity mentions that the model can not confidently match to any of the known labels.
"""

def __init__(
self,
word_embeddings: flair.embeddings.TokenEmbeddings,
label_dictionary: Dictionary,
pooling_operation: str = 'average',
pooling_operation: str = 'average',
label_type: str = 'nel',
**classifierargs,
**classifierargs,
):
"""
Initializes an EntityLinker
Expand All @@ -34,118 +36,120 @@ def __init__(
the embedding of the first and the embedding of the last word.
:param label_type: name of the label you use.
"""

super(EntityLinker, self).__init__(label_dictionary, **classifierargs)
self.word_embeddings = word_embeddings

self.word_embeddings = word_embeddings
self.pooling_operation = pooling_operation
self._label_type = label_type
#if we concatenate the embeddings we need double input size in our linear layer

# if we concatenate the embeddings we need double input size in our linear layer
if self.pooling_operation == 'first&last':
self.decoder = nn.Linear(
2 * self.word_embeddings.embedding_length, len(self.label_dictionary)
).to(flair.device)
2 * self.word_embeddings.embedding_length, len(self.label_dictionary)
).to(flair.device)
else:
self.decoder = nn.Linear(
self.word_embeddings.embedding_length, len(self.label_dictionary)
).to(flair.device)
self.word_embeddings.embedding_length, len(self.label_dictionary)
).to(flair.device)

nn.init.xavier_uniform_(self.decoder.weight)

cases = {
'average': self.emb_mean,
'first': self.emb_first,
'last': self.emb_last,
'first&last': self.emb_firstAndLast
}
}

if pooling_operation not in cases:
raise KeyError('pooling_operation has to be one of "average", "first", "last" or "first&last"')

self.aggregated_embedding = cases.get(pooling_operation)

self.to(flair.device)

def emb_first(self, arg):
return arg[0]

def emb_last(self, arg):
return arg[-1]

def emb_firstAndLast(self,arg):
return torch.cat((arg[0],arg[-1]),0)
def emb_firstAndLast(self, arg):
return torch.cat((arg[0], arg[-1]), 0)

def emb_mean(self, arg):
return torch.mean(arg,0)
return torch.mean(arg, 0)

def forward_pass(self,
sentences: Union[List[DataPoint], DataPoint],
return_label_candidates: bool = False,
):
if isinstance(sentences,DataPoint):

if isinstance(sentences, DataPoint):
sentences = [sentences]
#filter sentences with no annotation

# filter sentences with no candidates (no candidates means nothing can be linked anyway)
filtered_sentences = []
for sentence in sentences:
if sentence.get_labels(self.label_type):
filtered_sentences.append(sentence)

#embedd all tokens
self.word_embeddings.embed(sentences)

embedding_names = self.word_embeddings.get_names()

embedding_list = []

# fields to return
span_labels = []
sentences_to_spans = []
empty_label_candidates = []
#get the embeddings of the entity mentions
for sentence in filtered_sentences:
spans = sentence.get_spans(self.label_type)

for span in spans:
mention_emb = torch.Tensor(0,self.word_embeddings.embedding_length).to(flair.device)

for token in span.tokens:
mention_emb=torch.cat((mention_emb, token.get_embedding(embedding_names).unsqueeze(0)), 0)

embedding_list.append(self.aggregated_embedding(mention_emb).unsqueeze(0))

span_labels.append([label.value for label in span.get_labels(typename=self.label_type)])

if return_label_candidates:
sentences_to_spans.append(sentence)
candidate = SpanLabel(span=span, value=None, score=None)
empty_label_candidates.append(candidate)

if len(embedding_list) > 0:
embedding_tensor = torch.cat(embedding_list, 0).to(flair.device)

scores = self.decoder(embedding_tensor)
#No entity mention in given sentences, return None
else:
# if the entire batch has no sentence with candidates, return empty
if len(filtered_sentences) == 0:
scores = None


# otherwise, embed sentence and send through prediction head
else:
# embed all tokens
self.word_embeddings.embed(filtered_sentences)

embedding_names = self.word_embeddings.get_names()

embedding_list = []
# get the embeddings of the entity mentions
for sentence in filtered_sentences:
spans = sentence.get_spans(self.label_type)

for span in spans:
mention_emb = torch.Tensor(0, self.word_embeddings.embedding_length).to(flair.device)

for token in span.tokens:
mention_emb = torch.cat((mention_emb, token.get_embedding(embedding_names).unsqueeze(0)), 0)

embedding_list.append(self.aggregated_embedding(mention_emb).unsqueeze(0))

span_labels.append([label.value for label in span.get_labels(typename=self.label_type)])

if return_label_candidates:
sentences_to_spans.append(sentence)
candidate = SpanLabel(span=span, value=None, score=None)
empty_label_candidates.append(candidate)

embedding_tensor = torch.cat(embedding_list, 0).to(flair.device)
scores = self.decoder(embedding_tensor)

# minimal return is scores and labels
return_tuple = (scores, span_labels)

if return_label_candidates:
return_tuple += (sentences_to_spans, empty_label_candidates)

return return_tuple



def _get_state_dict(self):
model_state = {
"state_dict": self.state_dict(),
"word_embeddings": self.word_embeddings,
"label_type": self.label_type,
"label_dictionary": self.label_dictionary,
"pooling_operation": self.pooling_operation,
}
}
return model_state

@staticmethod
Expand All @@ -159,7 +163,7 @@ def _init_model_with_state_dict(state):

model.load_state_dict(state["state_dict"])
return model

@property
def label_type(self):
return self._label_type
14 changes: 5 additions & 9 deletions tests/test_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import torch

from flair.data import Sentence, Dictionary
from flair.embeddings import (
WordEmbeddings,
TokenEmbeddings,
Expand All @@ -11,8 +12,6 @@
DocumentLMEmbeddings, TransformerWordEmbeddings, TransformerDocumentEmbeddings,
DocumentCNNEmbeddings,
)

from flair.data import Sentence, Dictionary
from flair.models import LanguageModel

glove: TokenEmbeddings = WordEmbeddings("turian")
Expand All @@ -33,7 +32,6 @@ def test_load_non_existing_flair_embedding():


def test_keep_batch_order():

embeddings = DocumentRNNEmbeddings([glove])
sentences_1 = [Sentence("First sentence"), Sentence("This is second sentence")]
sentences_2 = [Sentence("This is second sentence"), Sentence("First sentence")]
Expand All @@ -50,7 +48,6 @@ def test_keep_batch_order():


def test_stacked_embeddings():

embeddings: StackedEmbeddings = StackedEmbeddings([glove, flair_embedding])

sentence: Sentence = Sentence("I love Berlin. Berlin is a great place to live.")
Expand All @@ -66,7 +63,6 @@ def test_stacked_embeddings():


def test_transformer_word_embeddings():

embeddings = TransformerWordEmbeddings('distilbert-base-uncased', layers='-1,-2,-3,-4', layer_mean=False)

sentence: Sentence = Sentence("I love Berlin")
Expand Down Expand Up @@ -105,7 +101,6 @@ def test_transformer_word_embeddings():


def test_transformer_weird_sentences():

embeddings = TransformerWordEmbeddings('distilbert-base-uncased', layers='all', layer_mean=True)

sentence = Sentence("Hybrid mesons , qq ̄ states with an admixture")
Expand Down Expand Up @@ -146,6 +141,7 @@ def test_transformer_weird_sentences():
for token in sentence_2:
assert len(token.get_embedding()) == 768


def test_fine_tunable_flair_embedding():
language_model_forward = LanguageModel(
Dictionary.load("chars"), is_forward_lm=True, hidden_size=32, nlayers=1
Expand Down Expand Up @@ -258,7 +254,6 @@ def test_document_pool_embeddings_nonlinear():


def test_transformer_document_embeddings():

embeddings = TransformerDocumentEmbeddings('distilbert-base-uncased')

sentence: Sentence = Sentence("I love Berlin")
Expand Down Expand Up @@ -289,7 +284,8 @@ def test_transformer_document_embeddings():
sentence.clear_embeddings()

del embeddings



def test_document_cnn_embeddings():
sentence: Sentence = Sentence("I love Berlin. Berlin is a great place to live.")

Expand All @@ -305,4 +301,4 @@ def test_document_cnn_embeddings():
sentence.clear_embeddings()

assert len(sentence.get_embedding()) == 0
del embeddings
del embeddings
4 changes: 2 additions & 2 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import torch.optim.optimizer

import flair, pytest
import flair
import pytest
from flair.data import Sentence, Corpus
from flair.datasets import ColumnCorpus, ClassificationCorpus
from flair.embeddings import WordEmbeddings, FlairEmbeddings, StackedEmbeddings, DocumentPoolEmbeddings, \
Expand Down Expand Up @@ -338,7 +339,6 @@ def test_text_classifier_transformer_finetune(results_base_path, tasks_base_path

@pytest.mark.integration
def test_text_classifier_multi(results_base_path, tasks_base_path):

flair.set_seed(123)

corpus = ClassificationCorpus(tasks_base_path / "trivial" / "trivial_text_classification_multi",
Expand Down