Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce the number of concatenation for 10% inference time reduction #1093

Merged
merged 5 commits into from Sep 13, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 9 additions & 2 deletions flair/data.py
Expand Up @@ -237,7 +237,9 @@ def set_embedding(self, name: str, vector: torch.tensor):
device = flair.device
if len(self._embeddings.keys()) > 0:
device = next(iter(self._embeddings.values())).device
self._embeddings[name] = vector.to(device)
if device != vector.device:
vector = vector.to(device)
self._embeddings[name] = vector

def to(self, device: str, pin_memory: bool = False):
for name, vector in self._embeddings.items():
Expand All @@ -257,6 +259,9 @@ def clear_embeddings(self, embedding_names: List[str] = None):
if name in self._embeddings.keys():
del self._embeddings[name]

def get_each_embedding(self) -> torch.tensor:
return [self._embeddings[embed] for embed in sorted(self._embeddings.keys())]

def get_embedding(self) -> torch.tensor:
embeddings = [
self._embeddings[embed] for embed in sorted(self._embeddings.keys())
Expand Down Expand Up @@ -642,7 +647,9 @@ def set_embedding(self, name: str, vector: torch.tensor):
device = flair.device
if len(self._embeddings.keys()) > 0:
device = next(iter(self._embeddings.values())).device
self._embeddings[name] = vector.to(device, non_blocking=True)
if device != vector.device:
vector = vector.to(device)
self._embeddings[name] = vector

def get_embedding(self) -> torch.tensor:
embeddings = []
Expand Down
15 changes: 12 additions & 3 deletions flair/embeddings.py
Expand Up @@ -2756,9 +2756,18 @@ def _add_embeddings_internal(self, sentences: Union[List[Sentence], Sentence]):

for s_id, sentence in enumerate(sentences):
# fill values with word embeddings
sentence_tensor[s_id][: len(sentence)] = torch.cat(
[token.get_embedding().unsqueeze(0) for token in sentence], 0
)
all_embs = list()

for index_token, token in enumerate(sentence):
embs = token.get_each_embedding()
if not all_embs:
all_embs = [list() for _ in range(len(embs))]
for index_emb, emb in enumerate(embs):
all_embs[index_emb].append(emb)

concat_word_emb = [torch.stack(embs) for embs in all_embs]
concat_sentence_emb = torch.cat(concat_word_emb, dim=1)
sentence_tensor[s_id][: len(sentence)] = concat_sentence_emb

# --------------------------------------------------------------------
# FF PART
Expand Down
15 changes: 12 additions & 3 deletions flair/models/sequence_tagger_model.py
Expand Up @@ -457,9 +457,18 @@ def forward(self, sentences: List[Sentence]):
)

for s_id, sentence in enumerate(sentences):
# fill values with word embeddings
token_embeddings = [token.get_embedding() for token in sentence]
sentence_tensor[s_id][: len(sentence)] = torch.stack(token_embeddings)
all_embs = list()

for index_token, token in enumerate(sentence):
embs = token.get_each_embedding()
if not all_embs:
all_embs = [list() for _ in range(len(embs))]
for index_emb, emb in enumerate(embs):
all_embs[index_emb].append(emb)

concat_word_emb = [torch.stack(embs) for embs in all_embs]
concat_sentence_emb = torch.cat(concat_word_emb, dim=1)
sentence_tensor[s_id][: len(sentence)] = concat_sentence_emb

# --------------------------------------------------------------------
# FF PART
Expand Down