Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trouble when fusing layernorm with pointwise operation. #88

Closed
xcwang1999 opened this issue Jul 11, 2024 · 2 comments
Closed

Trouble when fusing layernorm with pointwise operation. #88

xcwang1999 opened this issue Jul 11, 2024 · 2 comments

Comments

@xcwang1999
Copy link

xcwang1999 commented Jul 11, 2024

Describe the bug
I'm trying to fuse layernorm node with pointwise node(mul)
I got this "CUDNN_STATUS_BAD_PARAM_NULL_POINTER" when executing the graph, the error seems to come from the params passed to variant_pack, but I can't locate where my mistake is.

System Environment (please complete the following information):

  • cudnn_frontend version: 1.5.1
  • cudnn_backend version: 9.1.0
  • GPU arch: RTX 3050
  • cuda runtime version: 12.4
  • cuda driver version: 550.90.07
  • host compiler: gcc12.3.0
  • OS: Ubuntu22.04

API logs
be.log
fe.log

To Reproduce

import cudnn
import pytest
import torch

import functools

def torch_fork_set_rng(seed=None):
    def decorator_(func):
        @functools.wraps(func)
        def wrapper_(*args, **kwargs):
            with torch.random.fork_rng(devices=range(torch.cuda.device_count())):
                if seed is not None:
                    torch.manual_seed(seed)
                return func(*args, **kwargs)

        return wrapper_

    return decorator_

@torch_fork_set_rng(seed=0)
def test_layernorm(param_extract):

    embedding_dim, input_type = param_extract

    if input_type == torch.bfloat16:
        atol, rtol = 0.125, 0.125
    else:
        atol, rtol = 1e-2, 1e-2

    batch_size, seq_size = 16, 128
    N, C, H, W = batch_size * seq_size, embedding_dim, 1, 1

    epsilon_value = 1e-3

    x_gpu = (
        3
        * torch.randn(
            N, C, H, W, requires_grad=False, device="cuda", dtype=input_type
        ).to(memory_format=torch.channels_last)
        - 0.5
    )
    scale_gpu = (
        5
        * torch.randn(
            1, C, H, W, requires_grad=False, device="cuda", dtype=input_type
        ).to(memory_format=torch.channels_last)
        - 1
    )
    bias_gpu = (
        7
        * torch.randn(
            1, C, H, W, requires_grad=False, device="cuda", dtype=input_type
        ).to(memory_format=torch.channels_last)
        - 2
    )
    epsilon_cpu = torch.full(
        (1, 1, 1, 1),
        epsilon_value,
        requires_grad=False,
        device="cpu",
        dtype=torch.float32,
    )

    mask_gpu = torch.ones(N, C, H, W, device="cuda", dtype=input_type).to(memory_format=torch.channels_last)

    Y_expected = torch.nn.functional.layer_norm(
        x_gpu,
        [C, H, W],
        weight=scale_gpu.squeeze(0),
        bias=bias_gpu.squeeze(0),
        eps=epsilon_value,
    )
    mean_expected = x_gpu.to(torch.float32).mean(dim=(1, 2, 3), keepdim=True)
    inv_var_expected = torch.rsqrt(
        torch.var(x_gpu.to(torch.float32), dim=(1, 2, 3), keepdim=True) + epsilon_value
    )

    handle = cudnn.create_handle()
    stream = torch.cuda.current_stream().cuda_stream
    cudnn.set_stream(handle=handle, stream=stream)

    graph = cudnn.pygraph(
        intermediate_data_type=cudnn.data_type.FLOAT,
        compute_data_type=cudnn.data_type.FLOAT,
        handle=handle,
    )

    X = graph.tensor(
        name="X", dim=x_gpu.size(), stride=x_gpu.stride(), data_type=x_gpu.dtype
    )
    scale = graph.tensor(
        name="scale",
        dim=scale_gpu.size(),
        stride=scale_gpu.stride(),
        data_type=scale_gpu.dtype,
    )
    bias = graph.tensor(
        name="bias",
        dim=bias_gpu.size(),
        stride=bias_gpu.stride(),
        data_type=bias_gpu.dtype,
    )
    epsilon = graph.tensor(
        name="epsilon",
        dim=epsilon_cpu.size(),
        stride=epsilon_cpu.stride(),
        is_pass_by_value=True,
        data_type=epsilon_cpu.dtype,
    )

    mask = graph.tensor(
        name="mask", dim=mask_gpu.size(), stride=mask_gpu.stride(), data_type=mask_gpu.dtype
    )

    X_after_mul = graph.mul(name="mul", a=X, b=mask, compute_data_type=cudnn.data_type.FLOAT)

    Y, mean, inv_var = graph.layernorm(
        name="LN",
        norm_forward_phase=cudnn.norm_forward_phase.TRAINING,
        input=X_after_mul,
        scale=scale,
        bias=bias,
        epsilon=epsilon,
    )

    Y.set_output(True).set_data_type(x_gpu.dtype)
    mean.set_output(True).set_data_type(mean_expected.dtype)
    inv_var.set_output(True).set_data_type(inv_var_expected.dtype)

    graph.validate()
    graph.build_operation_graph()
    graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
    graph.check_support()
    graph.build_plans(cudnn.build_plan_policy.ALL)

    Y_actual = torch.empty_like(x_gpu)
    mean_actual = torch.empty_like(mean_expected)
    inv_var_actual = torch.empty_like(inv_var_expected)

    workspace = torch.empty(
        graph.get_workspace_size(), device="cuda", dtype=torch.uint8
    )

    graph.execute(
        {
            X: x_gpu.detach(),
            scale: scale_gpu.detach(),
            bias: bias_gpu.detach(),
            mask: mask_gpu,
            epsilon: epsilon_cpu,
            Y: Y_actual,
            mean: mean_actual,
            inv_var: inv_var_actual,
        },
        workspace,
        handle=handle,
    )

    torch.cuda.synchronize()

    torch.testing.assert_close(Y_expected, Y_actual, atol=atol, rtol=rtol)
    torch.testing.assert_close(mean_expected, mean_actual, atol=atol, rtol=rtol)
    torch.testing.assert_close(inv_var_expected, inv_var_actual, atol=atol, rtol=rtol)

    cudnn.destroy_handle(handle)

if __name__ == "__main__":
    test_layernorm((1600, torch.float))

Additional context
thank you for anything

@Anerudhan
Copy link
Collaborator

Hi @xcwang1999 ,

Currently, layer norm fusions (layer norm + pointwise fusions) are not supported in cuDNN. However, we are working on extending the support for the same.

Aside, having a CUDNN_STATUS_BAD_PARAM_NULL_POINTER seems to be a wrong error message. It should have been CUDNN_STATUS_NOT_SUPPORTED. We will fix on improving the diagnostic messages as well.

Thanks for raising this issue.

Regards
Anerudhan obo cudnn team

@xcwang1999
Copy link
Author

Thanks for reply!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants