/
semantics.py
1524 lines (1251 loc) · 54.5 KB
/
semantics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from __future__ import division
import datetime
import itertools
import math
import re
from abc import ABCMeta, abstractmethod
from collections import namedtuple
import scipy
import scipy.linalg
import six
from dateutil.relativedelta import relativedelta
from scipy import ndarray
from quantdsl.domain.model.simulated_price import make_simulated_price_id
from quantdsl.domain.services.uuids import create_uuid4
from quantdsl.exceptions import DslError, DslNameError, DslSyntaxError, DslSystemError
from quantdsl.priceprocess.base import get_duration_years
class DslObject(six.with_metaclass(ABCMeta)):
"""
Base class for DSL language objects.
Responsible for maintaining reference to original AST (for error reporting),
and for rendering objects into valid DSL source code. Also has methods for
validating object arguments, and finding child nodes of a particular type.
"""
def __init__(self, *args, **kwds):
self.node = kwds.pop('node', None)
self.validate(args)
self._args = list(args)
self._hash = None
def __str__(self):
"""
Returns DSL source code, that can be parsed to generate a clone of self.
"""
return self.pprint()
# Todo: More tests that this round trip actually works.
def pprint(self, indent=''):
"""Returns Quant DSL source code for the DSL object."""
msg = self.__class__.__name__ + "("
lenArgs = len(self._args)
if lenArgs > 1:
msg += "\n"
tab = 4
indent += ' ' * tab
for i, arg in enumerate(self._args):
if lenArgs > 1:
msg += indent
if isinstance(arg, DslObject):
msg += arg.pprint(indent)
else:
msg += str(arg)
if i < lenArgs - 1:
msg += ","
if lenArgs > 1:
msg += "\n"
indent = indent[:-tab]
if lenArgs > 1:
msg += indent
msg += ")"
return msg
@property
def hash(self):
"""
Creates a hash that is unique for this fragment of DSL.
"""
if self._hash is None:
hashes = ""
for arg in self._args:
if isinstance(arg, list):
arg = tuple(arg)
hashes += str(hash(arg))
self._hash = hash(hashes)
return self._hash
def __hash__(self):
return self.hash
@abstractmethod
def validate(self, args):
"""
Raises an exception if the object's args are not valid.
"""
# Todo: Rework validation, perhaps by considering a declarative form in which to express the requirements.
def assert_args_len(self, args, required_len=None, min_len=None, max_len=None):
if min_len != None and len(args) < min_len:
error = "%s is broken" % self.__class__.__name__
descr = "requires at least %s arguments (%s were given)" % (min_len, len(args))
raise DslSyntaxError(error, descr, self.node)
if max_len != None and len(args) > max_len:
error = "%s is broken" % self.__class__.__name__
descr = "requires at most %s arguments (%s were given)" % (max_len, len(args))
raise DslSyntaxError(error, descr, self.node)
if required_len != None and len(args) != required_len:
error = "%s is broken" % self.__class__.__name__
descr = "requires %s arguments (%s were given)" % (required_len, len(args))
raise DslSyntaxError(error, descr, self.node)
def assert_args_arg(self, args, posn, required_type):
if isinstance(required_type, list):
# Ahem, this is a way of saying we require a list of the type (should be a list length 1).
self.assert_args_arg(args, posn, list)
assert len(required_type) == 1, "List def should only have one item."
required_type = required_type[0]
list_of_args = args[posn]
for i in range(len(list_of_args)):
self.assert_args_arg(list_of_args, i, required_type)
elif not isinstance(args[posn], required_type):
error = "%s is broken" % self.__class__.__name__
if isinstance(required_type, (list, tuple)):
required_type_names = [i.__name__ for i in required_type]
required_type_names = ", ".join(required_type_names)
else:
required_type_names = required_type.__name__
desc = "argument %s must be %s" % (posn, required_type_names)
desc += " (but a %s was found): " % (args[posn].__class__.__name__)
desc += str(args[posn])
raise DslSyntaxError(error, desc, self.node)
def list_instances(self, dsl_type):
return list(self.find_instances(dsl_type))
def has_instances(self, dsl_type):
for i in self.find_instances(dsl_type):
return True
else:
return False
def find_instances(self, dsl_type):
if isinstance(self, dsl_type):
yield self
for arg in self._args:
if isinstance(arg, DslObject):
for dsl_obj in arg.find_instances(dsl_type):
yield dsl_obj
# elif isinstance(arg, list):
# for arg in arg:
# if isinstance(arg, DslObject):
# for dsl_obj in arg.list_instances(dsl_type):
# yield dsl_obj
def reduce(self, dsl_locals, dsl_globals, effective_present_time=None, pending_call_stack=None):
"""
Reduces by reducing all args, and then using those args
to create a new instance of self.
"""
new_dsl_args = []
for dsl_arg in self._args:
if isinstance(dsl_arg, DslObject):
dsl_arg = dsl_arg.reduce(dsl_locals, dsl_globals, effective_present_time,
pending_call_stack=pending_call_stack)
new_dsl_args.append(dsl_arg)
return self.__class__(node=self.node, *new_dsl_args)
def identify_price_simulation_requirements(self, requirements, **kwds):
for dsl_arg in self._args:
if isinstance(dsl_arg, DslObject):
dsl_arg.identify_price_simulation_requirements(requirements, **kwds)
def identify_perturbation_dependencies(self, dependencies, **kwds):
for dsl_arg in self._args:
if isinstance(dsl_arg, DslObject):
dsl_arg.identify_perturbation_dependencies(dependencies, **kwds)
class DslExpression(DslObject):
@abstractmethod
def evaluate(self, **kwds):
"""
Returns the value of the expression.
"""
def discount(self, value, date, **kwds):
r = float(kwds['interest_rate']) / 100
T = get_duration_years(kwds['present_time'], date)
return value * math.exp(- r * T)
class DslConstant(DslExpression):
required_type = None
def pprint(self, indent=''):
return repr(self.value)
def validate(self, args):
self.assert_args_len(args, required_len=1)
assert self.required_type is not None, "required_type attribute not set on %s" % self.__class__
self.assert_args_arg(args, posn=0, required_type=self.required_type)
self.parse(args[0])
@property
def value(self):
if not hasattr(self, '_value'):
self._value = self.parse(self._args[0])
return self._value
def evaluate(self, **_):
return self.value
def parse(self, value):
return value
def __eq__(self, other):
return self.value == other.value
def __ne__(self, other):
return not self.__eq__(other)
class String(DslConstant):
required_type = six.string_types
class Number(DslConstant):
required_type = six.integer_types + (float, ndarray)
class Date(DslConstant):
required_type = six.string_types + (String, datetime.date, datetime.datetime)
def pprint(self, indent=''):
return "Date('%04d-%02d-%02d')" % (self.value.year, self.value.month, self.value.day)
def parse(self, value):
# Return a datetime.datetime.
if isinstance(value, (six.string_types, String)):
if isinstance(value, String):
date_str = value.evaluate()
else:
date_str = value
try:
year, month, day = [int(i) for i in date_str.split('-')]
return datetime.datetime(year, month, day)
# return dateutil.parser.parse(date_str).replace()
except ValueError:
raise DslSyntaxError("invalid date string", date_str, node=self.node)
elif isinstance(value, datetime.datetime):
return value
elif isinstance(value, datetime.date):
return datetime.datetime(value.year, value.month, value.day)
class TimeDelta(DslConstant):
required_type = (String, datetime.timedelta, relativedelta)
def pprint(self, indent=''):
return "{}({})".format(self.__class__.__name__, self._args[0])
def parse(self, value, regex=re.compile(r'((?P<days>\d+?)d|(?P<months>\d+?)m|(?P<years>\d+?)y)?')):
if isinstance(value, String):
duration_str = value.evaluate()
parts = regex.match(duration_str)
parts = parts.groupdict()
params = dict((name, int(param)) for (name, param) in six.iteritems(parts) if param)
if not params:
raise DslSyntaxError('invalid "time delta" string', duration_str, node=self.node)
return relativedelta(**params)
elif isinstance(value, datetime.timedelta):
return value
elif isinstance(value, relativedelta):
return value
else:
raise DslSystemError("shouldn't get here", value, node=self.node)
class UnaryOp(DslExpression):
opchar = None
def pprint(self, indent=''):
return str(self.opchar) + str(self.operand)
def validate(self, args):
self.assert_args_len(args, required_len=1)
self.assert_args_arg(args, posn=0, required_type=DslExpression)
@property
def operand(self):
return self._args[0]
def evaluate(self, **kwds):
return self.op(self.operand.evaluate(**kwds))
@abstractmethod
def op(self, value):
"""
Returns the result of operating on the given value.
"""
class UnarySub(UnaryOp):
opchar = '-'
def op(self, value):
return -value
class BoolOp(DslExpression):
def validate(self, args):
self.assert_args_len(args, required_len=1)
self.assert_args_arg(args, posn=0, required_type=list)
@property
def values(self):
return self._args[0]
def evaluate(self, **kwds):
len_values = len(self.values)
assert len_values >= 2
for dsl_expr in self.values:
# assert isinstance(dsl_expr, DslExpression)
value = dsl_expr.evaluate(**kwds)
if self.op(value):
return self.op(True)
return self.op(False)
@abstractmethod
def op(self, value):
"""
Returns value, or not value, according to implementation.
"""
def pprint(self, indent=''):
operator = self.__class__.__name__.lower()
padded = ' ' + operator + ' '
text = padded.join([str(i) for i in self._args[0]])
return indent + '(' + text + ')'
class Or(BoolOp):
def op(self, value):
return value
class And(BoolOp):
def op(self, value):
return not value
class BinOp(DslExpression):
opchar = ''
@abstractmethod
def op(self, left, right):
"""
Returns result of operating on two args.
"""
def pprint(self, indent=''):
if self.opchar:
def makeStr(dsl_expr):
dslString = str(dsl_expr)
if isinstance(dsl_expr, BinOp):
dslString = "(" + dslString + ")"
return dslString
text = makeStr(self.left) + " " + self.opchar + " " + makeStr(self.right)
else:
text = '%s(%s, %s)' % (self.__class__.__name__, self.left, self.right)
return indent + text
def validate(self, args):
self.assert_args_len(args, required_len=2)
self.assert_args_arg(args, posn=0, required_type=(DslExpression, Date, TimeDelta, Underlying))
self.assert_args_arg(args, posn=1, required_type=(DslExpression, Date, TimeDelta, Underlying))
@property
def left(self):
return self._args[0]
@property
def right(self):
return self._args[1]
def evaluate(self, **kwds):
left = self.left.evaluate(**kwds)
right = self.right.evaluate(**kwds)
try:
return self.op(left, right)
except TypeError as e:
raise DslSyntaxError("unable to %s" % self.__class__.__name__.lower(), "%s %s: %s" % (left, right, e),
node=self.node)
class Add(BinOp):
opchar = '+'
def op(self, left, right):
return left + right
class Sub(BinOp):
opchar = '-'
def op(self, left, right):
return left - right
class Mult(BinOp):
opchar = '*'
def op(self, left, right):
return left * right
class Div(BinOp):
opchar = '/'
def op(self, left, right):
return left / right
# Todo: Pow, Mod, FloorDiv don't have proofs, so shouldn't really be used for combining random variables? Either
# prevent usage with ndarray inputs, or do the proofs. :-)
class Pow(BinOp):
opchar = '**'
def op(self, left, right):
return left ** right
class Mod(BinOp):
opchar = '%'
def op(self, left, right):
return left % right
class FloorDiv(BinOp):
opchar = '//'
def op(self, left, right):
return left // right
class NonInfixedBinOp(BinOp):
def op(self, a, b):
# Assume a and b have EITHER type ndarray, OR type int or float.
# Try to 'balance' the sides.
# - two scalar numbers are good
# - one number with one vector is okay
# - two vectors is okay, but they must have the same length.
aIsaNumber = isinstance(a, (int, float))
bIsaNumber = isinstance(b, (int, float))
if aIsaNumber and bIsaNumber:
# Neither are vectors.
return self.scalar_op(a, b)
elif (not aIsaNumber) and (not bIsaNumber):
# Both are vectors.
msg = "Vectors have different length: %s and %s" % (len(a), len(b))
assert len(a) == len(b), msg
elif aIsaNumber and (not bIsaNumber):
# Todo: Optimise with scipy.zeros() when a equals zero?
a = scipy.array([a] * len(b))
elif bIsaNumber and (not aIsaNumber):
# Todo: Optimise with scipy.zeros() when b equals zero?
b = scipy.array([b] * len(a))
return self.vector_op(a, b)
@abstractmethod
def vector_op(self, a, b):
"""Computes result of operation on vector values."""
@abstractmethod
def scalar_op(self, a, b):
"""Computes result of operation on scalar values."""
class Min(NonInfixedBinOp):
def vector_op(self, a, b):
return scipy.array([a, b]).min(axis=0)
def scalar_op(self, a, b):
return min(a, b)
class Max(NonInfixedBinOp):
def vector_op(self, a, b):
return scipy.array([a, b]).max(axis=0)
def scalar_op(self, a, b):
return max(a, b)
class Name(DslExpression):
def pprint(self, indent=''):
return self.name
def validate(self, args):
assert isinstance(args[0], (six.string_types, String)), type(args[0])
@property
def name(self):
"""
Returns a Python string.
"""
name = self._args[0]
if isinstance(name, six.string_types):
return name
elif isinstance(name, String):
return name.evaluate()
def reduce(self, dsl_locals, dsl_globals, effective_present_time=None, pending_call_stack=False):
"""
Replace commodity_name with named value in context (kwds).
"""
combined_namespace = DslNamespace(itertools.chain(dsl_globals.items(), dsl_locals.items()))
value = self.evaluate(**combined_namespace)
if isinstance(value, datetime.date):
return Date(value, node=self.node)
elif isinstance(value, DslObject):
return value
elif isinstance(value, six.integer_types + (float, ndarray)):
return Number(value, node=self.node)
elif isinstance(value, six.string_types):
return String(value, node=self.node)
elif isinstance(value, datetime.timedelta):
return TimeDelta(value, node=self.node)
elif isinstance(value, relativedelta):
return TimeDelta(value, node=self.node)
# elif isinstance(value, (SynchronizedArray, Synchronized)):
# return Number(numpy_from_sharedmem(value), node=self.node)
else:
raise DslSyntaxError("expected number, string, date, time delta, or DSL object when reducing name '%s'"
"" % self.name, repr(value), node=self.node)
def evaluate(self, **kwds):
try:
return kwds[self.name]
except KeyError:
raise DslNameError(
"'%s' is not defined. Current frame defines" % self.name,
kwds.keys() or "None",
node=self.node
)
class Stub(Name):
"""
Stubs are named values. Stubs are used to associate a value in a stubbed expression
with the value of another expression in a dependency graph.
"""
def pprint(self, indent=''):
# Can't just return a Python string, like with Names, because this
# is normally a UUID, and UUIDs are not valid Python variable names
# because they have dashes and sometimes start with numbers.
return "Stub('%s')" % self.name
class Underlying(DslObject):
def validate(self, args):
self.assert_args_len(args, 1)
@property
def expr(self):
return self._args[0]
def evaluate(self, **_):
return self.expr
class FunctionDef(DslObject):
"""
A DSL function def creates DSL expressions when called. They can be defined as
simple or conditionally recursive functions. Loops aren't supported, neither
are assignments.
"""
def pprint(self, indent=''):
msg = ""
for decorator_name in self.decorator_names:
msg += "@" + decorator_name + "\n"
msg += "def %s(%s):\n" % (self.name, ", ".join(self.call_arg_names))
if isinstance(self.body, DslObject):
try:
msg += self.body.pprint(indent=indent + ' ')
except TypeError:
raise DslSystemError("DSL object can't handle indent: %s" % type(self.body))
else:
msg += str(self.body)
return msg
def __init__(self, *args, **kwds):
super(FunctionDef, self).__init__(*args, **kwds)
# Initialise the function call cache for this function def.
self.call_cache = {}
self.enclosed_namespace = DslNamespace()
# Second attempt to implement module namespaces...
self.module_namespace = None
def validate(self, args):
self.assert_args_len(args, required_len=4)
@property
def name(self):
return self._args[0]
@property
def call_arg_names(self):
if not hasattr(self, '_call_arg_names'):
self._call_arg_names = [i.name for i in self._args[1]]
return self._call_arg_names
@property
def callArgs(self):
return self._args[1]
@property
def body(self):
return self._args[2]
@property
def decorator_names(self):
return self._args[3]
def validateCallArgs(self, dsl_locals):
for call_arg_name in self.call_arg_names:
if call_arg_name not in dsl_locals:
raise DslSyntaxError('expected call arg not found',
"arg '%s' not in call arg namespace %s" % (call_arg_name, dsl_locals.keys()))
def apply(self, dsl_globals=None, effective_present_time=None, pending_call_stack=None, is_destacking=False,
**dsl_locals):
# It's a function call, so create a new namespace "context".
if dsl_globals is None:
dsl_globals = DslNamespace()
else:
pass
# assert isinstance(dsl_globals, DslNamespace)
dsl_globals = DslNamespace(itertools.chain(self.enclosed_namespace.items(), self.module_namespace.items(),
dsl_globals.items()))
dsl_locals = DslNamespace(dsl_locals)
# Validate the call args with the definition.
self.validateCallArgs(dsl_locals)
# Create the cache key.
call_cache_key_dict = dsl_locals.copy()
call_cache_key_dict["__effective_present_time__"] = effective_present_time
call_cache_key = self.create_hash(dsl_locals)
# Check the call cache, to see whether this function has already been called with these args.
if not is_destacking and call_cache_key in self.call_cache:
return self.call_cache[call_cache_key]
if pending_call_stack and not is_destacking and not 'inline' in self.decorator_names:
# Just stack the call expression and return a stub.
# Create a new stub - the stub ID is the name of the return value of the function call..
stub_id = create_uuid4()
dsl_stub = Stub(stub_id, node=self.node)
# Put the function call on the call stack, with the stub ID.
# assert isinstance(pending_call_stack, PendingCallQueue)
pending_call_stack.put(
stub_id=stub_id,
stacked_function_def=self,
stacked_locals=dsl_locals.copy(),
stacked_globals=dsl_globals.copy(),
effective_present_time=effective_present_time
)
# Return the stub so that the containing DSL can be fully evaluated
# once the stacked function call has been evaluated.
dsl_expr = dsl_stub
else:
# Todo: Make sure the expression can be selected with the dsl_locals?
# - ie the conditional expressions should be functions only of call arg
# values that can be fully evaluated without evaluating contractual DSL objects.
selected_expression = self.select_expression(self.body, dsl_locals)
# Add this function to the dslNamespace (just in case it's called by itself).
new_dsl_globals = DslNamespace(dsl_globals)
new_dsl_globals[self.name] = self
# Reduce the selected expression.
# assert isinstance(selected_expression, DslExpression)
dsl_expr = selected_expression.reduce(
dsl_locals=dsl_locals,
dsl_globals=new_dsl_globals,
effective_present_time=effective_present_time,
pending_call_stack=pending_call_stack
)
# Cache the result.
if not is_destacking:
self.call_cache[call_cache_key] = dsl_expr
return dsl_expr
def select_expression(self, dsl_expr, call_arg_namespace):
# If the DSL expression is an instance of If, then evaluate
# the test and accordingly select body or orelse expressions. Repeat
# this method with the selected expression (supports if-elif-elif-else).
# Otherwise just return the DSL express as the selected expression.
if isinstance(dsl_expr, BaseIf):
# Todo: Implement a check that this test expression can be evaluated? Or handle case when it can't?
# Todo: Also allow user defined functions that just do dates or numbers in test expression.
# it doesn't have or expand into DSL elements that are the functions of time (Wait, Choice, Market, etc).
if dsl_expr.test.evaluate(**call_arg_namespace):
selected = dsl_expr.body
else:
selected = dsl_expr.orelse
selected = self.select_expression(selected, call_arg_namespace)
else:
selected = dsl_expr
return selected
def create_hash(self, obj):
if isinstance(obj, relativedelta):
return hash(repr(obj))
if isinstance(obj, (
int, float, six.string_types, datetime.datetime, datetime.date, datetime.timedelta, relativedelta)):
return hash(obj)
if isinstance(obj, dict):
return hash(tuple(sorted([(a, self.create_hash(b)) for a, b in obj.items()])))
if isinstance(obj, list):
return hash(tuple(sorted([self.create_hash(a) for a in obj])))
elif isinstance(obj, DslObject):
return hash(obj)
else:
raise DslSystemError("Can't create hash from obj type '%s'" % type(obj), obj,
node=obj.node if isinstance(obj, DslObject) else None)
class FunctionCall(DslExpression):
def pprint(self, indent=''):
return indent + "%s(%s)" % (self.functionDefName,
", ".join([str(arg) for arg in self.callArgExprs]))
def validate(self, args):
self.assert_args_len(args, required_len=2)
self.assert_args_arg(args, posn=0, required_type=Name)
self.assert_args_arg(args, posn=1, required_type=list)
@property
def functionDefName(self):
return self._args[0]
@property
def callArgExprs(self):
return self._args[1]
def reduce(self, dsl_locals, dsl_globals, effective_present_time=None, pending_call_stack=False):
"""
Reduces function call to result of evaluating function def with function call args.
"""
# Replace functionDef names with things in kwds.
functionDef = self.functionDefName.reduce(dsl_locals, dsl_globals, effective_present_time,
pending_call_stack=pending_call_stack)
# Function def name (a Name object) should have reduced to a FunctionDef object in the namespace.
# - it's an error for the name to be defined as anything other than a function, but that's not possible here?
# assert isinstance(functionDef, FunctionDef)
# Check lengths of arg names matches length of arg exprs (function signature must
# satisfy the call). Or the other way around :).
if len(functionDef.callArgs) != len(self.callArgExprs):
raise DslSyntaxError(
"mismatched call args",
"expected %s but got %s. Expected args: %s. Received exprs: %s" % (
len(functionDef.callArgs),
len(self.callArgExprs),
functionDef.call_arg_names,
self.callArgExprs,
),
node=self.node
)
# Create a new call arg namespace for the new call arg values.
newDslLocals = DslNamespace()
# Obtain the call arg values.
for callArgExpr, callArgDef in zip(self.callArgExprs, functionDef.callArgs):
# Skip if it's a DSL object that needs to be evaluated later with market data simulation.
# Todo: Think about and improve the way these levels are separated.
if not isinstance(callArgExpr, DslObject):
# It's a simple value - pass through, not much else to do.
callArgValue = callArgExpr
else:
# Substitute names, etc.
callArgExpr = callArgExpr.reduce(dsl_locals, dsl_globals, effective_present_time,
pending_call_stack=pending_call_stack)
# Decide whether to evaluate, or just pass the expression into the function call.
if isinstance(callArgExpr, Underlying):
# It's explicitly wrapped as an "underlying", so unwrap it as expected.
callArgValue = callArgExpr.evaluate()
elif callArgExpr.has_instances((Market, Fixing, Choice, Settlement, FunctionDef, Stub)):
# It's an underlying contract, or a stub. In any case, can't evaluate here, so.pass it through.
callArgValue = callArgExpr
else:
# assert isinstance(callArgExpr, DslExpression)
# It's a sum of two constants, or something like that - evaluate the full expression.
callArgValue = callArgExpr.evaluate()
# Add the call arg value to the new call arg namespace.
newDslLocals[callArgDef.name] = callArgValue
# Evaluate the function def with the dict of call arg values.
dsl_expr = functionDef.apply(dsl_globals, effective_present_time, pending_call_stack=pending_call_stack,
is_destacking=False, **newDslLocals)
# The result of this function call (stubbed or otherwise) should be a DSL expression.
# assert isinstance(dsl_expr, DslExpression)
return dsl_expr
def evaluate(self, **kwds):
raise DslSyntaxError('call to undefined name', self.functionDefName.name, node=self.node)
class FunctionArg(DslObject):
def validate(self, args):
self.assert_args_len(args, required_len=2)
@property
def name(self):
return self._args[0]
@property
def dsl_typeName(self):
return self._args[1]
class BaseIf(DslExpression):
def validate(self, args):
self.assert_args_len(args, required_len=3)
self.assert_args_arg(args, posn=0, required_type=DslExpression)
self.assert_args_arg(args, posn=1, required_type=DslExpression)
self.assert_args_arg(args, posn=2, required_type=DslExpression)
@property
def test(self):
return self._args[0]
@property
def body(self):
return self._args[1]
@property
def orelse(self):
return self._args[2]
def evaluate(self, **kwds):
testResult = self.test.evaluate(**kwds)
if isinstance(testResult, DslObject):
raise DslSyntaxError("If test condition result cannot be a DSL object", str(testResult), node=self.node)
if testResult:
return self.body.evaluate(**kwds)
else:
return self.orelse.evaluate(**kwds)
class If(BaseIf):
def pprint(self, indent=''):
msg = "\n"
msg += indent + "if %s:\n" % self.test
msg += indent + " %s\n" % self.body
msg += self.orelse_to_str(self.orelse, indent)
return msg
def orelse_to_str(self, orelse, indent):
msg = ''
if isinstance(orelse, If):
msg += indent + "elif %s:\n" % orelse.test
msg += indent + " %s\n" % orelse.body
# Recurse down "linked list" of alternatives...
msg += self.orelse_to_str(orelse.orelse, indent)
else:
# ...until we reach the final alternative.
msg += indent + "else:\n"
msg += indent + " %s\n" % orelse
return msg
class IfExp(If):
"""
Special case of If, where if-else clause is one expression (no elif support).
"""
def pprint(self, indent=''):
return indent + "%s if %s else %s" % (self.body, self.test, self.orelse)
class Compare(DslExpression):
valid_ops = {
'Eq': lambda a, b: a == b,
'NotEq': lambda a, b: a != b,
'Lt': lambda a, b: a < b,
'LtE': lambda a, b: a <= b,
'Gt': lambda a, b: a > b,
'GtE': lambda a, b: a >= b,
}
opcodes = {
'Eq': '==',
'NotEq': '!=',
'Lt': '<',
'LtE': '<=',
'Gt': '>',
'GtE': '>=',
}
def pprint(self, indent=''):
return indent + str(self.left) + ' ' + " ".join(
[str(self.opcodes[op]) + ' ' + str(right) for (op, right) in zip(self.op_names, self.comparators)]
)
def validate(self, args):
self.assert_args_len(args, 3)
self.assert_args_arg(args, 0, required_type=(
DslExpression, Date)) # , Date, Number, String, int, float, six.string_types, datetime.datetime))
self.assert_args_arg(args, 1, required_type=list)
self.assert_args_arg(args, 2, required_type=list)
for opName in args[1]:
if opName not in self.valid_ops.keys():
raise DslSyntaxError("Op name '%s' not supported" % opName)
@property
def left(self):
return self._args[0]
@property
def op_names(self):
return self._args[1]
@property
def comparators(self):
return self._args[2]
def evaluate(self, **kwds):
left = self.left.evaluate(**kwds)
for i in range(len(self.op_names)):
right = self.comparators[i].evaluate(**kwds)
op_name = self.op_names[i]
op = self.valid_ops[op_name]
if not op(left, right):
return False
left = right
return True
class Module(DslObject):
"""
A DSL module has a body, which is a list of DSL statements either
function defs or expressions.
"""
def __init__(self, *args, **kwds):
super(Module, self).__init__(*args, **kwds)
def pprint(self, indent=''):
return indent + "\n".join([str(statement) for statement in self.body])
def validate(self, args):
self.assert_args_len(args, 2)
self.assert_args_arg(args, 0, [(FunctionDef, DslExpression, Date)])
self.assert_args_arg(args, 1, DslNamespace)
@property
def body(self):
return self._args[0]
@property
def namespace(self):
return self._args[1]
def inline(*args):
"""
Dummy 'inline' Quant DSL decorator - we just want the name in the namespace.
"""
import mock
return mock.Mock
class DslNamespace(dict):
def copy(self):
copy = self.__class__(self)
return copy
class StochasticObject(DslObject):
@abstractmethod
def validate(self, args):
"""
Returns value of stochastic object.