Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set a LRU cache for word embeddings for a decrease of 20% of inference time #1084

Merged
merged 9 commits into from Sep 11, 2019
4 changes: 2 additions & 2 deletions flair/data.py
Expand Up @@ -237,7 +237,7 @@ def set_embedding(self, name: str, vector: torch.tensor):
device = flair.device
if len(self._embeddings.keys()) > 0:
device = next(iter(self._embeddings.values())).device
self._embeddings[name] = vector.to(device, non_blocking=True)
self._embeddings[name] = vector.to(device)
alanakbik marked this conversation as resolved.
Show resolved Hide resolved

def to(self, device: str, pin_memory: bool = False):
for name, vector in self._embeddings.items():
Expand Down Expand Up @@ -638,7 +638,7 @@ def get_label_names(self) -> List[str]:
def embedding(self):
return self.get_embedding()

def set_embedding(self, name: str, vector):
def set_embedding(self, name: str, vector: torch.tensor):
device = flair.device
if len(self._embeddings.keys()) > 0:
device = next(iter(self._embeddings.values())).device
Expand Down
102 changes: 58 additions & 44 deletions flair/embeddings.py
Expand Up @@ -3,6 +3,7 @@
import logging
from abc import abstractmethod
from collections import Counter
from functools import lru_cache
from pathlib import Path
from typing import List, Union, Dict

Expand Down Expand Up @@ -321,6 +322,28 @@ def __init__(self, embeddings: str, field: str = None):
def embedding_length(self) -> int:
return self.__embedding_length

@lru_cache(maxsize=10000, typed=False)
def get_cached_vec(self, word: str) -> torch.Tensor:
if word in self.precomputed_word_embeddings:
word_embedding = self.precomputed_word_embeddings[word]
elif word.lower() in self.precomputed_word_embeddings:
word_embedding = self.precomputed_word_embeddings[word.lower()]
elif re.sub(r"\d", "#", word.lower()) in self.precomputed_word_embeddings:
word_embedding = self.precomputed_word_embeddings[
re.sub(r"\d", "#", word.lower())
]
elif re.sub(r"\d", "0", word.lower()) in self.precomputed_word_embeddings:
word_embedding = self.precomputed_word_embeddings[
re.sub(r"\d", "0", word.lower())
]
else:
word_embedding = np.zeros(self.embedding_length, dtype="float")

word_embedding = torch.tensor(
word_embedding, device=flair.device, dtype=torch.float
)
return word_embedding

def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:

for i, sentence in enumerate(sentences):
Expand All @@ -332,26 +355,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
else:
word = token.get_tag(self.field).value

if word in self.precomputed_word_embeddings:
word_embedding = self.precomputed_word_embeddings[word]
elif word.lower() in self.precomputed_word_embeddings:
word_embedding = self.precomputed_word_embeddings[word.lower()]
elif (
re.sub(r"\d", "#", word.lower()) in self.precomputed_word_embeddings
):
word_embedding = self.precomputed_word_embeddings[
re.sub(r"\d", "#", word.lower())
]
elif (
re.sub(r"\d", "0", word.lower()) in self.precomputed_word_embeddings
):
word_embedding = self.precomputed_word_embeddings[
re.sub(r"\d", "0", word.lower())
]
else:
word_embedding = np.zeros(self.embedding_length, dtype="float")

word_embedding = torch.FloatTensor(word_embedding)
word_embedding = self.get_cached_vec(word=word)

token.set_embedding(self.name, word_embedding)

Expand Down Expand Up @@ -409,6 +413,18 @@ def __init__(self, embeddings: str, use_local: bool = True, field: str = None):
def embedding_length(self) -> int:
return self.__embedding_length

@lru_cache(maxsize=10000, typed=False)
def get_cached_vec(self, word: str) -> torch.Tensor:
try:
word_embedding = self.precomputed_word_embeddings[word]
except:
word_embedding = np.zeros(self.embedding_length, dtype="float")

word_embedding = torch.tensor(
word_embedding, device=flair.device, dtype=torch.float
)
return word_embedding

def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:

for i, sentence in enumerate(sentences):
Expand All @@ -420,12 +436,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
else:
word = token.get_tag(self.field).value

try:
word_embedding = self.precomputed_word_embeddings[word]
except:
word_embedding = np.zeros(self.embedding_length, dtype="float")

word_embedding = torch.FloatTensor(word_embedding)
word_embedding = self.get_cached_vec(word)

token.set_embedding(self.name, word_embedding)

Expand Down Expand Up @@ -561,6 +572,24 @@ def __init__(self,):
self.language_embeddings = {}
super().__init__()

@lru_cache(maxsize=10000, typed=False)
def get_cached_vec(self, language_code: str, word: str) -> torch.Tensor:
current_embedding_model = self.language_embeddings[language_code]
if word in current_embedding_model:
word_embedding = current_embedding_model[word]
elif word.lower() in current_embedding_model:
word_embedding = current_embedding_model[word.lower()]
elif re.sub(r"\d", "#", word.lower()) in current_embedding_model:
word_embedding = current_embedding_model[re.sub(r"\d", "#", word.lower())]
elif re.sub(r"\d", "0", word.lower()) in current_embedding_model:
word_embedding = current_embedding_model[re.sub(r"\d", "0", word.lower())]
else:
word_embedding = np.zeros(self.embedding_length, dtype="float")
word_embedding = torch.tensor(
word_embedding, device=flair.device, dtype=torch.float
)
return word_embedding

def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:

for i, sentence in enumerate(sentences):
Expand Down Expand Up @@ -613,31 +642,16 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
language_code
] = gensim.models.KeyedVectors.load(str(embeddings_file))

current_embedding_model = self.language_embeddings[language_code]

for token, token_idx in zip(sentence.tokens, range(len(sentence.tokens))):

if "field" not in self.__dict__ or self.field is None:
word = token.text
else:
word = token.get_tag(self.field).value

if word in current_embedding_model:
word_embedding = current_embedding_model[word]
elif word.lower() in current_embedding_model:
word_embedding = current_embedding_model[word.lower()]
elif re.sub(r"\d", "#", word.lower()) in current_embedding_model:
word_embedding = current_embedding_model[
re.sub(r"\d", "#", word.lower())
]
elif re.sub(r"\d", "0", word.lower()) in current_embedding_model:
word_embedding = current_embedding_model[
re.sub(r"\d", "0", word.lower())
]
else:
word_embedding = np.zeros(self.embedding_length, dtype="float")

word_embedding = torch.FloatTensor(word_embedding)
word_embedding = self.get_cached_vec(
language_code=language_code, word=word
)

token.set_embedding(self.name, word_embedding)

Expand Down
5 changes: 2 additions & 3 deletions flair/models/sequence_tagger_model.py
Expand Up @@ -458,9 +458,8 @@ def forward(self, sentences: List[Sentence]):

for s_id, sentence in enumerate(sentences):
# fill values with word embeddings
sentence_tensor[s_id][: len(sentence)] = torch.cat(
[token.get_embedding().unsqueeze(0) for token in sentence], 0
)
token_embeddings = [token.get_embedding() for token in sentence]
sentence_tensor[s_id][: len(sentence)] = torch.stack(token_embeddings)

# --------------------------------------------------------------------
# FF PART
Expand Down
6 changes: 4 additions & 2 deletions flair/visual/ner_html.py
Expand Up @@ -16,7 +16,7 @@
<!DOCTYPE html>
<html lang="en">
<head>
<title>Flair</title>
<title>{title}</title>
</head>

<body style="font-size: 16px; font-family: 'Segoe UI'; padding: 4rem 2rem">{text}</body>
Expand All @@ -41,6 +41,7 @@ def split_to_spans(s: Sentence):

def render_ner_html(
sentences: Union[List[Sentence], Sentence],
title: str = "Flair",
colors={
"PER": "#F7FF53",
"ORG": "#E8902E",
Expand All @@ -53,6 +54,7 @@ def render_ner_html(
) -> str:
"""
:param sentences: single sentence or list of sentences to convert to HTML
:param title: title of the HTML page
:param colors: dict where keys are tags and values are color HTML codes
:param default_color: color to use if colors parameter is missing a tag
:param wrap_page: if True method returns result of processing sentences wrapped by &lt;html&gt; and &lt;body&gt; tags, otherwise - without these tags
Expand All @@ -79,6 +81,6 @@ def render_ner_html(
final_text = "".join(sentences_html)

if wrap_page:
return HTML_PAGE.format(text=final_text)
return HTML_PAGE.format(text=final_text, title=title)
else:
return final_text
3 changes: 2 additions & 1 deletion tests/test_visual.py
Expand Up @@ -78,7 +78,8 @@ def test_html_rendering():
+ " leader in a ballot of party members and will become the next "
+ TAGGED_ENTITY.format(color="yellow", entity="UK", label="LOC")
+ " prime minister. &amp;"
)
),
title="Flair",
)

assert expected_res == actual