Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions flair/embeddings/token.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import hashlib
import pickle
import inspect
from abc import abstractmethod
from pathlib import Path
from typing import List, Union, Dict
from collections import Counter
from functools import lru_cache

import torch
from bpemb import BPEmb
Expand All @@ -21,7 +19,7 @@

from flair.data import Sentence, Token, Corpus, Dictionary
from flair.embeddings.base import Embeddings, ScalarMix
from flair.file_utils import cached_path, open_inside_zip
from flair.file_utils import cached_path, open_inside_zip, instance_lru_cache

log = logging.getLogger("flair")

Expand Down Expand Up @@ -212,7 +210,7 @@ def __init__(self, embeddings: str, field: str = None):
def embedding_length(self) -> int:
return self.__embedding_length

@lru_cache(maxsize=10000, typed=False)
@instance_lru_cache(maxsize=10000, typed=False)
def get_cached_vec(self, word: str) -> torch.Tensor:
if word in self.precomputed_word_embeddings:
word_embedding = self.precomputed_word_embeddings[word]
Expand Down Expand Up @@ -1222,7 +1220,7 @@ def __init__(self, embeddings: str, use_local: bool = True, field: str = None):
def embedding_length(self) -> int:
return self.__embedding_length

@lru_cache(maxsize=10000, typed=False)
@instance_lru_cache(maxsize=10000, typed=False)
def get_cached_vec(self, word: str) -> torch.Tensor:
try:
word_embedding = self.precomputed_word_embeddings[word]
Expand Down Expand Up @@ -1430,7 +1428,7 @@ def __init__(self, ):
self.language_embeddings = {}
super().__init__()

@lru_cache(maxsize=10000, typed=False)
@instance_lru_cache(maxsize=10000, typed=False)
def get_cached_vec(self, language_code: str, word: str) -> torch.Tensor:
current_embedding_model = self.language_embeddings[language_code]
if word in current_embedding_model:
Expand Down
12 changes: 12 additions & 0 deletions flair/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import shutil
import tempfile
import re
import functools
from urllib.parse import urlparse

import mmap
Expand Down Expand Up @@ -319,3 +320,14 @@ def tqdm(*args, **kwargs):
new_kwargs = {"mininterval": Tqdm.default_mininterval, **kwargs}

return _tqdm(*args, **new_kwargs)

def instance_lru_cache(*cache_args, **cache_kwargs):
def decorator(func):
@functools.wraps(func)
def create_cache(self, *args, **kwargs):
instance_cache = functools.lru_cache(*cache_args, **cache_kwargs)(func)
instance_cache = instance_cache.__get__(self, self.__class__)
setattr(self, func.__name__, instance_cache)
return instance_cache(*args, **kwargs)
return create_cache
return decorator