Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions onnxscript/_thirdparty/asciichartpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,8 @@ def plot(series, *, bin_edges=None, cfg=None):
height = cfg.get("height", interval)
ratio = height / interval if interval > 0 else 1

min2 = int(floor(minimum * ratio))
max2 = int(ceil(maximum * ratio))
min2 = floor(minimum * ratio)
max2 = ceil(maximum * ratio)

def clamp(n):
return min(max(n, minimum), maximum)
Expand Down
16 changes: 8 additions & 8 deletions onnxscript/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1239,14 +1239,14 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
if i != len(loop_stmt.body) - 1:
self.fail(s, "Instruction break must be the last one of the loop.")

_current_scope = self._current_scope()
if s.test.id not in _current_scope:
current_scope = self._current_scope()
if s.test.id not in current_scope:
self.fail(
loop_stmt,
f"Unable to find condition variable {s.test.id!r} in known "
f"variables {list(_current_scope)!r}.",
f"variables {list(current_scope)!r}.",
)
condition_name = _current_scope[s.test.id].value
condition_name = current_scope[s.test.id].value
operator_name = "Not"
continue
self._translate_stmt(s)
Expand All @@ -1255,14 +1255,14 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:

if cond_while is not None:
# Loop while
_current_scope = self._current_scope()
if cond_while not in _current_scope:
current_scope = self._current_scope()
if cond_while not in current_scope:
self.fail(
loop_stmt,
f"Unable to find condition variable {cond_while!r} in known "
f"variables {list(_current_scope)!r}.",
f"variables {list(current_scope)!r}.",
)
o_cond_var = _current_scope[cond_while].value
o_cond_var = current_scope[cond_while].value

self.emit(
[o_cond_out],
Expand Down
8 changes: 4 additions & 4 deletions onnxscript/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,16 +290,16 @@ def eval_function(
has_array = False
for arg, param_schema in tagged_args:
if param_schema.is_input:
adapted_arg, _has_array = _adapt_to_eager_mode(arg)
has_array = has_array or _has_array
adapted_arg, has_array_ = _adapt_to_eager_mode(arg)
has_array = has_array or has_array_
adapted_args.append(adapted_arg)
else:
adapted_args.append(arg)

for key, (arg, param_schema) in tagged_kwargs.items():
if param_schema.is_input:
adapted_arg, _has_array = _adapt_to_eager_mode(arg)
has_array = has_array or _has_array
adapted_arg, has_array_ = _adapt_to_eager_mode(arg)
has_array = has_array or has_array_
adapted_kwargs[key] = adapted_arg
else:
adapted_kwargs[key] = arg
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,15 +210,15 @@ def type_constraints(self, signature_only: bool = True) -> OnnxFunctionTypeConst
)

# Rename type constraints to T0, T1, T2, ...
_seen_type_constraints: Set[TypeConstraint] = set()
seen_type_constraints: Set[TypeConstraint] = set()
for type_constraint in (
*input_type_constraints.values(),
*output_type_constraints.values(),
*intermediate_type_constraints.values(),
):
if type_constraint is not None and type_constraint not in _seen_type_constraints:
type_constraint.name = f"T{len(_seen_type_constraints)}"
_seen_type_constraints.add(type_constraint)
if type_constraint is not None and type_constraint not in seen_type_constraints:
type_constraint.name = f"T{len(seen_type_constraints)}"
seen_type_constraints.add(type_constraint)

return OnnxFunctionTypeConstraints(
input_type_constraints, output_type_constraints, intermediate_type_constraints
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def main(args: argparse.Namespace) -> None:
functions[module_name] = {}
op_name = get_op_name(func)
if op_name in functions[module_name]:
logging.warning(
logging.warning( # noqa: LOG015
"Duplicated function: %s, overload: %s", op_name, func.func.name.overload_name
)
continue
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def _get_func_schema_in_namespace(namespaces: List[_OpNamespace]) -> Dict[str, F
# to "resize(Tensor a, SymInt[] shape) -> Tensor"
if "!" in op_overload_packet.schema:
op_overload_packet.schema = re.sub( # type: ignore[attr-defined]
"[(][A-Za-z]![)]", "", op_overload_packet.schema
r"[(][A-Za-z]![)]", "", op_overload_packet.schema
)

# FIXME: remove below code if the issue below is fixed.
Expand All @@ -283,7 +283,7 @@ def main(args: argparse.Namespace) -> None:
if module_name not in functions:
functions[module_name] = {}
if op_name in functions[module_name]:
logging.warning(
logging.warning( # noqa: LOG015
"Duplicated function: %s, overload: %s",
op_name,
func_schema.name.overload_name,
Expand Down
4 changes: 2 additions & 2 deletions onnxscript/ir/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,7 +1071,7 @@ def format_name(value_name: str) -> str:

for input in function.inputs:
if not input.name:
logging.warning(
logger.warning(
"Function '%s': Value name not set for function input: %s",
function_qualified_name,
input,
Expand All @@ -1084,7 +1084,7 @@ def format_name(value_name: str) -> str:
for node in function:
for node_output in node.outputs:
if not node_output.name:
logging.warning(
logger.warning(
"Function '%s': Value name not set for node output: %s",
function_qualified_name,
node_output,
Expand Down
8 changes: 4 additions & 4 deletions onnxscript/rewriter/onnxruntime/xformers/_smollm_1layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def main_graph(
val_191 = opset18.Transpose(slice_scatter, perm=[1, 0, 2, 3])
slice_scatter_1 = opset18.Transpose(val_191, perm=[1, 0, 2, 3])
unsqueeze_6 = opset18.Unsqueeze(input2, 1)
_to_copy_1 = opset18.Cast(unsqueeze_6, to=1)
to_copy_1 = opset18.Cast(unsqueeze_6, to=1)
view_1 = opset18.Constant(
value=make_tensor(
"value",
Expand Down Expand Up @@ -113,7 +113,7 @@ def main_graph(
],
)
)
view_2 = opset18.Reshape(_to_copy_1, [1, 1, 10], allowzero=0)
view_2 = opset18.Reshape(to_copy_1, [1, 1, 10], allowzero=0)
bmm = view_1 @ view_2
view_3 = opset18.Reshape(bmm, [1, 32, 10], allowzero=0)
transpose = opset18.Transpose(view_3, perm=[0, 2, 1])
Expand Down Expand Up @@ -199,8 +199,8 @@ def main_graph(
mul_13 = model_norm_weight * mul_12
t_7 = opset18.Transpose(lm_head_weight, perm=[1, 0])
view_23 = mul_13 @ t_7
_to_copy_12 = opset18.Identity(view_23)
return _to_copy_12, add_3, transpose_3
to_copy_12 = opset18.Identity(view_23)
return to_copy_12, add_3, transpose_3

model = main_graph.to_model_proto()
return model
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/rewriter/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def clone(self, node_map: dict[NodePattern, NodePattern]) -> ValuePattern:
def name(self) -> str | None:
return self._name

def producer(self) -> None | NodePattern:
def producer(self) -> NodePattern | None:
return None

def uses(self) -> Sequence[tuple[NodePattern, int]]:
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/tools/benchmark/benchmark_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def _cmd_line(script_name: str, **kwargs: dict[str, Any]) -> list[str]:


def _extract_metrics(text: str) -> dict[str, str]:
reg = re.compile(":(.*?),(.*.?);")
reg = re.compile(r":(.*?),(.*.?);")
res = reg.findall(text)
if len(res) == 0:
return {}
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/tools/benchmark/benchmark_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _cmd_line(script_name: str, **kwargs: dict[str, str | int | float]) -> list[


def _extract_metrics(text: str) -> dict[str, str]:
reg = re.compile(":(.*?),(.*.?);")
reg = re.compile(r":(.*?),(.*.?);")
res = reg.findall(text)
if len(res) == 0:
return {}
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ ignore = [
"PYI041", # int | float is more clear
"RUF022", # We don't need to sort __all__ for elements to be grouped
"RUF031", # Parentheses for tuple in subscripts is more readable
"RUF052", # Variables with `_` prefix may not be dummy variables in all cases
"SIM102", # Collapible if statements are not always more readable
"SIM108", # We don't always encourage ternary operators
"SIM114", # Don't always combine if branches for debugability
Expand Down
2 changes: 1 addition & 1 deletion requirements/lintrunner/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# This file is auto updated by dependabot
lintrunner-adapters>=0.8.0
# RUFF, RUFF-FIX
ruff==0.7.3
ruff==0.8.4
# MYPY
mypy==1.10.1
types-PyYAML==6.0.12.20240808
Expand Down
18 changes: 6 additions & 12 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,19 +254,16 @@ def _embedding_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
"""Remove arguments not present in the aten op signature."""
if "max_norm" in kwargs:
del kwargs["max_norm"]
if "norm_type" in kwargs:
del kwargs["norm_type"]
kwargs.pop("max_norm", None)
kwargs.pop("norm_type", None)
return args, kwargs


def _empty_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
"""Remove arguments not present in the aten op signature."""
if "requires_grad" in kwargs:
del kwargs["requires_grad"]
kwargs.pop("requires_grad", None)
return args, kwargs


Expand Down Expand Up @@ -325,8 +322,7 @@ def _max_pool_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
# Remove return_indices argument because this op doesn't accept it
if "return_indices" in kwargs:
del kwargs["return_indices"]
kwargs.pop("return_indices", None)
return args, kwargs


Expand Down Expand Up @@ -364,8 +360,7 @@ def _nll_loss_input_wrangler(
def _nonzero_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
if "as_tuple" in kwargs:
del kwargs["as_tuple"]
kwargs.pop("as_tuple", None)
return args, kwargs


Expand Down Expand Up @@ -421,8 +416,7 @@ def _roll_input_wrangler(
def _scalar_tensor_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
if "requires_grad" in kwargs:
del kwargs["requires_grad"]
kwargs.pop("requires_grad", None)
return args, kwargs


Expand Down
Loading