8
8
represent actual op definitions (i.e. YAML).
9
9
"""
10
10
11
- from typing import Any , Dict , List , Optional , Sequence , Set , Tuple , Union
11
+ from typing import Any , Dict , List , Optional , Sequence , Set , Tuple
12
12
13
13
from mlir import ir as _ir
14
14
@@ -50,7 +50,7 @@ def visit_affine_exprs(expr):
50
50
self .visit_tensor_exprs (visit_affine_exprs )
51
51
return results
52
52
53
- def collect_uses (self , uses : Set ["TensorUse" ]):
53
+ def collect_tensor_uses (self , uses : Set ["TensorUse" ]):
54
54
"""Collects all TensorUses reachable through this expression."""
55
55
56
56
def visit_tensor_use (expr ):
@@ -68,14 +68,14 @@ def visit_index(expr):
68
68
69
69
self .visit_tensor_exprs (visit_index )
70
70
71
- def collect_captures (self , captures : Set ["CaptureDef " ]):
72
- """Collects all CaptureDefs reachable through this expression."""
71
+ def collect_scalar_uses (self , uses : Set ["ScalarDef " ]):
72
+ """Collects all ScalarDefs reachable through this expression."""
73
73
74
- def visit_capture_def (expr ):
75
- if isinstance (expr , CaptureDef ):
76
- captures .add (expr )
74
+ def visit_scalar_def (expr ):
75
+ if isinstance (expr , ScalarDef ):
76
+ uses .add (expr )
77
77
78
- self .visit_tensor_exprs (visit_capture_def )
78
+ self .visit_tensor_exprs (visit_scalar_def )
79
79
80
80
def __add__ (self , rhs : "TensorExpression" ) -> "TensorExpression" :
81
81
return PrimFn .add (self , rhs )
@@ -101,19 +101,19 @@ class TensorUse(TensorExpression):
101
101
TensorDef.__setitem__
102
102
"""
103
103
104
- def __init__ (self , tensor_def : "TensorDef" , indices : Sequence [AffineExprDef ]):
105
- self .tensor_def = tensor_def
104
+ def __init__ (self , operand_def : "OperandDef" ,
105
+ indices : Sequence [AffineExprDef ]):
106
+ self .operand_def = operand_def
106
107
self .indices = tuple (indices )
107
108
108
109
def to_scalar_expression (self ) -> ScalarExpression :
109
- assert self .tensor_def .tensor_name is not None
110
- return ScalarArg (self .tensor_def .tensor_name ).expr ()
110
+ return ScalarArg (self .tensor_name ).expr ()
111
111
112
112
@property
113
113
def tensor_name (self ) -> str :
114
- n = self .tensor_def . tensor_name
115
- assert n is not None , "TensorDef not attached"
116
- return n
114
+ name = self .operand_def . name
115
+ assert name is not None , "TensorDef not attached"
116
+ return name
117
117
118
118
def __iadd__ (self , rhs : TensorExpression ) -> TensorExpression :
119
119
return ReduceFn .add (* self ._compute_reduce_dims (rhs ))(rhs )
@@ -133,40 +133,57 @@ def __repr__(self):
133
133
return f"{ self .tensor_name } [{ ', ' .join ([repr (i ) for i in self .indices ])} ]"
134
134
135
135
136
- class TensorDef :
137
- """Bookkeeping of a single registered tensor, held in dict by name ."""
136
+ class OperandDef :
137
+ """Definition of a Tensor or Scalar operand passed to an operation ."""
138
138
139
- def __init__ (self ,
140
- type_var : TypeVar ,
141
- * shape : AffineExprDef ,
142
- indexing_map : Optional [_ir .AffineMap ] = None ,
143
- output : bool = False ):
139
+ def __init__ (self , type_var : TypeVar , shape : Sequence [AffineExprDef ],
140
+ scalar : bool , output : bool ):
144
141
if not isinstance (type_var , TypeVar ):
145
- raise ValueError (f"TensorDef requires a TypeVar. Got: { repr (type_var )} " )
142
+ raise ValueError (f"OperandDef requires a TypeVar. Got: { repr (type_var )} " )
146
143
self .owner = None # type: Optional["LinalgOpDef"]
147
144
self .type_var = type_var
148
145
self .shape = shape
149
- self .indexing_map = indexing_map
146
+ self .scalar = scalar
150
147
self .output = output
151
- self .tensor_name = None # type: Optional[str]
148
+ self .name = None # type: Optional[str]
152
149
self .registered_index = - 1 # type: int
153
150
154
- @property
155
- def rank (self ) -> int :
156
- """The rank of the tensor."""
157
- return len (self .shape )
158
-
159
- def attach (self , index : int , tensor_name : str , owner : "LinalgOpDef" ):
151
+ def attach (self , index : int , name : str , owner : "LinalgOpDef" ):
160
152
if self .owner :
161
- raise ValueError (f"TensorDef already registered with op: { self } " )
153
+ raise ValueError (f"OperandDef already registered with op: { self } " )
162
154
self .registered_index = index
163
- self .tensor_name = tensor_name
155
+ self .name = name
164
156
self .owner = owner
165
157
158
+ def __hash__ (self ):
159
+ return hash (id (self ))
160
+
161
+ def __repr__ (self ):
162
+ output = "OUTPUT " if self .output else ""
163
+ scalar = "SCALAR " if self .scalar else ""
164
+ return (f"{ self .name } :OperandDef({ output } { scalar } "
165
+ f"{ repr (self .type_var )} , shape={ self .shape } )" )
166
+
167
+
168
+ class TensorDef :
169
+ """Tensor operand definition.
170
+
171
+ Tensor operands are indexed using the associated indexing_map when forwarded
172
+ to the body of the structured op. A unique name identifies the tensor operands
173
+ and an index determines their position in the operation's parameter list.
174
+ """
175
+
176
+ def __init__ (self ,
177
+ type_var : TypeVar ,
178
+ * shape : AffineExprDef ,
179
+ output : bool = False ):
180
+ self .operand_def = OperandDef (type_var , shape , False , output )
181
+
166
182
def __getitem__ (self , dims ) -> TensorUse :
167
- assert self .owner , "TensorDef is not attached to an op"
183
+ assert self .operand_def . owner , "TensorDef is not attached to an op"
168
184
state = AffineBuildState (
169
- global_state = self .owner ._affine_state , allow_new_symbols = False )
185
+ global_state = self .operand_def .owner ._affine_state ,
186
+ allow_new_symbols = False )
170
187
if not isinstance (dims , tuple ):
171
188
dims = (dims ,) # Handle single subscript case.
172
189
# Special case: (None) is a 0d-scalar use.
@@ -179,7 +196,7 @@ def __getitem__(self, dims) -> TensorUse:
179
196
raise KeyError (
180
197
"A TensorDef can only be subscripted by a tuple of affine dims" )
181
198
exprs .append (expr_def )
182
- return TensorUse (self , exprs )
199
+ return TensorUse (self . operand_def , exprs )
183
200
184
201
def __setitem__ (self , dims , value ):
185
202
"""Creates a new 1:1 comprehension by binding this tensor to an expression.
@@ -192,46 +209,28 @@ def __setitem__(self, dims, value):
192
209
f"Got: { repr (value )} " )
193
210
use = self [dims ]
194
211
comp = Comprehension ((use , value ))
195
- self .owner .comprehensions .append (comp )
212
+ self .operand_def . owner .comprehensions .append (comp )
196
213
197
- def __hash__ (self ):
198
- return hash (id (self ))
199
214
200
- def __repr__ (self ):
201
- output = "OUTPUT " if self .output else ""
202
- return (f"{ self .tensor_name } :TensorDef({ output } { repr (self .type_var )} , "
203
- f"shape={ self .shape } )" )
204
-
205
-
206
- class CaptureDef (TensorExpression ):
207
- """Defines an SSA value captured by the operation.
215
+ class ScalarDef (TensorExpression ):
216
+ """Scalar operand definition.
208
217
209
- The captured SSA values are not indexed by the indexing_maps of the
210
- structured op (as opposed to memrefs and tensors). A unique name
211
- identifies the captures and an index determines their position the
212
- operation's parameter list.
218
+ Scalar operands are forwarded to the body of the structured op as they are.
219
+ A unique name identifies the scalars and an index determines their position in
220
+ the operation's parameter list.
213
221
"""
214
222
215
223
def __init__ (self , type_var : TypeVar ):
216
- if not isinstance (type_var , TypeVar ):
217
- raise ValueError (f"CaptureDef requires a TypeVar. Got: { repr (type_var )} " )
218
- self .owner = None # type: Optional["LinalgOpDef"]
219
- self .type_var = type_var
220
- self .capture_name = None # type: Optional[str]
221
- self .registered_index = - 1 # type: int
224
+ self .operand_def = OperandDef (type_var , (), True , False )
222
225
223
- def attach (self , index : int , capture_name : str , owner : "LinalgOpDef" ):
224
- if self .owner :
225
- raise ValueError (f"CaptureDef already registered with op: { self } " )
226
- self .registered_index = index
227
- self .capture_name = capture_name
228
- self .owner = owner
226
+ @property
227
+ def scalar_name (self ) -> str :
228
+ name = self .operand_def .name
229
+ assert name is not None , "ScalarDef not attached"
230
+ return name
229
231
230
232
def to_scalar_expression (self ) -> ScalarExpression :
231
- return ScalarCapture (self .capture_name ).expr ()
232
-
233
- def __repr__ (self ):
234
- return (f"{ self .capture_name } :CaptureDef({ repr (self .type_var )} )" )
233
+ return ScalarArg (self .scalar_name ).expr ()
235
234
236
235
237
236
class Comprehension :
@@ -472,43 +471,34 @@ def __init__(self,
472
471
doc : Optional [str ] = None ):
473
472
self .metadata = OpMetadataDef (
474
473
name = name , cpp_class_name = cpp_class_name , doc = doc )
475
- self .registered_tensors = dict () # type: Dict[str, TensorDef]
476
- self .registered_captures = dict () # type: Dict[str, CaptureDef]
474
+ self .registered_operands = dict () # type: Dict[str, OperandDef]
477
475
self .comprehensions = list () # type: List[Comprehension]
478
476
self ._affine_state = AffineBuildState ()
479
477
480
478
@property
481
- def inputs (self ) -> Sequence [TensorDef ]:
482
- return [t for t in self .registered_tensors .values () if not t .output ]
479
+ def outputs (self ) -> Sequence [OperandDef ]:
480
+ return [
481
+ operand for operand in self .registered_operands .values ()
482
+ if operand .output
483
+ ]
483
484
484
- @property
485
- def outputs (self ) -> Sequence [TensorDef ]:
486
- return [t for t in self .registered_tensors .values () if t .output ]
487
-
488
- def add_tensor (self , tensor_name : str , tensor : TensorDef ):
489
- """Registers a tensor."""
490
- if tensor_name in self .registered_tensors :
491
- raise ValueError (f"Tensor { tensor_name } is already registered "
492
- f"to { self .registered_tensors ['tensor_name' ]} " )
493
- tensor .attach (len (self .registered_tensors ), tensor_name , self )
494
- self .registered_tensors [tensor_name ] = tensor
495
-
496
- def add_capture (self , capture_name : str , capture : CaptureDef ):
497
- """Registers a capture."""
498
- if capture_name in self .registered_captures :
499
- raise ValueError (f"Capture { capture_name } is already registered "
500
- f"to { self .registered_captures ['capture_name' ]} " )
501
- capture .attach (len (self .registered_captures ), capture_name , self )
502
- self .registered_captures [capture_name ] = capture
485
+ def add_operand (self , name : str , operand : OperandDef ):
486
+ """Registers an operand."""
487
+ if name in self .registered_operands :
488
+ raise ValueError (f"The operand { name } is already registered "
489
+ f"to { self .registered_operands ['name' ]} " )
490
+ if not operand .output and self .outputs :
491
+ raise ValueError (f"The operand { name } is an input registered after "
492
+ f"the output { self .outputs [- 1 ]} " )
493
+ operand .attach (len (self .registered_operands ), name , self )
494
+ self .registered_operands [name ] = operand
503
495
504
496
def __repr__ (self ):
505
497
lines = [
506
498
f"LinalgOpDef({ self .metadata .name } -> { self .metadata .cpp_class_name } ,"
507
499
]
508
- for name , tensor in self .registered_tensors .items ():
509
- lines .append (f" { tensor } " )
510
- for name , capture in self .registered_captures .items ():
511
- lines .append (f" { capture } " )
500
+ for name , operand in self .registered_operands .items ():
501
+ lines .append (f" { operand } " )
512
502
if self .comprehensions :
513
503
lines [- 1 ] += " {"
514
504
for comprehension in self .comprehensions :
0 commit comments