/
simpler.py
135 lines (99 loc) · 3.34 KB
/
simpler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# Simpler context vars.
import threading
from typing import *
# Type variables.
T = TypeVar('T') # A type
KT = TypeVar('KT') # A key type
VT = TypeVar('VT') # A value type
# Fake thread state so the code works.
class ThreadState(threading.local): # type: ignore # See https://github.com/python/typeshed/issues/1591
"""Dummy corresponding to PyThreadState.
This implementation is actually thread-local!
"""
ctx: Optional['Context'] = None
_ts = ThreadState()
def get_TS() -> ThreadState:
"""Return current ThreadState."""
return _ts
# Get and set current thread's Context.
def get_ctx() -> 'Context':
"""Return current thread's context (creating it if necessary)."""
ts = get_TS()
if ts.ctx is None:
ts.ctx = Context()
return ts.ctx
def _set_ctx(ctx: 'Context') -> None:
"""Set current thread's context."""
ts = get_TS()
ts.ctx = ctx
_no_default: Any = object()
class Token(Generic[T]):
_cv: ContextVar[T]
_orig: T
def __init__(self, orig: T = _no_default) -> None:
self._orig = orig
class ContextVar(Generic[T]):
"""Context variable."""
def __init__(self, name: str, *, default: T = _no_default) -> None:
self._name = name
self._default = default
@property
def name(self) -> str:
return self._name
@property
def default(self) -> T:
return self._default
# Methods that take the current context into account.
def get(self, default: T = _no_default) -> T:
"""Return current value."""
ctx: 'Context' = get_ctx()
if self in ctx:
value: T = ctx[self]
return value
if default is not _no_default:
return default
if self._default is not _no_default:
return self._default
raise LookupError
def set(self, value: T) -> Token[T]:
"""Overwrite current value."""
ctx: 'Context' = get_ctx()
if self in ctx:
orig = _no_default
else:
orig = ctx[self]
ctx._setitem(self, value)
return Token(orig)
def reset(self, t: Token[T]) -> None:
"""Restore state as it was when set() returned t."""
ctx = get_ctx()
if t._orig is _no_default:
ctx._delitem(self)
else:
ctx._setitem(self, t._orig)
class AbstractContext(Mapping[KT, VT]):
# The mapping is mutable through private methods.
def __init__(self, d: Mapping[KT, VT] = {}) -> None:
self.__d = dict(d) # Maybe a weakkeydict?
def __getitem__(self, key: KT) -> VT:
return self.__d[key]
def _setitem(self, key: KT, value: VT) -> None:
self.__d[key] = value
def _delitem(self, key: KT) -> None:
del self.__d[key]
def __len__(self) -> int:
return len(self.__d)
def __iter__(self) -> Iterator[KT]:
return iter(self.__d)
def __contains__(self, key: object) -> bool:
return key in self.__d
# For other methods, the defaults in MutableMapping suffice.
class Context(AbstractContext[ContextVar, Any]):
# Externally this is only supposed to subclass (immutable) Mapping.
def run(self, func: Callable[..., T], *args: Any, **kwds: Any) -> T:
saved = get_ctx()
try:
_set_ctx(self)
return func(*args, **kwds)
finally:
_set_ctx(saved)