Skip to content

Commit 662f9bf

Browse files
author
Tobias Gysi
committed
[mlir][linalg][python] Adapt the OpDSL to use scalars.
The patch replaces the existing capture functionality by scalar operands that have been introduced by https://reviews.llvm.org/D104109. Scalar operands behave as tensor operands except for the fact that they are not indexed. As a result ScalarDefs can be accessed directly as no indexing expression is needed. The patch only updates the OpDSL. The C++ side is updated by a follow up patch. Differential Revision: https://reviews.llvm.org/D104220
1 parent 389e749 commit 662f9bf

File tree

10 files changed

+262
-361
lines changed

10 files changed

+262
-361
lines changed

mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py

Lines changed: 84 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
represent actual op definitions (i.e. YAML).
99
"""
1010

11-
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
11+
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple
1212

1313
from mlir import ir as _ir
1414

@@ -50,7 +50,7 @@ def visit_affine_exprs(expr):
5050
self.visit_tensor_exprs(visit_affine_exprs)
5151
return results
5252

53-
def collect_uses(self, uses: Set["TensorUse"]):
53+
def collect_tensor_uses(self, uses: Set["TensorUse"]):
5454
"""Collects all TensorUses reachable through this expression."""
5555

5656
def visit_tensor_use(expr):
@@ -68,14 +68,14 @@ def visit_index(expr):
6868

6969
self.visit_tensor_exprs(visit_index)
7070

71-
def collect_captures(self, captures: Set["CaptureDef"]):
72-
"""Collects all CaptureDefs reachable through this expression."""
71+
def collect_scalar_uses(self, uses: Set["ScalarDef"]):
72+
"""Collects all ScalarDefs reachable through this expression."""
7373

74-
def visit_capture_def(expr):
75-
if isinstance(expr, CaptureDef):
76-
captures.add(expr)
74+
def visit_scalar_def(expr):
75+
if isinstance(expr, ScalarDef):
76+
uses.add(expr)
7777

78-
self.visit_tensor_exprs(visit_capture_def)
78+
self.visit_tensor_exprs(visit_scalar_def)
7979

8080
def __add__(self, rhs: "TensorExpression") -> "TensorExpression":
8181
return PrimFn.add(self, rhs)
@@ -101,19 +101,19 @@ class TensorUse(TensorExpression):
101101
TensorDef.__setitem__
102102
"""
103103

104-
def __init__(self, tensor_def: "TensorDef", indices: Sequence[AffineExprDef]):
105-
self.tensor_def = tensor_def
104+
def __init__(self, operand_def: "OperandDef",
105+
indices: Sequence[AffineExprDef]):
106+
self.operand_def = operand_def
106107
self.indices = tuple(indices)
107108

108109
def to_scalar_expression(self) -> ScalarExpression:
109-
assert self.tensor_def.tensor_name is not None
110-
return ScalarArg(self.tensor_def.tensor_name).expr()
110+
return ScalarArg(self.tensor_name).expr()
111111

112112
@property
113113
def tensor_name(self) -> str:
114-
n = self.tensor_def.tensor_name
115-
assert n is not None, "TensorDef not attached"
116-
return n
114+
name = self.operand_def.name
115+
assert name is not None, "TensorDef not attached"
116+
return name
117117

118118
def __iadd__(self, rhs: TensorExpression) -> TensorExpression:
119119
return ReduceFn.add(*self._compute_reduce_dims(rhs))(rhs)
@@ -133,40 +133,57 @@ def __repr__(self):
133133
return f"{self.tensor_name}[{', '.join([repr(i) for i in self.indices])}]"
134134

135135

136-
class TensorDef:
137-
"""Bookkeeping of a single registered tensor, held in dict by name."""
136+
class OperandDef:
137+
"""Definition of a Tensor or Scalar operand passed to an operation."""
138138

139-
def __init__(self,
140-
type_var: TypeVar,
141-
*shape: AffineExprDef,
142-
indexing_map: Optional[_ir.AffineMap] = None,
143-
output: bool = False):
139+
def __init__(self, type_var: TypeVar, shape: Sequence[AffineExprDef],
140+
scalar: bool, output: bool):
144141
if not isinstance(type_var, TypeVar):
145-
raise ValueError(f"TensorDef requires a TypeVar. Got: {repr(type_var)}")
142+
raise ValueError(f"OperandDef requires a TypeVar. Got: {repr(type_var)}")
146143
self.owner = None # type: Optional["LinalgOpDef"]
147144
self.type_var = type_var
148145
self.shape = shape
149-
self.indexing_map = indexing_map
146+
self.scalar = scalar
150147
self.output = output
151-
self.tensor_name = None # type: Optional[str]
148+
self.name = None # type: Optional[str]
152149
self.registered_index = -1 # type: int
153150

154-
@property
155-
def rank(self) -> int:
156-
"""The rank of the tensor."""
157-
return len(self.shape)
158-
159-
def attach(self, index: int, tensor_name: str, owner: "LinalgOpDef"):
151+
def attach(self, index: int, name: str, owner: "LinalgOpDef"):
160152
if self.owner:
161-
raise ValueError(f"TensorDef already registered with op: {self}")
153+
raise ValueError(f"OperandDef already registered with op: {self}")
162154
self.registered_index = index
163-
self.tensor_name = tensor_name
155+
self.name = name
164156
self.owner = owner
165157

158+
def __hash__(self):
159+
return hash(id(self))
160+
161+
def __repr__(self):
162+
output = "OUTPUT " if self.output else ""
163+
scalar = "SCALAR " if self.scalar else ""
164+
return (f"{self.name}:OperandDef({output}{scalar}"
165+
f"{repr(self.type_var)}, shape={self.shape})")
166+
167+
168+
class TensorDef:
169+
"""Tensor operand definition.
170+
171+
Tensor operands are indexed using the associated indexing_map when forwarded
172+
to the body of the structured op. A unique name identifies the tensor operands
173+
and an index determines their position in the operation's parameter list.
174+
"""
175+
176+
def __init__(self,
177+
type_var: TypeVar,
178+
*shape: AffineExprDef,
179+
output: bool = False):
180+
self.operand_def = OperandDef(type_var, shape, False, output)
181+
166182
def __getitem__(self, dims) -> TensorUse:
167-
assert self.owner, "TensorDef is not attached to an op"
183+
assert self.operand_def.owner, "TensorDef is not attached to an op"
168184
state = AffineBuildState(
169-
global_state=self.owner._affine_state, allow_new_symbols=False)
185+
global_state=self.operand_def.owner._affine_state,
186+
allow_new_symbols=False)
170187
if not isinstance(dims, tuple):
171188
dims = (dims,) # Handle single subscript case.
172189
# Special case: (None) is a 0d-scalar use.
@@ -179,7 +196,7 @@ def __getitem__(self, dims) -> TensorUse:
179196
raise KeyError(
180197
"A TensorDef can only be subscripted by a tuple of affine dims")
181198
exprs.append(expr_def)
182-
return TensorUse(self, exprs)
199+
return TensorUse(self.operand_def, exprs)
183200

184201
def __setitem__(self, dims, value):
185202
"""Creates a new 1:1 comprehension by binding this tensor to an expression.
@@ -192,46 +209,28 @@ def __setitem__(self, dims, value):
192209
f"Got: {repr(value)}")
193210
use = self[dims]
194211
comp = Comprehension((use, value))
195-
self.owner.comprehensions.append(comp)
212+
self.operand_def.owner.comprehensions.append(comp)
196213

197-
def __hash__(self):
198-
return hash(id(self))
199214

200-
def __repr__(self):
201-
output = "OUTPUT " if self.output else ""
202-
return (f"{self.tensor_name}:TensorDef({output}{repr(self.type_var)}, "
203-
f"shape={self.shape})")
204-
205-
206-
class CaptureDef(TensorExpression):
207-
"""Defines an SSA value captured by the operation.
215+
class ScalarDef(TensorExpression):
216+
"""Scalar operand definition.
208217
209-
The captured SSA values are not indexed by the indexing_maps of the
210-
structured op (as opposed to memrefs and tensors). A unique name
211-
identifies the captures and an index determines their position the
212-
operation's parameter list.
218+
Scalar operands are forwarded to the body of the structured op as they are.
219+
A unique name identifies the scalars and an index determines their position in
220+
the operation's parameter list.
213221
"""
214222

215223
def __init__(self, type_var: TypeVar):
216-
if not isinstance(type_var, TypeVar):
217-
raise ValueError(f"CaptureDef requires a TypeVar. Got: {repr(type_var)}")
218-
self.owner = None # type: Optional["LinalgOpDef"]
219-
self.type_var = type_var
220-
self.capture_name = None # type: Optional[str]
221-
self.registered_index = -1 # type: int
224+
self.operand_def = OperandDef(type_var, (), True, False)
222225

223-
def attach(self, index: int, capture_name: str, owner: "LinalgOpDef"):
224-
if self.owner:
225-
raise ValueError(f"CaptureDef already registered with op: {self}")
226-
self.registered_index = index
227-
self.capture_name = capture_name
228-
self.owner = owner
226+
@property
227+
def scalar_name(self) -> str:
228+
name = self.operand_def.name
229+
assert name is not None, "ScalarDef not attached"
230+
return name
229231

230232
def to_scalar_expression(self) -> ScalarExpression:
231-
return ScalarCapture(self.capture_name).expr()
232-
233-
def __repr__(self):
234-
return (f"{self.capture_name}:CaptureDef({repr(self.type_var)})")
233+
return ScalarArg(self.scalar_name).expr()
235234

236235

237236
class Comprehension:
@@ -472,43 +471,34 @@ def __init__(self,
472471
doc: Optional[str] = None):
473472
self.metadata = OpMetadataDef(
474473
name=name, cpp_class_name=cpp_class_name, doc=doc)
475-
self.registered_tensors = dict() # type: Dict[str, TensorDef]
476-
self.registered_captures = dict() # type: Dict[str, CaptureDef]
474+
self.registered_operands = dict() # type: Dict[str, OperandDef]
477475
self.comprehensions = list() # type: List[Comprehension]
478476
self._affine_state = AffineBuildState()
479477

480478
@property
481-
def inputs(self) -> Sequence[TensorDef]:
482-
return [t for t in self.registered_tensors.values() if not t.output]
479+
def outputs(self) -> Sequence[OperandDef]:
480+
return [
481+
operand for operand in self.registered_operands.values()
482+
if operand.output
483+
]
483484

484-
@property
485-
def outputs(self) -> Sequence[TensorDef]:
486-
return [t for t in self.registered_tensors.values() if t.output]
487-
488-
def add_tensor(self, tensor_name: str, tensor: TensorDef):
489-
"""Registers a tensor."""
490-
if tensor_name in self.registered_tensors:
491-
raise ValueError(f"Tensor {tensor_name} is already registered "
492-
f"to {self.registered_tensors['tensor_name']}")
493-
tensor.attach(len(self.registered_tensors), tensor_name, self)
494-
self.registered_tensors[tensor_name] = tensor
495-
496-
def add_capture(self, capture_name: str, capture: CaptureDef):
497-
"""Registers a capture."""
498-
if capture_name in self.registered_captures:
499-
raise ValueError(f"Capture {capture_name} is already registered "
500-
f"to {self.registered_captures['capture_name']}")
501-
capture.attach(len(self.registered_captures), capture_name, self)
502-
self.registered_captures[capture_name] = capture
485+
def add_operand(self, name: str, operand: OperandDef):
486+
"""Registers an operand."""
487+
if name in self.registered_operands:
488+
raise ValueError(f"The operand {name} is already registered "
489+
f"to {self.registered_operands['name']}")
490+
if not operand.output and self.outputs:
491+
raise ValueError(f"The operand {name} is an input registered after "
492+
f"the output {self.outputs[-1]}")
493+
operand.attach(len(self.registered_operands), name, self)
494+
self.registered_operands[name] = operand
503495

504496
def __repr__(self):
505497
lines = [
506498
f"LinalgOpDef({self.metadata.name} -> {self.metadata.cpp_class_name},"
507499
]
508-
for name, tensor in self.registered_tensors.items():
509-
lines.append(f" {tensor}")
510-
for name, capture in self.registered_captures.items():
511-
lines.append(f" {capture}")
500+
for name, operand in self.registered_operands.items():
501+
lines.append(f" {operand}")
512502
if self.comprehensions:
513503
lines[-1] += " {"
514504
for comprehension in self.comprehensions:

0 commit comments

Comments
 (0)