Skip to content

Commit

Permalink
Remove the dependency of the pre-registered op (#74)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #74

We can register the TorchScriptOp on the fly

Reviewed By: louisfeng

Differential Revision:
D41002899

Privacy Context Container: L1137347

fbshipit-source-id: dc12c86efc478054c7c75fc155a644fefb4516a1
  • Loading branch information
Sung-Han Lin authored and facebook-github-bot committed May 25, 2023
1 parent c29a4a8 commit df99f82
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 15 deletions.
10 changes: 8 additions & 2 deletions train/compute/python/lib/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from .init_helper import get_logger
from .iterator import config_iterator_map, ConfigIterator, DefaultConfigIterator
from .operator import op_map, OperatorInterface
from .pytorch.operator_impl import TorchScriptOp


logger = get_logger()

Expand Down Expand Up @@ -64,8 +66,12 @@ def input_data_generator(self, value: Type[DataGenerator]):
def make_op_config(op_name: str, op_info: Dict[str, Any], device: str):
global op_map
if (op_name not in op_map) or (not op_map[op_name]):
logger.warning(f"{op_name} has no valid callable defined, skipped.")
return None
if op_name.startswith("aten::"):
logger.debug(f"register op: {op_name}")
op_map[op_name] = TorchScriptOp(op_name)
else:
logger.warning(f"{op_name} has no valid callable defined, skipped.")
return None

op = op_map[op_name]
op.device = device
Expand Down
14 changes: 1 addition & 13 deletions train/compute/python/workloads/pytorch/native_basic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch

from ...lib.operator import OperatorInterface, register_operators
from ...lib.pytorch.operator_impl import BuildableOp, CallableOp, TorchScriptOp, UnaryOp
from ...lib.pytorch.operator_impl import BuildableOp, CallableOp, UnaryOp


# Unary
Expand Down Expand Up @@ -37,15 +37,3 @@
"torch.nn.Linear": BuildableOp(torch.nn.Linear),
}
register_operators(buildable_ops)


# Operator schema: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
torchscript_ops: Dict[str, OperatorInterface] = {
"aten::add": TorchScriptOp("aten::add"),
"aten::add_": TorchScriptOp("aten::add_"),
"aten::matmul": TorchScriptOp("aten::matmul"),
"aten::mul": TorchScriptOp("aten::mul"),
"aten::sum": TorchScriptOp("aten::sum"),
"aten::linear": TorchScriptOp("aten::linear"),
}
register_operators(torchscript_ops)

0 comments on commit df99f82

Please sign in to comment.