-
Notifications
You must be signed in to change notification settings - Fork 514
/
base.py
648 lines (482 loc) · 19.8 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
"""
Copyright 2019 Goldman Sachs.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
"""
import builtins
import copy
import datetime as dt
import logging
from abc import ABC, ABCMeta, abstractmethod
from collections import namedtuple
from dataclasses import Field, InitVar, MISSING, dataclass, field, fields, replace
from enum import EnumMeta
from functools import update_wrapper
from typing import Iterable, Mapping, Optional, Union, Tuple
import numpy as np
from dataclasses_json import config, global_config, LetterCase, dataclass_json
from dataclasses_json.core import _decode_generic, _is_supported_generic
from inflection import camelize, underscore
from gs_quant.context_base import ContextBase, ContextMeta
from gs_quant.json_convertors import encode_date_or_str, decode_date_or_str, decode_optional_date, encode_datetime, \
decode_datetime, decode_float_or_str, decode_instrument, encode_dictable, decode_quote_report, decode_quote_reports, \
decode_custom_comment, decode_custom_comments
_logger = logging.getLogger(__name__)
__builtins = set(dir(builtins))
__getattribute__ = object.__getattribute__
__setattr__ = object.__setattr__
_rename_cache = {}
def exclude_none(o):
return o is None
def exclude_always(_o):
return True
def is_iterable(o, t):
return isinstance(o, Iterable) and all(isinstance(it, t) for it in o)
def is_instance_or_iterable(o, t):
return isinstance(o, t) or is_iterable(o, t)
def _get_underscore(arg):
if arg not in _rename_cache:
_rename_cache[arg] = underscore(arg)
return _rename_cache[arg]
def handle_camel_case_args(cls):
init = cls.__init__
def wrapper(self, *args, **kwargs):
normalised_kwargs = {}
for arg, value in kwargs.items():
if not arg.isupper():
snake_case_arg = _get_underscore(arg)
if snake_case_arg != arg and snake_case_arg in kwargs:
raise ValueError('{} and {} both specified'.format(arg, snake_case_arg))
arg = snake_case_arg
arg = cls._field_mappings().get(arg, arg)
normalised_kwargs[arg] = value
return init(self, *args, **normalised_kwargs)
cls.__init__ = update_wrapper(wrapper=wrapper, wrapped=init)
return cls
field_metadata = config(exclude=exclude_none)
name_metadata = config(exclude=exclude_always)
class RiskKey(namedtuple('RiskKey', ('provider', 'date', 'market', 'params', 'scenario', 'risk_measure'))):
@property
def ex_measure(self):
from gs_quant.target.common import RiskRequestParameters
return RiskKey(self.provider, self.date, self.market,
RiskRequestParameters(self.params.csa_term, self.params.raw_results, False,
self.params.market_behaviour),
self.scenario, None)
@property
def ex_historical_diddle(self):
from gs_quant.target.common import RiskRequestParameters
return RiskKey(self.provider, self.date, self.market,
RiskRequestParameters(self.params.csa_term, self.params.raw_results, False,
self.params.market_behaviour),
self.scenario, self.risk_measure)
@property
def fields(self):
return self._fields
class EnumBase:
@classmethod
def _missing_(cls: EnumMeta, key):
if not isinstance(key, str):
key = str(key)
return next((m for m in cls.__members__.values() if m.value.lower() == key.lower()), None)
def __reduce_ex__(self, protocol):
return self.__class__, (self.value,)
def __lt__(self: EnumMeta, other):
return self.value < other.value
def __repr__(self):
return str(self)
def __str__(self):
return self.value
class HashableDict(dict):
@staticmethod
def hashables(in_dict) -> Tuple:
hashables = []
for it in in_dict.items():
if isinstance(it[1], dict):
hashables.append((it[0], HashableDict.hashables(it[1])))
else:
hashables.append(it)
return tuple(hashables)
def __hash__(self):
return hash(HashableDict.hashables(self))
class DictBase(HashableDict):
_PROPERTIES = set()
def __init__(self, *args, **kwargs):
if self._PROPERTIES:
invalid_arg = next((k for k in kwargs.keys() if k not in self._PROPERTIES), None)
if invalid_arg is not None:
raise AttributeError(f"'{self.__class__.__name__}' has no attribute '{invalid_arg}'")
super().__init__(*args, **{camelize(k, uppercase_first_letter=False): v for k, v in kwargs.items()
if v is not None})
def __getitem__(self, item):
return super().__getitem__(camelize(item, uppercase_first_letter=False))
def __setitem__(self, key, value):
if value is not None:
return super().__setitem__(camelize(key, uppercase_first_letter=False), value)
def __getattr__(self, item):
if self._PROPERTIES:
if _get_underscore(item) in self._PROPERTIES:
return self.get(item)
elif item in self:
return self[item]
raise AttributeError(f"'{self.__class__.__name__}' has no attribute '{item}'")
def __setattr__(self, key, value):
if key in dir(self):
return super().__setattr__(key, value)
elif self._PROPERTIES and _get_underscore(key) not in self._PROPERTIES:
raise AttributeError(f"'{self.__class__.__name__}' has no attribute '{key}'")
self[key] = value
@classmethod
def properties(cls) -> set:
return cls._PROPERTIES
class Base(ABC):
"""The base class for all generated classes"""
__fields_by_name = None
__field_mappings = None
def __getattr__(self, item):
fields_by_name = __getattribute__(self, '_fields_by_name')()
if item.startswith('_') or item in fields_by_name:
return __getattribute__(self, item)
# Handle setting via camelCase names (legacy behaviour) and field mappings from disallowed names
snake_case_item = _get_underscore(item)
field_mappings = __getattribute__(self, '_field_mappings')()
snake_case_item = field_mappings.get(snake_case_item, snake_case_item)
try:
return __getattribute__(self, snake_case_item)
except AttributeError:
return __getattribute__(self, item)
def __setattr__(self, key, value):
# Handle setting via camelCase names (legacy behaviour)
snake_case_key = _get_underscore(key)
snake_case_key = self._field_mappings().get(snake_case_key, snake_case_key)
fld = self._fields_by_name().get(snake_case_key)
if fld:
if not fld.init:
raise ValueError(f'{key} cannot be set')
key = snake_case_key
value = self.__coerce_value(fld.type, value)
__setattr__(self, key, value)
def __repr__(self):
if self.name is not None:
return f'{self.name} ({self.__class__.__name__})'
return super().__repr__()
@classmethod
def __coerce_value(cls, typ: type, value):
if isinstance(value, np.generic):
# Handle numpy types
return value.item()
elif hasattr(value, 'tolist'):
# tolist converts scalar or array to native python type if not already native.
return value()
elif typ in (DictBase, Optional[DictBase]) and isinstance(value, Base):
return value.to_dict()
if _is_supported_generic(typ):
return _decode_generic(typ, value, False)
else:
return value
@classmethod
def _fields_by_name(cls) -> Mapping[str, Field]:
if cls is Base:
return {}
if cls.__fields_by_name is None:
cls.__fields_by_name = {f.name: f for f in fields(cls)}
return cls.__fields_by_name
@classmethod
def _field_mappings(cls) -> Mapping[str, str]:
if cls is Base:
return {}
if cls.__field_mappings is None:
field_mappings = {}
for fld in fields(cls):
config_fn = fld.metadata.get('dataclasses_json', {}).get('letter_case')
if config_fn:
mapped_name = config_fn('field_name')
if mapped_name:
field_mappings[mapped_name] = fld.name
cls.__field_mappings = field_mappings
return cls.__field_mappings
def clone(self, **kwargs):
"""
Clone this object, overriding specified values
:param kwargs: property names and values, e.g. swap.clone(fixed_rate=0.01)
**Examples**
To change the market data location of the default context:
>>> from gs_quant.instrument import IRCap
>>> cap = IRCap('5y', 'GBP')
>>>
>>> new_cap = cap.clone(cap_rate=0.01)
"""
return replace(self, **kwargs)
@classmethod
def properties(cls) -> set:
"""The public property names of this class"""
return set(f[:-1] if f[-1] == '_' else f for f in cls._fields_by_name().keys())
@classmethod
def properties_init(cls) -> set:
"""The public property names of this class"""
return set(f[:-1] if f[-1] == '_' else f for f, v in cls._fields_by_name().items() if v.init)
def as_dict(self, as_camel_case: bool = False) -> dict:
"""Dictionary of the public, non-null properties and values"""
# to_dict() converts all the values to JSON type, does camel case and name mappings
# asdict() does not convert values or case of the keys or do name mappings
ret = {}
field_mappings = {v: k for k, v in self._field_mappings().items()}
for key in self.__fields_by_name.keys():
value = __getattribute__(self, key)
key = field_mappings.get(key, key)
if value is not None:
if as_camel_case:
key = camelize(key, uppercase_first_letter=False)
ret[key] = value
return ret
@classmethod
def default_instance(cls):
"""
Construct a default instance of this type
"""
required = {f.name: None if f.default == MISSING else f.default for f in fields(cls) if f.init}
return cls(**required)
def from_instance(self, instance):
"""
Copy the values from an existing instance of the same type to our self
:param instance: from which to copy:
:return:
"""
if not isinstance(instance, type(self)):
raise ValueError('Can only use from_instance with an object of the same type')
for fld in fields(self.__class__):
if fld.init:
__setattr__(self, fld.name, __getattribute__(instance, fld.name))
@dataclass_json
@dataclass
class Priceable(Base):
def resolve(self, in_place: bool = True):
"""
Resolve non-supplied properties of an instrument
**Examples**
>>> from gs_quant.instrument import IRSwap
>>>
>>> swap = IRSwap('Pay', '10y', 'USD')
>>> rate = swap.fixedRate
rate is None
>>> swap.resolve()
>>> rate = swap.fixedRate
rates is now the solved fixed rate
"""
raise NotImplementedError
def dollar_price(self):
"""
Present value in USD
:return: a float or a future, depending on whether the current PricingContext is async, or has been entered
**Examples**
>>> from gs_quant.instrument import IRCap
>>>
>>> cap = IRCap('1y', 'EUR')
>>> price = cap.dollar_price()
price is the present value in USD (a float)
>>> cap_usd = IRCap('1y', 'USD')
>>> cap_eur = IRCap('1y', 'EUR')
>>>
>>> from gs_quant.markets import PricingContext
>>>
>>> with PricingContext():
>>> price_usd_f = cap_usd.dollar_price()
>>> price_eur_f = cap_eur.dollar_price()
>>>
>>> price_usd = price_usd_f.result()
>>> price_eur = price_eur_f.result()
price_usd_f and price_eur_f are futures, price_usd and price_eur are floats
"""
raise NotImplementedError
def price(self):
"""
Present value in local currency. Note that this is not yet supported on all instruments
***Examples**
>>> from gs_quant.instrument import IRSwap
>>>
>>> swap = IRSwap('Pay', '10y', 'EUR')
>>> price = swap.price()
price is the present value in EUR (a float)
"""
raise NotImplementedError
def calc(self, risk_measure, fn=None):
"""
Calculate the value of the risk_measure
:param risk_measure: the risk measure to compute, e.g. IRDelta (from gs_quant.risk)
:param fn: a function for post-processing results
:return: a float or dataframe, depending on whether the value is scalar or structured, or a future thereof
(depending on how PricingContext is being used)
**Examples**
>>> from gs_quant.instrument import IRCap
>>> from gs_quant.risk import IRDelta
>>>
>>> cap = IRCap('1y', 'USD')
>>> delta = cap.calc(IRDelta)
delta is a dataframe
>>> from gs_quant.instrument import EqOption
>>> from gs_quant.risk import EqDelta
>>>
>>> option = EqOption('.SPX', '3m', 'ATMF', 'Call', 'European')
>>> delta = option.calc(EqDelta)
delta is a float
>>> from gs_quant.markets import PricingContext
>>>
>>> cap_usd = IRCap('1y', 'USD')
>>> cap_eur = IRCap('1y', 'EUR')
>>> with PricingContext():
>>> usd_delta_f = cap_usd.calc(IRDelta)
>>> eur_delta_f = cap_eur.calc(IRDelta)
>>>
>>> usd_delta = usd_delta_f.result()
>>> eur_delta = eur_delta_f.result()
usd_delta_f and eur_delta_f are futures, usd_delta and eur_delta are dataframes
"""
raise NotImplementedError
class __ScenarioMeta(ABCMeta, ContextMeta):
pass
@dataclass
class Scenario(Base, ContextBase, ABC, metaclass=__ScenarioMeta):
def __lt__(self, other):
if self.__repr__ != other.__repr__:
return self.name < other.name
return False
def __repr__(self):
if self.name:
return self.name
else:
params = self.as_dict()
sorted_keys = sorted(params.keys(), key=lambda x: x.lower())
params = ', '.join(
[f'{k}:{params[k].__repr__ if isinstance(params[k], Base) else params[k]}' for k in sorted_keys])
return self.scenario_type + '(' + params + ')'
@dataclass
class RiskMeasureParameter(Base, ABC):
pass
@dataclass
class InstrumentBase(Base, ABC):
quantity_: InitVar[float] = field(default=1, init=False)
@property
@abstractmethod
def provider(self):
...
@property
def instrument_quantity(self) -> float:
return self.quantity_
@property
def resolution_key(self) -> Optional[RiskKey]:
try:
return self.__resolution_key
except AttributeError:
return None
@property
def unresolved(self):
try:
return self.__unresolved
except AttributeError:
return None
@property
def metadata(self):
try:
return self.__metadata
except AttributeError:
return None
@metadata.setter
def metadata(self, value):
self.__metadata = value
def from_instance(self, instance):
self.__resolution_key = None
super().from_instance(instance)
self.__unresolved = instance.__unresolved
self.__resolution_key = instance.__resolution_key
def resolved(self, values: dict, resolution_key: RiskKey):
all_values = self.as_dict(True)
all_values.update(values)
new_instrument = self.from_dict(all_values)
new_instrument.name = self.name
new_instrument.__unresolved = copy.copy(self)
new_instrument.__resolution_key = resolution_key
return new_instrument
def clone(self, **kwargs):
new_instrument = super().clone(**kwargs)
new_instrument.__unresolved = self.unresolved
new_instrument.metadata = self.metadata
new_instrument.__resolution_key = self.resolution_key
return new_instrument
@dataclass
class Market(ABC):
def __hash__(self):
return hash(self.market or self.location)
def __eq__(self, other):
return (self.market or self.location) == (other.market or other.location)
def __lt__(self, other):
return repr(self) < repr(other)
@property
@abstractmethod
def market(self):
...
@property
@abstractmethod
def location(self):
...
def to_dict(self):
return self.market.to_dict()
class Sentinel:
def __init__(self, name: str):
self.__name = name
def __eq__(self, other):
return self.__name == other.__name
@dataclass
class QuoteReport(Base, ABC):
pass
@dataclass
class CustomComments(Base, ABC):
pass
def get_enum_value(enum_type: EnumMeta, value: Union[EnumBase, str]):
if value in (None,):
return None
if isinstance(value, enum_type):
return value
try:
enum_value = enum_type(value)
except ValueError:
_logger.warning('Setting value to {}, which is not a valid entry in {}'.format(value, enum_type))
enum_value = value
return enum_value
@handle_camel_case_args
@dataclass_json(letter_case=LetterCase.CAMEL)
@dataclass(unsafe_hash=True, repr=False)
class MarketDataScenario(Base):
scenario: Scenario = field(default=None, metadata=field_metadata)
subtract_base: Optional[bool] = field(default=False, metadata=field_metadata)
name: Optional[str] = field(default=None, metadata=name_metadata)
# Yes, I know this is a little evil ...
global_config.encoders[dt.date] = dt.date.isoformat
global_config.encoders[Optional[dt.date]] = encode_date_or_str
global_config.decoders[dt.date] = decode_optional_date
global_config.decoders[Optional[dt.date]] = decode_optional_date
global_config.encoders[Union[dt.date, str]] = encode_date_or_str
global_config.encoders[Optional[Union[dt.date, str]]] = encode_date_or_str
global_config.decoders[Union[dt.date, str]] = decode_date_or_str
global_config.decoders[Optional[Union[dt.date, str]]] = decode_date_or_str
global_config.encoders[dt.datetime] = encode_datetime
global_config.encoders[Optional[dt.datetime]] = encode_datetime
global_config.decoders[dt.datetime] = decode_datetime
global_config.decoders[Optional[dt.datetime]] = decode_datetime
global_config.decoders[Union[float, str]] = decode_float_or_str
global_config.decoders[Optional[Union[float, str]]] = decode_float_or_str
global_config.decoders[InstrumentBase] = decode_instrument
global_config.decoders[Optional[InstrumentBase]] = decode_instrument
global_config.decoders[QuoteReport] = decode_quote_report
global_config.decoders[Optional[Tuple[QuoteReport, ...]]] = decode_quote_reports
global_config.decoders[CustomComments] = decode_custom_comment
global_config.decoders[Optional[Tuple[CustomComments, ...]]] = decode_custom_comments
global_config.encoders[Market] = encode_dictable
global_config.encoders[Optional[Market]] = encode_dictable