Skip to content

[Feature] Use Primitive wrapper to implment the new Dispatch mechanism #2231

@lvyufeng

Description

@lvyufeng
  1. Add ENABLE_DSIPATCH environment
  2. use the primitive wrapper
import mindspore
from mindspore.ops import Primitive
from mindspore.ops._primitive_cache import _get_cache_prim
from mindspore import ops

def hook_call(instance):
    # instance.set_device('CPU')
    def wrapped_call(*args, **kwargs):
        if USE_DISPATCH:
            if getattr(instance, 'primitive_target', 'CPU') != args[0].device:
                instance.set_device(args[0].device[:-2])
        print("【装饰器Hook】调用前")
        result = instance(*args, **kwargs)
        print("【装饰器Hook】调用后")
        return result
    return wrapped_call


x = ops.randn(3)

USE_DISPATCH = False
@mindspore.jit(fullgraph=True, backend='GE')
def fn(x):
    add_op = ops.Add()
    # print(add_op.primitive_target)
    add = hook_call(add_op)
    # print(add_op.primitive_target)
    y = add(x, 1)
    return y
y = fn(x)
print(y, y.device)

Metadata

Metadata

Assignees

Labels

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions