-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
infer_memref_layout.py
231 lines (209 loc) · 8.03 KB
/
infer_memref_layout.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference for memref layout and memory space."""
# mypy: ignore-errors
from jaxlib.mlir import ir
from jaxlib.mlir.dialects import func
from jaxlib.mlir.dialects import memref
import numpy as np
from . import tpu
def _round_up(x: int, to: int):
return ((x + to - 1) // to) * to
def _tiling_factor(
num_128s: int, hardware_generation: int, bitwidth: int
) -> int:
"""Returns the number of 128-element groups in a tile.
Arguments:
num_128s: A number of 128-element groups in the full operand.
hardware_generation: An integer indicating the target TPU generation.
bitwidth: The bitwidth of the element type of the operand.
"""
assert bitwidth.bit_count() == 1 and (4 <= bitwidth <= 32)
packing = 32 // bitwidth
min_tiling = (1 + (hardware_generation < 4)) * packing
max_tiling = 8
tiling = min_tiling
while tiling < min(num_128s, max_tiling):
tiling *= 2
return tiling
def infer_memref(
memref: ir.MemRefType, hardware_generation: int
) -> ir.MemRefType:
"""Infers the layout and memory space attributes of the given memref.
Arguments:
memref: The memref type with potentially missing layout and memory space.
hardware_generation: The TPU hardware generation to target.
Returns:
A memref type with layout and memory space filled in.
"""
vmem = ir.Attribute.parse("#tpu.memory_space<vmem>")
if tpu.private_has_no_memory_space(memref):
memory_space = vmem
else:
memory_space = memref.memory_space
if ir.AffineMapAttr.isinstance(memref.layout):
if not tpu.private_is_identity(memref.layout):
raise NotImplementedError("Non-identity affine layout")
bitwidth: int
if ir.BF16Type.isinstance(memref.element_type):
bitwidth = 16
elif ir.F32Type.isinstance(memref.element_type):
bitwidth = 32
elif ir.IntegerType.isinstance(memref.element_type):
bitwidth = ir.IntegerType(memref.element_type).width
else:
raise NotImplementedError(
f"Unrecognized element type: {memref.element_type}"
)
# Infer the layout.
if memref.rank == 1:
tile = _tiling_factor(
((memref.shape[-1] + 127) // 128), hardware_generation, bitwidth
) * 128
if bitwidth == 32:
trailing_tiles = ""
else:
if bitwidth.bit_count() != 1 or bitwidth > 32:
raise NotImplementedError(f"Unsupported bitwidth: {bitwidth}")
trailing_tiles = f"(128)({32 // bitwidth},1)"
layout = ir.Attribute.parse(f"#tpu.tiled<({tile}){trailing_tiles},[1]>")
else:
leading_tile = _tiling_factor(
memref.shape[-2], hardware_generation, bitwidth
)
if bitwidth == 32:
trailing_tiles = ""
else:
if bitwidth.bit_count() != 1 or bitwidth > 32:
raise NotImplementedError(f"Unsupported bitwidth: {bitwidth}")
trailing_tiles = f"({32 // bitwidth},1)"
tile_strides = [None] * memref.rank
stride = 1
for i in range(memref.rank - 1, -1, -1):
tile_strides[i] = stride
if i == memref.rank - 1:
stride *= (memref.shape[i] + 127) // 128
elif i == memref.rank - 2:
stride *= (memref.shape[i] + leading_tile - 1) // leading_tile
else:
stride *= memref.shape[i]
layout = ir.Attribute.parse(
f"#tpu.tiled<({leading_tile},128){trailing_tiles},{tile_strides}>"
)
elif tpu.private_is_tiled_layout(memref.layout):
layout = memref.layout
else:
raise NotImplementedError("Unrecognized layout annotation")
new_shape = list(memref.shape)
# Make sure only the first tile might introduce padding.
first_tile, *_ = tiles = tpu.private_get_tiles(layout)
tiled_dims = np.asarray(first_tile, dtype=np.int32)
for t in tiles:
t = np.asarray(t, dtype=np.int32)
if len(t) > len(tiled_dims):
raise NotImplementedError("Layout too complicated")
untiled_prefix, tiled_suffix = tiled_dims[:-len(t)], tiled_dims[-len(t):]
if np.any(tiled_suffix % t != 0):
raise NotImplementedError("Layout too complicated")
tiled_dims = np.concatenate([untiled_prefix, tiled_suffix // t, t])
untiled_dims = len(new_shape) - len(first_tile)
if untiled_dims < 0:
raise ValueError("Invalid tiling")
for i, size in enumerate(first_tile):
new_shape[untiled_dims + i] = _round_up(
new_shape[untiled_dims + i], to=size
)
return ir.MemRefType.get(
new_shape, memref.element_type, layout, memory_space
)
def infer_block(block: ir.Block, hardware_generation: int):
ops = list(block.operations)
if not ops:
return
ops_next = [*ops[1:], None]
assert len(ops) == len(ops_next)
for op, op_next in zip(ops, ops_next):
if isinstance(op, memref.AllocaOp):
arg = op.result
memref_ty = op.result.type
new_memref_ty = infer_memref(memref_ty, hardware_generation)
op.result.set_type(new_memref_ty)
if memref_ty != new_memref_ty:
if op_next is None:
ip = ir.InsertionPoint.at_block_end(block)
else:
ip = ir.InsertionPoint(op_next)
with ip:
erase_op = tpu.EraseLayoutOp(
ir.MemRefType.get(
new_memref_ty.shape,
memref_ty.element_type,
None,
new_memref_ty.memory_space),
arg,
)
tpu.private_replace_all_uses_except(arg, erase_op.result, erase_op)
else:
infer_op(op, hardware_generation)
def infer_op(op: ir.Operation, hardware_generation: int):
for region in op.regions:
for block in region.blocks:
infer_block(block, hardware_generation)
def infer_func(f: func.FuncOp, hardware_generation: int):
if len(f.body.blocks) != 1:
raise ValueError("Functions should only have a single block")
(entry,) = f.body.blocks
new_arg_types = []
with ir.InsertionPoint.at_block_begin(entry):
for arg in entry.arguments:
try:
memref_ty = ir.MemRefType(arg.type)
except ValueError:
new_arg_types.append(arg.type)
continue
new_memref_ty = infer_memref(memref_ty, hardware_generation)
arg.set_type(new_memref_ty)
new_arg_types.append(new_memref_ty)
if memref_ty != new_memref_ty:
# Some standard MLIR ops have static checks that seems unreasonable,
# and we know they hold in the way they are used in Mosaic. Still,
# verification with layouts likes to fail, because it can't statically
# prove the properties.
erase_op = tpu.EraseLayoutOp(
ir.MemRefType.get(
new_memref_ty.shape,
memref_ty.element_type,
None,
new_memref_ty.memory_space,
),
arg,
)
tpu.private_replace_all_uses_except(arg, erase_op.result, erase_op)
f.attributes["function_type"] = ir.TypeAttr.get(
ir.FunctionType.get(new_arg_types, f.type.results)
)
for op in entry.operations:
infer_op(op, hardware_generation)
def infer_module(module: ir.Module, hardware_generation: int):
"""Infers the layout and memory space attributes of function memref arguments.
In the future we should require those annotations from Mosaic users, but it's
best to keep them internal for as long as they are under development.
Arguments:
module: The MLIR module on which to perform the inference.
hardware_generation: The TPU hardware generation to target.
"""
# TODO(apaszke): Do layout assignment for scoped allocations too.
for f in module.body:
assert isinstance(f, func.FuncOp)
infer_func(f, hardware_generation)