-
Notifications
You must be signed in to change notification settings - Fork 245
/
lazy_entity.py
67 lines (54 loc) · 2.12 KB
/
lazy_entity.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import typing
from threading import Lock
from flytekit import FlyteContext
from flytekit.remote.remote_callable import RemoteEntity
T = typing.TypeVar("T", bound=RemoteEntity)
class LazyEntity(RemoteEntity, typing.Generic[T]):
"""
Fetches the entity when the entity is called or when the entity is retrieved.
The entity is derived from RemoteEntity so that it behaves exactly like the mimiced entity.
"""
def __init__(self, name: str, getter: typing.Callable[[], T], *args, **kwargs):
super().__init__(*args, **kwargs)
self._entity = None
self._getter = getter
self._name = name
if not self._getter:
raise ValueError("getter method is required to create a Lazy loadable Remote Entity.")
self._mutex = Lock()
@property
def name(self) -> str:
return self._name
def entity_fetched(self) -> bool:
with self._mutex:
return self._entity is not None
@property
def entity(self) -> T:
"""
If not already fetched / available, then the entity will be force fetched.
"""
with self._mutex:
if self._entity is None:
try:
self._entity = self._getter()
except AttributeError as e:
raise RuntimeError(
f"Error downloading the entity {self._name}, (check original exception...)"
) from e
return self._entity
def __getattr__(self, item: str) -> typing.Any:
"""
Forwards all other attributes to entity, causing the entity to be fetched!
"""
return getattr(self.entity, item)
def compile(self, ctx: FlyteContext, *args, **kwargs):
return self.entity.compile(ctx, *args, **kwargs)
def __call__(self, *args, **kwargs):
"""
Forwards the call to the underlying entity. The entity will be fetched if not already present
"""
return self.entity(*args, **kwargs)
def __repr__(self) -> str:
return str(self)
def __str__(self) -> str:
return f"Promise for entity [{self._name}]"