This repository has been archived by the owner on Oct 25, 2023. It is now read-only.
/
test_frontend_dynamo.py
369 lines (321 loc) · 13.9 KB
/
test_frontend_dynamo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
pytest.importorskip("torch._dynamo")
import tvm
from tvm import relax, meta_schedule as ms, tir
import tvm.testing
import torch
import torch._dynamo as dynamo
from tvm.relax.frontend.torch import relax_dynamo
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T
def test_relax_dynamo():
class Input1(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin = torch.nn.Linear(100, 10)
def forward(self, x):
return torch.nn.functional.relu(self.lin(x))
model = Input1()
### construct the database
@tvm.script.ir_module
class Input1_ir:
@T.prim_func
def main(
inp_0: T.Buffer((T.int64(10), T.int64(100)), "float32"),
param_0: T.Buffer((T.int64(100), T.int64(10)), "float32"),
param_1: T.Buffer(T.int64(10), "float32"),
compute: T.Buffer((T.int64(10), T.int64(10)), "float32"),
):
# function attr dict
T.func_attr({"tir.noalias": True, "global_symbol": "main"})
# body
# with T.block("root")
matmul = T.alloc_buffer([T.int64(10), T.int64(10)], dtype="float32")
T_add = T.alloc_buffer([T.int64(10), T.int64(10)], dtype="float32")
for i0, i1, k in T.grid(T.int64(10), T.int64(10), T.int64(100)):
with T.block("matmul"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(inp_0[v_i0, v_k], param_0[v_k, v_i1])
T.writes(matmul[v_i0, v_i1])
with T.init():
matmul[v_i0, v_i1] = T.float32(0)
matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + inp_0[v_i0, v_k] * param_0[v_k, v_i1]
for ax0, ax1 in T.grid(T.int64(10), T.int64(10)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(matmul[v_ax0, v_ax1], param_1[v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = matmul[v_ax0, v_ax1] + param_1[v_ax1]
for i0, i1 in T.grid(T.int64(10), T.int64(10)):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(T_add[v_i0, v_i1])
T.writes(compute[v_i0, v_i1])
compute[v_i0, v_i1] = T.max(T_add[v_i0, v_i1], T.float32(0))
db = ms.Database.create("memory")
workload = db.commit_workload(Input1_ir)
sch = tir.Schedule(Input1_ir, debug_mask="all")
b0 = sch.get_block(name="matmul", func_name="main")
b1 = sch.get_block(name="T_add", func_name="main")
b2 = sch.get_block(name="root", func_name="main")
sch.compute_inline(block=b1)
sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")
l3, l4, l5 = sch.get_loops(block=b0)
v6, v7, v8, v9 = sch.sample_perfect_tile(
loop=l3, n=4, max_innermost_factor=64, decision=[1, 2, 5, 1]
)
l10, l11, l12, l13 = sch.split(loop=l3, factors=[v6, v7, v8, v9], preserve_unit_iters=True)
v14, v15, v16, v17 = sch.sample_perfect_tile(
loop=l4, n=4, max_innermost_factor=64, decision=[1, 1, 10, 1]
)
l18, l19, l20, l21 = sch.split(loop=l4, factors=[v14, v15, v16, v17], preserve_unit_iters=True)
v22, v23 = sch.sample_perfect_tile(loop=l5, n=2, max_innermost_factor=64, decision=[100, 1])
l24, l25 = sch.split(loop=l5, factors=[v22, v23], preserve_unit_iters=True)
sch.reorder(l10, l18, l11, l19, l24, l12, l20, l25, l13, l21)
(b26,) = sch.get_consumers(block=b0)
sch.reverse_compute_at(block=b26, loop=l18, preserve_unit_loops=True, index=-1)
sch.annotate(block_or_loop=b2, ann_key="meta_schedule.parallel", ann_val=96)
sch.annotate(block_or_loop=b2, ann_key="meta_schedule.vectorize", ann_val=64)
v27 = sch.sample_categorical(
candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25], decision=0
)
sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v27)
tuning_record = ms.database.TuningRecord(sch.trace, workload, run_secs=[0.0])
db.commit_tuning_record(tuning_record)
### Optimize the model with tuned-log
with db:
opt_model = torch.compile(model, backend=relax_dynamo())
inp = torch.randn(10, 100)
tvm.testing.assert_allclose(
opt_model(inp).detach().numpy(), model(inp).detach().numpy(), rtol=1e-5, atol=1e-5
)
def test_subgraph_capture():
import torch
from tvm.relax.frontend.torch.dynamo import dynamo_capture_subgraphs
class Input1(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin = torch.nn.Linear(100, 10)
def forward(self, x):
return torch.nn.functional.relu(self.lin(x))
@tvm.script.ir_module
class Expected1:
@R.function
def subgraph_0(
inp_0: R.Tensor((10, 100), dtype="float32"),
w0: R.Tensor((10, 100), dtype="float32"),
w1: R.Tensor((10,), dtype="float32"),
) -> R.Tensor((10, 10), dtype="float32"):
# block 0
with R.dataflow():
lv: R.Tensor((100, 10), dtype="float32") = R.permute_dims(w0, axes=None)
lv1: R.Tensor((10, 10), dtype="float32") = R.matmul(inp_0, lv, out_dtype="float32")
lv2: R.Tensor((10, 10), dtype="float32") = R.add(lv1, w1)
lv3: R.Tensor((10, 10), dtype="float32") = R.nn.relu(lv2)
gv: R.Tensor((10, 10), dtype="float32") = lv3
R.output(gv)
return gv
model = Input1()
mod = dynamo_capture_subgraphs(model, torch.randn(10, 100))
binding = {"w0": model.lin.weight.detach().numpy(), "w1": model.lin.bias.detach().numpy()}
binding = {k: tvm.nd.array(v) for k, v in binding.items()}
expected = relax.transform.BindParams("subgraph_0", binding)(Expected1)
tvm.ir.assert_structural_equal(mod, expected)
def Input2(a, b):
x = a / (torch.sin(a) + 1)
if torch.sum(b) < 1:
b = b * -1
return x * b
@tvm.script.ir_module
class Expected2:
@R.function
def subgraph_0(
inp_0: R.Tensor((10,), dtype="float32"), inp_1: R.Tensor((10,), dtype="float32")
) -> R.Tuple(R.Tensor((10,), dtype="float32"), R.Tensor((), dtype="bool")):
# block 0
with R.dataflow():
lv: R.Tensor((10,), dtype="float32") = R.sin(inp_0)
lv1: R.Tensor((10,), dtype="float32") = R.add(lv, R.const(1, "float32"))
lv2: R.Tensor((10,), dtype="float32") = R.divide(inp_0, lv1)
lv3: R.Tensor((), dtype="float32") = R.sum(inp_1, axis=None, keepdims=False)
lv4: R.Tensor((), dtype="bool") = R.less(lv3, R.const(1, "float32"))
gv: R.Tuple(R.Tensor((10,), dtype="float32"), R.Tensor((), dtype="bool")) = (
lv2,
lv4,
)
R.output(gv)
return gv
@R.function
def subgraph_1(
inp_01: R.Tensor((10,), dtype="float32"), inp_11: R.Tensor((10,), dtype="float32")
) -> R.Tensor((10,), dtype="float32"):
# block 0
with R.dataflow():
lv5: R.Tensor((10,), dtype="float32") = R.multiply(inp_11, inp_01)
gv1: R.Tensor((10,), dtype="float32") = lv5
R.output(gv1)
return gv1
mod = dynamo_capture_subgraphs(Input2, torch.randn(10), torch.ones(10))
tvm.ir.assert_structural_equal(mod, Expected2)
class Input3(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin = torch.nn.Linear(100, 10)
def forward(self, x, add_one=False):
if add_one:
x = x + 1
return torch.nn.functional.relu(self.lin(x))
@tvm.script.ir_module
class Expected3:
@R.function
def subgraph_0(
inp_0: R.Tensor((10, 100), dtype="float32"),
w0: R.Tensor((10, 100), dtype="float32"),
w1: R.Tensor((10,), dtype="float32"),
) -> R.Tensor((10, 10), dtype="float32"):
# block 0
with R.dataflow():
lv0 = R.add(inp_0, R.const(1, "float32"))
lv: R.Tensor((100, 10), dtype="float32") = R.permute_dims(w0, axes=None)
lv1: R.Tensor((10, 10), dtype="float32") = R.matmul(lv0, lv, out_dtype="float32")
lv2: R.Tensor((10, 10), dtype="float32") = R.add(lv1, w1)
lv3: R.Tensor((10, 10), dtype="float32") = R.nn.relu(lv2)
gv: R.Tensor((10, 10), dtype="float32") = lv3
R.output(gv)
return gv
model = Input3()
mod = dynamo_capture_subgraphs(model, torch.randn(10, 100), add_one=True)
binding = {"w0": model.lin.weight.detach().numpy(), "w1": model.lin.bias.detach().numpy()}
binding = {k: tvm.nd.array(v) for k, v in binding.items()}
expected = relax.transform.BindParams("subgraph_0", binding)(Expected3)
tvm.ir.assert_structural_equal(mod, expected)
def verify_dynamo_model(torch_model, input_info, binding, expected):
import torch
import torch._dynamo as dynamo
from tvm.relax.frontend.torch import from_fx
args = []
for info in input_info:
args.append(torch.zeros(*info[0], dtype=_convert_data_type(info[1])))
graph_model = dynamo.export(torch_model, *args)[0]
mod = from_fx(graph_model, input_info, unwrap_unit_return_tuple=True)
binding = {k: tvm.nd.array(v) for k, v in binding.items()}
expected = relax.transform.BindParams("main", binding)(expected)
tvm.ir.assert_structural_equal(mod, expected)
def _convert_data_type(input_type):
"""converts the PyTorch scalar type input_type to a TVM dtype."""
import torch # type: ignore
input_type = input_type.lower() if isinstance(input_type, str) else input_type
if input_type == "float32":
return torch.float32
elif input_type == "float16":
return torch.float16
elif input_type == "int64":
return torch.int64
elif input_type == "int32":
return torch.int32
elif input_type == "bool":
return torch.bool
else:
raise NotImplementedError("input_type {} is not handled yet".format(input_type))
@tvm.testing.requires_gpu
def test_ones():
import torch
from torch.nn import Module
class Ones(Module):
def forward(self, input):
return torch.ones((10, 10), dtype=torch.float32)
@I.ir_module
class Expected1:
@R.function
def main(
inp_0: R.Tensor((256, 256), dtype="float32")
) -> R.Tensor((10, 10), dtype="float32"):
with R.dataflow():
lv: R.Tensor((10, 10), dtype="float32") = R.full(
R.shape([10, 10]), R.const(1, "float32"), dtype="float32"
)
gv: R.Tensor((10, 10), dtype="float32") = lv
R.output(gv)
return gv
verify_dynamo_model(
Ones(),
[([256, 256], "float32")],
{},
Expected1,
)
@tvm.testing.requires_gpu
def test_full():
import torch
from torch.nn import Module
class Full(Module):
def forward(self, input):
return torch.full((10, 10), 1, dtype=torch.float32)
@I.ir_module
class Expected1:
@R.function
def main(
inp_0: R.Tensor((256, 256), dtype="float32")
) -> R.Tensor((10, 10), dtype="float32"):
with R.dataflow():
lv: R.Tensor((10, 10), dtype="float32") = R.full(
R.shape([10, 10]), R.const(1, "float32"), dtype="float32"
)
gv: R.Tensor((10, 10), dtype="float32") = lv
R.output(gv)
return gv
verify_dynamo_model(
Full(),
[([256, 256], "float32")],
{},
Expected1,
)
@tvm.testing.requires_gpu
def test_masked_fill():
import torch
from torch.nn import Module
class MaskedFill(Module):
def forward(self, mask, input):
return input.masked_fill(mask, 0)
class InplaceMaskedFill(Module):
def forward(self, mask, input):
input.masked_fill_(mask, 0)
return input
@I.ir_module
class Expected1:
@R.function
def main(
inp_0: R.Tensor((256, 256), dtype="bool"), inp_1: R.Tensor((256, 256), dtype="float32")
) -> R.Tensor((256, 256), dtype="float32"):
with R.dataflow():
lv: R.Tensor((256, 256), dtype="float32") = R.full_like(
inp_1, R.const(0, "int32"), dtype="void"
)
lv1: R.Tensor((256, 256), dtype="float32") = R.where(inp_0, lv, inp_1)
gv: R.Tensor((256, 256), dtype="float32") = lv1
R.output(gv)
return gv
verify_dynamo_model(
MaskedFill(), [([256, 256], "bool"), ([256, 256], "float32")], {}, Expected1
)
verify_dynamo_model(
InplaceMaskedFill(), [([256, 256], "bool"), ([256, 256], "float32")], {}, Expected1
)
if __name__ == "__main__":
tvm.testing.main()