Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Mark the reduce fp16 operator not fusible #100

Merged
merged 3 commits into from Feb 13, 2023

Conversation

yaoyaoding
Copy link
Member

The reduce fp16 operator has not been marked as not fusible. By default, all operator are allowed to be fusible with preceeding and suceeding elementwise operators. We should override allow_prologue and allow_epilogue to mark that this task does not support such fusion (because we used implicit access through pointer to the inputs/outputs).

Reproduce the error:

import hidet


def demo_fusion():
    x = hidet.symbol([9, 1, 3136, 64], dtype='float16').cuda()
    y = hidet.symbol([1, 64, 1, 1], dtype='float16').cuda()
    y_2 = hidet.ops.sum(x, dims=[0])
    y_3 = hidet.ops.reshape(y_2, [1, 1, 56, 56, 64])
    y_4 = hidet.ops.rearrange(y_3, [[4], [2], [3]])
    z = y_4 + y
    y_1 = hidet.ops.relu(z)

    graph = hidet.trace_from(y_1, [x, y])
    with hidet.graph.PassContext() as ctx:
        graph_opt = hidet.graph.optimize(graph)

    print(graph)
    print(graph_opt)
    xx = hidet.randn_like(x)
    yy = hidet.randn_like(y)
    hidet.option.save_lower_ir()
    yy_1 = graph_opt(xx, yy)


if __name__ == '__main__':
    demo_fusion()

FYI @hjjq.

@yaoyaoding yaoyaoding merged commit d97dc2f into hidet-org:main Feb 13, 2023
@yaoyaoding yaoyaoding deleted the fix-reduce-f16 branch February 13, 2023 23:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant