-
Notifications
You must be signed in to change notification settings - Fork 125
/
symbolic_convert.py
1196 lines (1046 loc) · 40.8 KB
/
symbolic_convert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import collections
import dataclasses
import dis
import functools
import importlib
import inspect
import itertools
import logging
import operator
import sys
import traceback
import types
import typing
from typing import Any
from typing import Dict
from typing import List
from unittest.mock import patch
import torchdynamo.side_effects
import torchdynamo.variables.base
from torchdynamo.source import AttrSource
from torchdynamo.source import GetItemSource
from torchdynamo.source import GlobalSource
from torchdynamo.source import LocalSource
from torchdynamo.variables.builder import VariableBuilder
from . import config
from . import skipfiles
from .allowed_functions import is_allowed
from .allowed_functions import is_builtin
from .bytecode_analysis import livevars_analysis
from .bytecode_transformation import Instruction
from .bytecode_transformation import cleaned_instructions
from .bytecode_transformation import create_instruction
from .bytecode_transformation import is_generator
from .bytecode_transformation import unique_id
from .codegen import PyCodegen
from .exc import RestartAnalysis
from .exc import Unsupported
from .exc import unimplemented
from .output_graph import OutputGraph
from .resume_execution import ContinueExecutionCache
from .resume_execution import ReenterWith
from .utils import counters
from .utils import istype
from .variables.base import MutableLocal
from .variables.base import VariableTracker
from .variables.base import typestr
from .variables.builtin import BuiltinVariable
from .variables.constant import ConstantVariable
from .variables.dicts import ConstDictVariable
from .variables.functions import BaseUserFunctionVariable
from .variables.functions import NestedUserFunctionVariable
from .variables.functions import UserFunctionVariable
from .variables.lists import BaseListVariable
from .variables.lists import ListIteratorVariable
from .variables.lists import ListVariable
from .variables.lists import SliceVariable
from .variables.lists import TupleVariable
from .variables.misc import ClosureVariable
from .variables.misc import ContextManagerVariable
from .variables.misc import GetAttrVariable
from .variables.misc import PythonModuleVariable
from .variables.misc import UnknownVariable
from .variables.misc import WithExitFunctionVariable
from .variables.nn_module import NNModuleVariable
from .variables.tensor import TensorVariable
from .variables.torch import TorchVariable
from .variables.user_defined import UserDefinedVariable
log = logging.getLogger(__name__)
@dataclasses.dataclass
class BlockStackEntry:
target: Instruction
stack_index: int = None
with_context: ContextManagerVariable = None
def can_restore(self):
return self.with_context is not None
def resume_fn(self):
assert self.stack_index is not None
return ReenterWith(self.stack_index)
def exit(self, tx):
return self.with_context.exit(tx)
def stack_op(fn: typing.Callable):
nargs = len(inspect.signature(fn).parameters)
fn_var = BuiltinVariable(fn)
@functools.wraps(fn)
def impl(self: "InstructionTranslatorBase", inst: Instruction):
self.push(fn_var.call_function(self, self.popn(nargs), {}))
return impl
def generic_jump(truth_fn: typing.Callable, push: bool):
def inner(self: "InstructionTranslatorBase", inst: Instruction):
value: VariableTracker = self.pop()
self.output.guards.update(value.guards)
if value.is_python_constant():
if truth_fn(value.as_python_constant()):
push and self.push(value)
self.jump(inst)
elif isinstance(value, TensorVariable) and self.should_compile_partial_graph():
# compile a partial subgraph prefix then jump into user code
self.push(value)
self.output.compile_subgraph(self)
self.pop()
if_next = self.create_call_resume_at(self.next_instruction)
push and self.push(value)
if_jump = self.create_call_resume_at(inst.target)
self.output.add_output_instructions(
[(create_instruction(inst.opname, target=if_jump[0]))]
+ if_next
+ if_jump
)
elif value.has_unpack_var_sequence(self):
if truth_fn(len(value.unpack_var_sequence(self))):
push and self.push(value)
self.jump(inst)
else:
unimplemented(f"generic_jump {typestr(value)}")
return inner
def break_graph_if_unsupported(inner_fn):
@functools.wraps(inner_fn)
def wrapper(self: "InstructionTranslatorBase", inst: Instruction):
state = self.copy_graphstate()
try:
return inner_fn(self, inst)
except Unsupported as exc:
if not self.should_compile_partial_graph():
raise
exc.remove_from_stats()
exc.add_to_stats("graph_break")
self.restore_graphstate(state)
self.output.compile_subgraph(self)
# note, assuming inst pushes 1
self.popn(1 - dis.stack_effect(inst.opcode, inst.arg))
self.output.add_output_instructions([inst])
self.push(UnknownVariable())
self.output.add_output_instructions(
self.create_call_resume_at(self.next_instruction)
)
return wrapper
class InstructionTranslatorBase(object):
def cell_and_freevars(self):
if not hasattr(self, "_cell_and_freevars"):
self._cell_and_freevars = tuple(
self.code_options["co_cellvars"] or []
) + tuple(self.code_options["co_freevars"] or [])
return self._cell_and_freevars
def prune_dead_locals(self):
reads = livevars_analysis(self.instructions, self.current_instruction)
# implicit use by super()
# reads = reads | {"__class__"}
# output variables?
reads = reads | set(self.cell_and_freevars())
self.symbolic_locals = collections.OrderedDict(
[(k, v) for k, v in self.symbolic_locals.items() if k in reads]
)
self.output.side_effects.prune_dead_object_new(self)
def call_function(
self,
fn: VariableTracker,
args: List[VariableTracker],
kwargs: Dict[str, VariableTracker],
):
assert isinstance(fn, VariableTracker)
assert isinstance(args, list)
assert isinstance(kwargs, dict)
assert all(
isinstance(x, VariableTracker)
for x in itertools.chain(args, kwargs.values())
)
self.push(fn.call_function(self, args, kwargs))
def replace_all(self, oldvar: VariableTracker, newvar: VariableTracker):
if isinstance(
oldvar.mutable_local, torchdynamo.side_effects.MutableSideEffects
):
newvar = self.output.side_effects.mutation(oldvar, newvar)
else:
assert isinstance(
oldvar.mutable_local, torchdynamo.variables.base.MutableLocal
)
newvar = newvar.clone(
mutable_local=torchdynamo.variables.base.MutableLocal()
)
def repl(v: VariableTracker):
if v.mutable_local is oldvar.mutable_local:
return newvar
return v
self.output.side_effects.apply(repl)
self.stack = [VariableTracker.apply(repl, x) for x in self.stack]
for k, x in self.symbolic_locals.items():
self.symbolic_locals[k] = VariableTracker.apply(repl, x)
return newvar
def inline_user_function_return(self, fn, args, kwargs):
"""
A call to some user defined function by inlining it.
"""
state = self.copy_graphstate()
try:
result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
self.output.guards.update(fn.guards)
return result
except Exception:
self.restore_graphstate(state)
raise
def step(self):
"""Process exactly one instruction, return False we should exit"""
inst = self.instructions[self.instruction_pointer]
self.current_instruction = inst
self.instruction_pointer += 1
if self.instruction_pointer < len(self.instructions):
self.next_instruction = self.instructions[self.instruction_pointer]
else:
self.instruction_pointer = None
self.next_instruction = None
if inst.starts_line:
self.lineno = inst.starts_line
if len(self.stack) == 0 and self.should_compile_partial_graph():
self.checkpoint = inst, self.copy_graphstate()
if config.trace:
print("TRACE", inst.opname, inst.argval, self.stack)
try:
if not hasattr(self, inst.opname):
unimplemented(f"missing: {inst.opname}")
getattr(self, inst.opname)(inst)
return inst.opname != "RETURN_VALUE"
except Unsupported as exc:
exc.real_stack.append(self.frame_summary())
if not self.checkpoint:
raise
# generate code from checkpoint
assert not self.output.output_instructions
continue_inst, state = self.checkpoint
self.restore_graphstate(state)
self.output.compile_subgraph(self, partial_convert=True)
self.output.add_output_instructions(
[create_instruction("JUMP_ABSOLUTE", target=continue_inst)]
+ self.instructions
)
def run(self):
try:
while (
self.instruction_pointer is not None
and not self.output.should_exit
and self.step()
):
pass
except (Unsupported, RestartAnalysis):
raise
except Exception as e:
sys.stderr.write(
f"ERROR FROM offset={self.current_instruction.offset} "
f"filename {self.code_options.get('co_filename')} "
f"{self.lineno} {typestr(e)}\n"
)
raise
def push(self, val):
assert val is None or isinstance(
val, VariableTracker
), f"push expects VariableTracker, got {typestr(val)}"
self.stack.append(val)
def push_many(self, vals: List[TensorVariable]):
for val in vals:
self.push(val)
def pop(self) -> TensorVariable:
return self.stack.pop()
def popn(self, n: int) -> List[TensorVariable]:
assert n >= 0
return list(reversed([self.pop() for _ in range(n)]))
def LOAD_FAST(self, inst):
name = inst.argval
if name.startswith(".") and name not in self.symbolic_locals:
# This happens in dict/list comprehensions
name = name.replace(".", "implicit")
assert name not in self.cell_and_freevars()
if name not in self.symbolic_locals:
unimplemented("undefined LOAD_FAST")
self.push(self.symbolic_locals[name])
if name.startswith("___stack"):
self.symbolic_locals.pop(name)
def LOAD_DEREF(self, inst):
assert inst.argval in self.cell_and_freevars()
if inst.argval not in self.symbolic_locals:
unimplemented(f"undefined LOAD_DEREF {inst.argval}")
self.push(self.symbolic_locals[inst.argval])
def STORE_FAST(self, inst):
self.symbolic_locals[inst.argval] = self.pop()
def DELETE_FAST(self, inst):
del self.symbolic_locals[inst.argval]
STORE_DEREF = STORE_FAST
def LOAD_CLOSURE(self, inst):
self.push(ClosureVariable(name=inst.argval))
def LOAD_CONST(self, inst):
self.push(ConstantVariable(value=inst.argval))
def LOAD_GLOBAL(self, inst):
try:
value = self.f_globals[inst.argval]
except KeyError:
return self.load_builtin(inst)
if self.output.root_globals is self.f_globals:
source = GlobalSource(inst.argval)
else:
if "__name__" in self.f_globals:
source = AttrSource(
self.import_source(self.f_globals["__name__"]), inst.argval
)
else:
name = f"___unnamed_scope_{id(self.f_globals)}"
if name not in self.output.root_globals:
self.output.install_global(name, self.f_globals)
source = GetItemSource(GlobalSource(name), inst.argval)
self.push(VariableBuilder(self, source)(value))
def import_source(self, module_name):
"""Create an alias to a module for use in guards"""
value = importlib.import_module(module_name)
alias = f"__import_{module_name.replace('.', '_dot_')}"
f_globals = self.output.root_globals
assert alias not in f_globals or f_globals[alias] is value
f_globals[alias] = value
self.output.update_co_names(alias)
return GlobalSource(alias)
def IMPORT_NAME(self, inst):
level, fromlist = self.popn(2)
if level.as_python_constant() != 0:
unimplemented("IMPORT_NAME with level")
# Import name imports the top level package
module_name = inst.argval.split(".")[0]
value = importlib.import_module(module_name)
source = self.import_source(module_name)
if is_allowed(value):
self.push(TorchVariable(value, source=source))
elif istype(value, types.ModuleType):
self.push(PythonModuleVariable(value, source=source))
else:
unimplemented(f"IMPORT_NAME {typestr(value)}")
def IMPORT_FROM(self, inst):
self.DUP_TOP(inst)
self.LOAD_ATTR(inst)
def load_builtin(self, inst):
assert inst.argval in self.f_builtins
val = self.f_builtins[inst.argval]
assert is_builtin(val)
self.push(VariableBuilder(self, GlobalSource(inst.argval))(val))
def jump(self, inst):
self.instruction_pointer = self.indexof[id(inst.target)]
JUMP_FORWARD = jump
JUMP_ABSOLUTE = jump
POP_JUMP_IF_FALSE = generic_jump(operator.not_, False)
POP_JUMP_IF_TRUE = generic_jump(operator.truth, False)
JUMP_IF_FALSE_OR_POP = generic_jump(operator.not_, True)
JUMP_IF_TRUE_OR_POP = generic_jump(operator.truth, True)
def SETUP_LOOP(self, inst):
# only exists in python<=3.7
self.block_stack.append(BlockStackEntry(inst.target))
def POP_BLOCK(self, inst):
self.block_stack.pop()
def SETUP_WITH(self, inst):
ctx = self.pop()
if not isinstance(ctx, ContextManagerVariable):
unimplemented(f"SETUP_WITH {ctx}")
self.output.guards.update(ctx.guards)
if isinstance(self, InstructionTranslator):
self.block_stack.append(BlockStackEntry(inst.target, len(self.stack), ctx))
else:
# can't restore this while inlining
self.block_stack.append(BlockStackEntry(inst.target))
self.push(
WithExitFunctionVariable(
ctx,
inst.target,
**VariableTracker.propagate(ctx),
)
)
self.push(ctx.enter(self))
def SETUP_FINALLY(self, inst):
self.block_stack.append(BlockStackEntry(inst.target))
def BEGIN_FINALLY(self, inst):
self.push(None)
def WITH_CLEANUP_START(self, inst):
exit, exc = self.popn(2)
assert exc is None
self.push(exc)
self.push(exit.call_function(self, [ConstantVariable(None)] * 3, {}))
def WITH_CLEANUP_FINISH(self, inst):
self.popn(2)
self.push(None)
def END_FINALLY(self, inst):
assert self.pop() is None
def FOR_ITER(self, inst):
it = self.pop()
if isinstance(it, ListIteratorVariable):
self.output.guards.update(it.guards)
try:
val, next_iter = it.next_variables()
self.replace_all(it, next_iter)
self.push(next_iter)
self.push(val)
except StopIteration:
self.jump(inst)
else:
unimplemented(f"FOR_ITER {typestr(it)}")
def COMPARE_OP(self, inst):
left, right = self.popn(2)
options = VariableTracker.propagate([left, right])
op = inst.argval
supported_is_const = {
"is": operator.is_,
"is not": operator.is_not,
"==": operator.eq,
"!=": operator.ne,
}
supported_tensors = {
">": operator.gt,
"<": operator.lt,
">=": operator.ge,
"<=": operator.le,
"==": operator.eq,
"!=": operator.ne,
}
supported_any = dict(
itertools.chain(supported_tensors.items(), supported_is_const.items())
)
if (
isinstance(
left,
(
TensorVariable,
NNModuleVariable,
BaseListVariable,
UserDefinedVariable,
BaseUserFunctionVariable,
),
)
and isinstance(right, ConstantVariable)
and right.value is None
and op in supported_is_const
):
# <non-None> is None
self.push(
ConstantVariable(
supported_is_const[op](object(), right.value), **options
)
)
elif (
isinstance(left, TensorVariable) or isinstance(right, TensorVariable)
) and op in supported_tensors:
self.push(
TensorVariable.create(
self,
supported_tensors[op](left.as_proxy(), right.as_proxy()),
**options,
)
)
elif (
left.is_python_constant()
and right.is_python_constant()
and op in supported_any
):
# constant fold
self.push(
ConstantVariable(
supported_any[op](
left.as_python_constant(), right.as_python_constant()
),
**options,
)
)
elif op in ("in", "not in"):
self.push(right.call_method(self, "__contains__", [left], {}))
if op == "not in":
self.UNARY_NOT(inst)
else:
unimplemented(f"COMPARE_OP {typestr(left)} {op} {typestr(right)}")
def GET_ITER(self, inst):
self.call_function(BuiltinVariable(iter), [self.pop()], {})
@break_graph_if_unsupported
def CALL_FUNCTION(self, inst):
args = self.popn(inst.argval)
fn = self.pop()
self.call_function(fn, args, {})
@break_graph_if_unsupported
def CALL_FUNCTION_EX(self, inst):
if inst.argval == 0:
kwargsvars = ConstDictVariable({})
argsvars = self.pop()
elif inst.argval == 1:
kwargsvars = self.pop()
argsvars = self.pop()
else:
unimplemented("CALL_FUNCTION_EX")
fn = self.pop()
self.output.guards.update(argsvars.guards)
self.output.guards.update(kwargsvars.guards)
if (
isinstance(fn, GetAttrVariable)
and isinstance(fn.obj, TensorVariable)
and fn.name == "view"
and isinstance(argsvars, (ConstantVariable, TensorVariable))
):
# Hack to handle special case in some bert models. Converts
# x.view(*shape) into x.view(shape), which is correct for view()
# but not generally. See test_transpose_for_scores().
argsvars = TupleVariable([argsvars])
if not isinstance(
argsvars, BaseListVariable
) and argsvars.has_unpack_var_sequence(self):
argsvars = TupleVariable(argsvars.unpack_var_sequence(self))
if not isinstance(argsvars, BaseListVariable) or not isinstance(
kwargsvars, ConstDictVariable
):
unimplemented(f"non-static call {typestr(argsvars)} {typestr(kwargsvars)}")
self.call_function(fn, argsvars.items, kwargsvars.items)
@break_graph_if_unsupported
def CALL_FUNCTION_KW(self, inst):
argnames = self.pop()
args = self.popn(inst.argval)
fn = self.pop()
assert isinstance(argnames, ConstantVariable)
argnames = argnames.value
args, kwargs = args[: -len(argnames)], args[-len(argnames) :]
kwargs = dict(zip(argnames, kwargs))
assert len(kwargs) == len(argnames)
self.call_function(fn, args, kwargs)
def LOAD_METHOD(self, inst):
self.LOAD_ATTR(inst)
self.push(self.pop())
self.push(None)
def CALL_METHOD(self, inst):
args = self.popn(inst.argval)
dummy = self.pop()
assert dummy is None
fn = self.pop()
self.call_function(fn, args, {})
def LOAD_ATTR(self, inst):
obj = self.pop()
result = BuiltinVariable(getattr).call_function(
self, [obj, ConstantVariable(inst.argval)], {}
)
self.push(result)
def STORE_ATTR(self, inst):
prior = self.copy_graphstate()
val, obj = self.popn(2)
try:
self.output.guards.update(
BuiltinVariable(setattr)
.call_function(self, [obj, ConstantVariable(inst.argval), val], {})
.guards
)
return
except Unsupported as e:
if not self.should_compile_partial_graph():
raise
e.remove_from_stats()
e.add_to_stats("graph_break")
self.restore_graphstate(prior)
# break the graph
self.output.compile_subgraph(self)
self.output.add_output_instructions([inst])
self.popn(2)
self.output.add_output_instructions(
self.create_call_resume_at(self.next_instruction)
)
def STORE_SUBSCR(self, inst):
val, obj, key = self.popn(3)
result = obj.call_method(self, "__setitem__", [key, val], {})
# no result is pushed, so need to lift the guards to global
self.output.guards.update(result.guards)
def BUILD_TUPLE(self, inst):
items = self.popn(inst.argval)
options = VariableTracker.propagate(items)
self.push(TupleVariable(items, **options))
def BUILD_SLICE(self, inst):
items = self.popn(inst.argval)
options = VariableTracker.propagate(items)
self.push(SliceVariable(items, **options))
def BUILD_LIST(self, inst):
items = self.popn(inst.argval)
options = VariableTracker.propagate(items)
self.push(ListVariable(items, mutable_local=MutableLocal(), **options))
def BUILD_LIST_UNPACK(self, inst, cls=ListVariable):
seqs = self.popn(inst.argval)
options = VariableTracker.propagate(seqs)
items = list()
for seq in seqs:
try:
items.extend(seq.unpack_var_sequence(self))
except NotImplementedError:
unimplemented(f"BUILD_LIST_UNPACK {seq}")
self.push(cls(items, mutable_local=MutableLocal(), **options))
def BUILD_TUPLE_UNPACK(self, inst):
self.BUILD_LIST_UNPACK(inst, cls=TupleVariable)
BUILD_TUPLE_UNPACK_WITH_CALL = BUILD_TUPLE_UNPACK
def BUILD_MAP(self, inst):
items = self.popn(inst.argval * 2)
options = VariableTracker.propagate(items)
result = collections.OrderedDict()
for k, v in zip(items[::2], items[1::2]):
assert isinstance(k, ConstantVariable)
result[k.value] = v
assert len(result) == len(items) / 2
self.push(ConstDictVariable(result, mutable_local=MutableLocal(), **options))
def BUILD_CONST_KEY_MAP(self, inst):
keys = self.pop()
values = self.popn(inst.argval)
options = VariableTracker.propagate([keys] + values)
assert isinstance(keys, ConstantVariable)
keys = keys.value
assert istype(keys, tuple)
assert len(keys) == len(values)
self.push(
ConstDictVariable(
collections.OrderedDict(zip(keys, values)),
mutable_local=MutableLocal(),
**options,
)
)
def MAP_ADD(self, inst):
if sys.version_info < (3, 8):
v, k = self.popn(2)
else:
k, v = self.popn(2)
assert inst.argval > 0
obj = self.stack[-inst.arg]
assert isinstance(obj, ConstDictVariable)
assert obj.mutable_local
items = collections.OrderedDict(obj.items)
items[k.as_python_constant()] = v
self.replace_all(
obj,
ConstDictVariable(
items,
**VariableTracker.propagate([obj, k, v]),
),
)
def LIST_APPEND(self, inst):
v = self.pop()
assert inst.argval > 0
obj = self.stack[-inst.arg]
assert isinstance(obj, ListVariable)
assert obj.mutable_local
self.replace_all(
obj,
ListVariable(
obj.items + [v],
**VariableTracker.propagate([obj, v]),
),
)
def MAKE_FUNCTION(self, inst):
flags = inst.arg
old_stack = list(self.stack)
fn_name = self.pop()
code = self.pop()
defaults = None
closure = None
annotations = None
kwdefaults = None
if flags & 0x08:
closure = self.pop()
if flags & 0x04:
annotations = self.pop()
if flags & 0x02:
kwdefaults = self.pop()
if flags & 0x01:
defaults = self.pop()
options = VariableTracker.propagate(old_stack[len(self.stack) :])
self.push(
NestedUserFunctionVariable(
fn_name,
code,
self.f_globals,
defaults,
kwdefaults,
annotations,
closure,
closure_scope=self,
**options,
)
)
def UNPACK_SEQUENCE(self, inst):
# TODO(jansel): rewrite this using unpack_var_sequence
seq = self.pop()
options = VariableTracker.propagate([seq])
if isinstance(seq, BaseListVariable):
assert len(seq.items) == inst.argval
self.output.guards.update(seq.guards)
for i in reversed(seq.items):
self.push(i)
elif seq.is_python_constant() and isinstance(seq, ConstantVariable):
val = seq.as_python_constant()
assert len(val) == inst.argval
for i in reversed(val):
self.push(ConstantVariable(i, **options))
elif isinstance(seq, TensorVariable):
proxy = seq.as_proxy()
for i in reversed(range(inst.argval)):
self.push(TensorVariable.create(self, proxy[i], **options))
elif isinstance(seq, GetAttrVariable) and isinstance(seq.obj, TensorVariable):
# x, y = a.shape
proxy = getattr(seq.obj.as_proxy(), seq.name)
for i in reversed(range(inst.argval)):
self.push(TensorVariable.create(self, proxy[i], **options))
else:
unimplemented(f"UNPACK_SEQUENCE {seq}")
def UNPACK_EX(self, inst):
assert 0 <= inst.argval <= 0xFFFF
prefix = inst.argval & 0xFF # low byte
suffix = inst.argval >> 8 # high byte
seq = self.pop()
options = VariableTracker.propagate(seq)
if seq.has_unpack_var_sequence(self):
vals = list(seq.unpack_var_sequence(self))
assert len(vals) >= prefix + suffix
vals_prefix = vals[:prefix]
vals_list = vals[prefix : len(vals) - suffix]
vals_suffix = vals[len(vals) - suffix :]
for item in reversed(vals_suffix):
self.push(item.add_options(options))
self.push(TupleVariable(vals_list, **options))
for item in reversed(vals_prefix):
self.push(item.add_options(options))
else:
unimplemented(f"UNPACK_EX {seq}")
def NOP(self, inst):
pass
def POP_TOP(self, inst):
self.pop()
def ROT_TWO(self, inst):
a = self.pop()
b = self.pop()
self.push(a)
self.push(b)
def ROT_THREE(self, inst):
a = self.pop()
b = self.pop()
c = self.pop()
self.push(a)
self.push(c)
self.push(b)
def ROT_FOUR(self, inst):
a = self.pop()
b = self.pop()
c = self.pop()
d = self.pop()
self.push(a)
self.push(d)
self.push(c)
self.push(b)
def DUP_TOP(self, inst):
a = self.pop()
self.push(a)
self.push(a)
def DUP_TOP_TWO(self, inst):
a = self.pop()
b = self.pop()
self.push(b)
self.push(a)
self.push(b)
self.push(a)
UNARY_POSITIVE = stack_op(operator.pos)
UNARY_NEGATIVE = stack_op(operator.neg)
UNARY_NOT = stack_op(operator.not_)
UNARY_INVERT = stack_op(operator.invert)
BINARY_POWER = stack_op(operator.pow)
BINARY_MULTIPLY = stack_op(operator.mul)
BINARY_MATRIX_MULTIPLY = stack_op(operator.matmul)
BINARY_FLOOR_DIVIDE = stack_op(operator.floordiv)
BINARY_TRUE_DIVIDE = stack_op(operator.truediv)
BINARY_MODULO = stack_op(operator.mod)
BINARY_ADD = stack_op(operator.add)
BINARY_SUBTRACT = stack_op(operator.sub)
BINARY_SUBSCR = break_graph_if_unsupported(stack_op(operator.getitem))
BINARY_LSHIFT = stack_op(operator.lshift)
BINARY_RSHIFT = stack_op(operator.rshift)
BINARY_AND = stack_op(operator.and_)
BINARY_OR = stack_op(operator.or_)
BINARY_XOR = stack_op(operator.xor)
INPLACE_POWER = stack_op(operator.ipow)
INPLACE_MULTIPLY = stack_op(operator.imul)
INPLACE_MATRIX_MULTIPLY = stack_op(operator.imatmul)
INPLACE_FLOOR_DIVIDE = stack_op(operator.ifloordiv)
INPLACE_TRUE_DIVIDE = stack_op(operator.itruediv)
INPLACE_MODULO = stack_op(operator.imod)
INPLACE_ADD = stack_op(operator.iadd)
INPLACE_SUBTRACT = stack_op(operator.isub)
INPLACE_LSHIFT = stack_op(operator.ilshift)
INPLACE_RSHIFT = stack_op(operator.irshift)
INPLACE_AND = stack_op(operator.iand)
INPLACE_XOR = stack_op(operator.ixor)
INPLACE_OR = stack_op(operator.ior)
def copy_graphstate(self):
"""Create a checkpoint of the current state by copying everything"""
return (
self.output.copy_graphstate(),
collections.OrderedDict(self.symbolic_locals),
list(self.stack),
list(self.block_stack),
self.instruction_pointer,
self.current_instruction,
self.next_instruction,
self.lineno,
)
def restore_graphstate(self, state):
"""Restore a checkpoint created by self.copy_graphstate()"""
(
output_state,
self.symbolic_locals,
self.stack,
self.block_stack,
self.instruction_pointer,
self.current_instruction,
self.next_instruction,
self.lineno,
) = state
self.output.restore_graphstate(output_state)
def frame_summary(self):
return traceback.FrameSummary(
getattr(self.f_code, "co_filename", "<unknown>"),
self.lineno,
getattr(self.f_code, "co_name", "<unknown>"),
lookup_line=False,
)
def __init__(
self,
output: OutputGraph,
instructions: List[Instruction],
f_globals: Dict[str, Any],
f_builtins: Dict[str, Any],
code_options: Dict[str, Any],
symbolic_locals: Dict[str, VariableTracker],
f_code: types.CodeType,
):
super(InstructionTranslatorBase, self).__init__()
# Mutable state checkpointed by copy_graphstate()
self.output: OutputGraph = output
self.symbolic_locals: Dict[str, VariableTracker] = symbolic_locals
self.stack: List[VariableTracker] = []
self.instruction_pointer: int = 0
self.current_instruction: Instruction = create_instruction("NOP")
self.next_instruction: typing.Optional[Instruction] = None
self.block_stack: List[BlockStackEntry] = []
self.lineno: int = code_options.get("co_firstlineno")
# Properties of the input/output code
self.instructions: List[Instruction] = instructions
self.indexof: Dict[int, int] = {id(i): n for n, i in enumerate(instructions)}
self.f_globals: Dict[str, Any] = f_globals
self.f_builtins: Dict[str, Any] = f_builtins
self.code_options: Dict[str, Any] = code_options
self.f_code: types.CodeType = f_code
self.checkpoint = None
class InstructionTranslator(InstructionTranslatorBase):
def __init__(
self,
instructions: List[Instruction],
f_code,
f_locals,
f_globals,
f_builtins,
code_options,
compiler_fn,
one_graph,
):
super(InstructionTranslator, self).__init__(
output=OutputGraph(f_globals, code_options, compiler_fn, self),
instructions=instructions,
f_globals=f_globals,
f_builtins=f_builtins,
code_options=code_options,
symbolic_locals=collections.OrderedDict(), # set below
f_code=f_code,
)
self.one_graph: bool = one_graph
vars = list(code_options["co_varnames"])
vars.extend(x for x in self.cell_and_freevars() if x not in vars)
self.symbolic_locals = collections.OrderedDict(
(k, VariableBuilder(self, LocalSource(k))(f_locals[k]))
for k in vars
if k in f_locals
)
# TODO(jansel): figure out why the following is needed for detectron2_maskrcnn
for val in self.symbolic_locals.values():
if isinstance(val, (ListIteratorVariable, BaseListVariable)):
self.output.guards.update(val.guards)
self._freevars_ids = dict()
for name in self.code_options["co_freevars"]:
if name in f_locals:
self._freevars_ids[name] = id(f_locals[name])
def match_nested_cell(self, name, cell):