Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tutel/jit_kernels/sparse.py torch.float16 There is a bug in the calculation: the cuda calculation result is inconsistent with the CPU calculation result and the array is out of bounds #196

Open
WsqRichards1 opened this issue Mar 8, 2023 · 1 comment
Labels
invalid This doesn't seem right

Comments

@WsqRichards1
Copy link

WsqRichards1 commented Mar 8, 2023

code :

import numpy as np
import torch
from tutel.jit_kernels import sparse as jit_kernel
print(torch.version)
def moe_dispatch_bwd_gate():
samples=2
capacity=2
hidden=2
num_experts=1
indices = [0,0]
locations = [0,0]
input = [0.4946, -0.0043, 0.5386, -0.8354]
dispatch = [0.7085, 0.8257, -0.1455, -0.1788]
#int32
indices_t = np.asarray(indices,dtype=np.int32)
locations_t = np.asarray(locations,dtype=np.int32)
#float / half
input_t = np.asarray(input,dtype=np.float16)
dispatch_t = np.asarray(dispatch,dtype=np.float16)
indices_gpu = torch.from_numpy(indices_t).cuda()
locations_gpu = torch.from_numpy(locations_t).cuda()
input_gpu = torch.from_numpy(input_t).cuda()
dispatch_gpu = torch.from_numpy(dispatch_t).cuda()
print("cuda:")
print("indices_gpu:",indices_gpu)
print("locations_gpu:",locations_gpu)
print("input_gpu:",input_gpu)
print("dispatch_gpu:",dispatch_gpu)
# call gpu func
grad_gates = torch.zeros([samples], dtype=input_gpu.dtype, device=input_gpu.device)
moe_dispatch_bwd_gate = jit_kernel.create_backward_gate(input_gpu.dtype, input_gpu.is_cuda)
moe_dispatch_bwd_gate(grad_gates, indices_gpu, locations_gpu, input_gpu, dispatch_gpu, extra=[samples, hidden, capacity])
print("grad_gates:",grad_gates)
# call cpu func
input_t = np.asarray(input,dtype=np.float32)
dispatch_t = np.asarray(dispatch,dtype=np.float32)
indices_cpu = torch.from_numpy(indices_t)
locations_cpu = torch.from_numpy(locations_t)
input_cpu = torch.from_numpy(input_t)
print("cpu:")
# print("input_cpu:",input_cpu)
dispatch_cpu = torch.from_numpy(dispatch_t)
grad_gates_cpu = torch.zeros([samples], dtype=input_cpu.dtype, device=input_cpu.device)
moe_dispatch_bwd_gate = jit_kernel.create_backward_gate(input_cpu.dtype, input_cpu.is_cuda)
moe_dispatch_bwd_gate(grad_gates_cpu, indices_cpu, locations_cpu, input_cpu, dispatch_cpu, extra=[samples, hidden, capacity])
print("grad_gates_cpu:",grad_gates_cpu)
if name == 'main':
moe_dispatch_bwd_gate()

Problem: cuda calculation result is inconsistent with CPU calculation result:
cuda:[0.4180, 0.0000]
cpu:[ 0.3469, -0.3082]

Cuda calculation process analysis:

When index=0, calculate the gradient of the first gate

Due to dispatched_ Input and reshaded_ Input is of type half2, which is equivalent to float pointer

Therefore, when i=0, the subscript index * (hidden)+i=0 of the distribution, and the subscript index * (hidden)+i=0 of the input, take the first two half data, and accumulate the result of the calculation_ gates1_ s_ On rf

Read value: patch=[0.7085, 0.8257], input=[0.4946, -0.0043]

I=0 Calculation result: grad_ gates1_ s_ rf = 0.7085 * 0.4946 + 0.8257 * (-0.0043) = 0.34687359

When i=1, the subscript index * (hidden)+i=1 of the distribution, and the subscript index * (hidden)+i=1 of the input, take the last two half data, and also add it to the first gate gradient

Read value: patch=[-0.1455, -0.1788], input=[0.5386, -0.8354]

I=1 calculation result grad_ gates1_ s_ rf += (0.5386 * (-0.1455) + (-0.8354) * (-0.1788) = 0.07100322)

Last grad_ gates1_ s_ rf = 0.34687359 + 0.07100322 = 0.41787681

When index=1, the gradient of the second gate is calculated. The initial subscript of input is 2. The array access is out of bounds. The illegal address value may be 0, resulting in the second gradient result of 0

@WsqRichards1 WsqRichards1 changed the title tutel/jit_kernels/sparse.py torch.float16计算有bug:cuda计算结果与CPU计算结果不一致且有数组越界问题 tutel/jit_kernels/sparse.py torch.float16 There is a bug in the calculation: the cuda calculation result is inconsistent with the CPU calculation result and the array is out of bounds Mar 8, 2023
@ghostplant
Copy link
Contributor

Hi, thanks for your info. According to tracing, this is not a bug, but your code doesn't use it in a correct way:

CUDA's evaluation from your code is based on fp16x2, so the hidden_size value fed to that kernel should be divided by 2 as well (see https://github.com/microsoft/tutel/blob/main/tutel/impls/fast_dispatch.py#L95).

In other words, you should change your code from:

moe_dispatch_bwd_gate(grad_gates, indices_gpu, locations_gpu, input_gpu, dispatch_gpu, extra=[samples, hidden, capacity])

into

moe_dispatch_bwd_gate(grad_gates, indices_gpu, locations_gpu, input_gpu, dispatch_gpu, extra=[samples, hidden if input_gpu.dtype is not torch.float16 else hidden // 2, capacity])

@ghostplant ghostplant added the invalid This doesn't seem right label Mar 10, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
invalid This doesn't seem right
Projects
None yet
Development

No branches or pull requests

2 participants