diff --git a/tritonbench/operators/rms_norm/fused_triton.py b/tritonbench/operators/rms_norm/fused_triton.py index 5a8f1472..6dd3aae5 100644 --- a/tritonbench/operators/rms_norm/fused_triton.py +++ b/tritonbench/operators/rms_norm/fused_triton.py @@ -83,13 +83,12 @@ def forward(ctx, x, normalized_shape, weight, eps): # allocate output y = torch.empty_like(x) # reshape input data into 2D tensor - x_arg = x.reshape(-1, x.shape[-1]).to(weight.dtype) def rmsnorm_ref(inp, w, eps=1e-6): rms = 1.0 / torch.sqrt(torch.mean(inp.square(), dim=-1, keepdim=True) + eps) return (inp * rms * w).to(inp.dtype), rms - y, rms = rmsnorm_ref(x_arg, weight, eps) + y, rms = rmsnorm_ref(x, weight, eps) ctx.save_for_backward(x, weight, rms) ctx.eps = eps return y