-
Notifications
You must be signed in to change notification settings - Fork 11.6k
/
module.py
126 lines (102 loc) · 3.21 KB
/
module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# RUN: %PYTHON %s | FileCheck %s
import gc
from mlir.ir import *
def run(f):
print("\nTEST:", f.__name__)
f()
gc.collect()
assert Context._get_live_count() == 0
# Verify successful parse.
# CHECK-LABEL: TEST: testParseSuccess
# CHECK: module @successfulParse
def testParseSuccess():
ctx = Context()
module = Module.parse(r"""module @successfulParse {}""", ctx)
assert module.context is ctx
print("CLEAR CONTEXT")
ctx = None # Ensure that module captures the context.
gc.collect()
module.dump() # Just outputs to stderr. Verifies that it functions.
print(str(module))
run(testParseSuccess)
# Verify parse error.
# CHECK-LABEL: TEST: testParseError
# CHECK: testParseError: Unable to parse module assembly (see diagnostics)
def testParseError():
ctx = Context()
try:
module = Module.parse(r"""}SYNTAX ERROR{""", ctx)
except ValueError as e:
print("testParseError:", e)
else:
print("Exception not produced")
run(testParseError)
# Verify successful parse.
# CHECK-LABEL: TEST: testCreateEmpty
# CHECK: module {
def testCreateEmpty():
ctx = Context()
loc = Location.unknown(ctx)
module = Module.create(loc)
print("CLEAR CONTEXT")
ctx = None # Ensure that module captures the context.
gc.collect()
print(str(module))
run(testCreateEmpty)
# Verify round-trip of ASM that contains unicode.
# Note that this does not test that the print path converts unicode properly
# because MLIR asm always normalizes it to the hex encoding.
# CHECK-LABEL: TEST: testRoundtripUnicode
# CHECK: func private @roundtripUnicode()
# CHECK: foo = "\F0\9F\98\8A"
def testRoundtripUnicode():
ctx = Context()
module = Module.parse(r"""
func private @roundtripUnicode() attributes { foo = "😊" }
""", ctx)
print(str(module))
run(testRoundtripUnicode)
# Tests that module.operation works and correctly interns instances.
# CHECK-LABEL: TEST: testModuleOperation
def testModuleOperation():
ctx = Context()
module = Module.parse(r"""module @successfulParse {}""", ctx)
assert ctx._get_live_module_count() == 1
op1 = module.operation
assert ctx._get_live_operation_count() == 1
# CHECK: module @successfulParse
print(op1)
# Ensure that operations are the same on multiple calls.
op2 = module.operation
assert ctx._get_live_operation_count() == 1
assert op1 is op2
# Ensure that if module is de-referenced, the operations are still valid.
module = None
gc.collect()
print(op1)
# Collect and verify lifetime.
op1 = None
op2 = None
gc.collect()
print("LIVE OPERATIONS:", ctx._get_live_operation_count())
assert ctx._get_live_operation_count() == 0
assert ctx._get_live_module_count() == 0
run(testModuleOperation)
# CHECK-LABEL: TEST: testModuleCapsule
def testModuleCapsule():
ctx = Context()
module = Module.parse(r"""module @successfulParse {}""", ctx)
assert ctx._get_live_module_count() == 1
# CHECK: "mlir.ir.Module._CAPIPtr"
module_capsule = module._CAPIPtr
print(module_capsule)
module_dup = Module._CAPICreate(module_capsule)
assert module is module_dup
assert module_dup.context is ctx
# Gc and verify destructed.
module = None
module_capsule = None
module_dup = None
gc.collect()
assert ctx._get_live_module_count() == 0
run(testModuleCapsule)