-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
regionpasses.py
122 lines (101 loc) · 3.65 KB
/
regionpasses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from collections.abc import Mapping
from numba_rvsdg.core.datastructures.scfg import (
SCFG,
)
from numba_rvsdg.core.datastructures.basic_block import (
BasicBlock,
RegionBlock,
)
def _compute_incoming_labels(graph: Mapping[str, BasicBlock]) -> dict[str, set[str]]:
jump_table: dict[str, set[str]] = {}
blk: BasicBlock
for k in graph:
jump_table[k] = set()
for blk in graph.values():
for dst in blk.jump_targets:
if dst in jump_table:
jump_table[dst].add(blk.name)
return jump_table
def toposort_graph(graph: Mapping[str, BasicBlock]) -> list[list[str]]:
incoming_labels = _compute_incoming_labels(graph)
visited: set[str] = set()
toposorted: list[list[str]] = []
# Toposort
while incoming_labels:
level = []
for k, vs in incoming_labels.items():
if not (vs - visited):
# all incoming visited
level.append(k)
for k in level:
del incoming_labels[k]
visited |= set(level)
toposorted.append(level)
return toposorted
class RegionVisitor:
direction = "forward"
def visit_block(self, block: BasicBlock, data):
pass
def visit_loop(self, region: RegionBlock, data):
pass
def visit_switch(self, region: RegionBlock, data):
pass
def visit_linear(self, region: RegionBlock, data):
return self.visit_graph(region.subregion, data)
def visit_graph(self, scfg: SCFG, data):
toposorted = self._toposort_graph(scfg)
label: str
for lvl in toposorted:
for label in lvl:
data = self.visit(scfg[label], data)
return data
def _toposort_graph(self, scfg: SCFG):
toposorted = toposort_graph(scfg.graph)
if self.direction == "forward":
return toposorted
elif self.direction == "backward":
return reversed(toposorted)
else:
assert False, f"invalid direction {self.direction!r}"
def visit(self, block: BasicBlock, data):
if isinstance(block, RegionBlock):
if block.kind == "loop":
fn = self.visit_loop
elif block.kind == "switch":
fn = self.visit_switch
else:
raise NotImplementedError('unreachable')
data = fn(block, data)
else:
data = self.visit_block(block, data)
return data
class RegionTransformer(RegionVisitor):
def visit_block(self, parent: SCFG, block: BasicBlock, data):
pass
def visit_loop(self, parent: SCFG, region: RegionBlock, data):
pass
def visit_switch(self, parent: SCFG, region: RegionBlock, data):
pass
def visit_linear(self, parent: SCFG, region: RegionBlock, data):
return self.visit_graph(region.subregion, data)
def visit_graph(self, scfg: SCFG, data):
toposorted = toposort_graph(scfg.graph)
label: str
for lvl in toposorted:
for label in lvl:
data = self.visit(scfg, scfg[label], data)
return data
def visit(self, parent: SCFG, block: BasicBlock, data):
if isinstance(block, RegionBlock):
if block.kind == "loop":
fn = self.visit_loop
elif block.kind == "switch":
fn = self.visit_switch
elif block.kind in {"head", "tail", "branch"}:
fn = self.visit_linear
else:
raise NotImplementedError('unreachable', block.name, block.kind)
data = fn(parent, block, data)
else:
data = self.visit_block(parent, block, data)
return data