/
context.py
168 lines (130 loc) · 5.1 KB
/
context.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import threading
import json
import yaml
from typing import Union, List
from collections import UserDict
from .registry import registered_names, get_major_component_names
class TradingContext(UserDict):
"""A class that objects that put themselves in a `Context` using
the `with` statement.
This implementation for this class is heavily borrowed from the pymc3
library and adapted with the design goals of TensorTrade in mind.
Arguments:
shared: A context that is shared between all components that are made under the overarching `TradingContext`.
exchanges: A context that is specific to components with a registered name of `exchanges`.
actions: A context that is specific to components with a registered name of `actions`.
rewards: A context that is specific to components with a registered name of `rewards`.
features: A context that is specific to components with a registered name of `features`.
Warnings:
If there is a conflict in the contexts of different components because
they were initialized under different contexts, can have undesirable effects.
Therefore, a warning should be made to the user indicating that using
components together that have conflicting contexts can lead to unwanted
behavior.
Reference:
https://github.com/pymc-devs/pymc3/blob/master/pymc3/model.py
"""
contexts = threading.local()
def __init__(self,
base: str = 'USD',
instruments: Union[str, List[str]] = 'BTC',
**config):
super().__init__(
base=base,
instruments=instruments,
**config
)
if type(instruments) == str:
instruments = [instruments]
for name in registered_names():
if name not in get_major_component_names():
setattr(self, name, config.get(name, {}))
config_items = {k: config[k] for k in config.keys()
if k not in registered_names()}
self._shared = config.get('shared', {})
self._exchanges = config.get('exchanges', {})
self._actions = config.get('actions', {})
self._rewards = config.get('rewards', {})
self._features = config.get('features', {})
self._slippage = config.get('slippage', {})
self._shared = {
'base': base,
'instruments': instruments,
**self._shared,
**config_items
}
@property
def shared(self) -> dict:
return self._shared
@property
def exchanges(self) -> dict:
return self._exchanges
@property
def actions(self) -> dict:
return self._actions
@property
def rewards(self) -> dict:
return self._rewards
@property
def features(self) -> dict:
return self._features
@property
def slippage(self) -> dict:
return self._slippage
def __enter__(self):
"""Adds a new context to the context stack.
This method is used for a `with` statement and adds a `TradingContext`
to the context stack. The new context on the stack is then used by every
class that subclasses `Component` the initialization of its instances.
"""
type(self).get_contexts().append(self)
return self
def __exit__(self, typ, value, traceback):
type(self).get_contexts().pop()
@classmethod
def get_contexts(cls):
if not hasattr(cls.contexts, 'stack'):
cls.contexts.stack = [TradingContext()]
return cls.contexts.stack
@classmethod
def get_context(cls):
"""Gets the deepest context on the stack."""
return cls.get_contexts()[-1]
@classmethod
def from_json(cls, path: str):
with open(path, "rb") as fp:
config = json.load(fp)
return TradingContext(**config)
@classmethod
def from_yaml(cls, path: str):
with open(path, "rb") as fp:
config = yaml.load(fp, Loader=yaml.FullLoader)
return TradingContext(**config)
class Context(UserDict):
"""A context that is injected into every instance of a class that is
a subclass of component.
Arguments:
base_instrument: The exchange symbol of the instrument to store/measure value in.
instruments: The exchange symbols of the instruments being traded.
"""
def __init__(self,
base_instrument: str = 'USD',
instruments: Union[str, List[str]] = 'BTC',
**kwargs):
super(Context, self).__init__(
base_instrument=base_instrument,
instruments=instruments,
**kwargs
)
self._base_instrument = base_instrument
self._instruments = instruments
self.__dict__ = {**self.__dict__, **self.data}
@property
def base(self):
return self._base
@property
def instruments(self):
return self._instruments
def __str__(self):
data = ['{}={}'.format(k, getattr(self, k)) for k in self.__slots__]
return '<{}: {}>'.format(self.__class__.__name__, ', '.join(data))