Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] error: use of undeclared identifier '__double2half'; did you mean '__double2hiint'?" #3197

Closed
WeiMa01 opened this issue Apr 12, 2023 · 4 comments
Labels
bug Something isn't working inference

Comments

@WeiMa01
Copy link

WeiMa01 commented Apr 12, 2023

When I launch the script with deepspeed ,there is a error " error: use of undeclared identifier __double2half; did you mean __double2hiint?"

##1. The script as following:

`
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import time
import os
import deepspeed

model_id = "/home/user/workspace/workspace/gpt2"
payload = "Hello my name is Philipp. I am getting in touch with you because i didn't get a response from you. What do I need to do to get my"

local_rank = int(os.getenv('LOCAL_RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))
print("world_size",world_size)
generator = pipeline('text-generation', model=model_id,
device=local_rank,torch_dtype=torch.float16)
print("generator")
generator.model = deepspeed.init_inference(generator.model,
mp_size=world_size,
dtype=torch.float16,
replace_with_kernel_inject=True)

string = generator(payload, do_sample=True,num_beams=1, min_length=256,max_new_tokens=256,pad_token_id = 50256)
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
print(string)
`
##2. the H/W and S/W
ds_report:

`
'
DeepSpeed C++/CUDA extension op report

NOTE: Ops not installed will be just-in-time (JIT) compiled at
runtime if needed. Op compatibility means that your system
meet the required dependencies to JIT install the op.

JIT compiled ops requires ninja
ninja .................. [OKAY]

op name ................ installed .. compatible

async_io ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
fused_adam ............. [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
[WARNING] sparse_attn is not compatible with ROCM
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
utils .................. [NO] ....... [OKAY]

DeepSpeed general environment info:
torch install path ............... ['/opt/conda/lib/python3.8/site-packages/torch']
torch version .................... 1.13.0a0+git941769a
deepspeed install path ........... ['/home/user/workspace/DeepSpeed/deepspeed']
deepspeed info ................... 0.9.0+f662bfcd, f662bfc, master
torch cuda version ............... None
torch hip version ................ 5.4.22801-aaa1e3d8
nvcc version ..................... None
deepspeed wheel compiled w. ...... torch 1.13, hip 5.4
'

##3. launch command:

deepspeed --num_gpus 4 gpt2_deepspeed.py

##4. the error detail as following:

Using envvar MAX_JOBS (32) as the number of workers... [1/5] /opt/rocm/bin/hipcc -DWITH_HIP -DTORCH_EXTENSION_NAME=transformer_inference -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1013\" -I/home/user/workspace/DeepSpeed/csrc/transformer/inference/includes -I/home/user/workspace/DeepSpeed/csrc/includes -isystem /opt/conda/lib/python3.8/site-packages/torch/include -isystem /opt/conda/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -isystem /opt/conda/lib/python3.8/site-packages/torch/include/TH -isystem /opt/conda/lib/python3.8/site-packages/torch/include/THC -isystem /opt/conda/lib/python3.8/site-packages/torch/include/THH -isystem /opt/rocm/include -isystem /opt/rocm/miopen/include -isystem /opt/rocm/hip/include -isystem /opt/conda/include/python3.8 -D_GLIBCXX_USE_CXX11_ABI=1 -fPIC -std=c++14 -O3 -std=c++14 -g -Wno-reorder -fPIC -D__HIP_PLATFORM_HCC__=1 -DUSE_ROCM=1 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -std=c++14 -U__HIP_NO_HALF_OPERATORS__ -U__HIP_NO_HALF_CONVERSIONS__ -U__HIP_NO_HALF2_OPERATORS__ -DROCM_VERSION_MAJOR=5 -DROCM_VERSION_MINOR=4 --amdgpu-target=gfx900 --amdgpu-target=gfx906 --amdgpu-target=gfx908 --amdgpu-target=gfx90a --amdgpu-target=gfx1030 -fno-gpu-rdc -c /home/user/workspace/DeepSpeed/csrc/transformer/inference/csrc/layer_norm.hip -o layer_norm.cuda.o FAILED: layer_norm.cuda.o /opt/rocm/bin/hipcc -DWITH_HIP -DTORCH_EXTENSION_NAME=transformer_inference -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1013\" -I/home/user/workspace/DeepSpeed/csrc/transformer/inference/includes -I/home/user/workspace/DeepSpeed/csrc/includes -isystem /opt/conda/lib/python3.8/site-packages/torch/include -isystem /opt/conda/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -isystem /opt/conda/lib/python3.8/site-packages/torch/include/TH -isystem /opt/conda/lib/python3.8/site-packages/torch/include/THC -isystem /opt/conda/lib/python3.8/site-packages/torch/include/THH -isystem /opt/rocm/include -isystem /opt/rocm/miopen/include -isystem /opt/rocm/hip/include -isystem /opt/conda/include/python3.8 -D_GLIBCXX_USE_CXX11_ABI=1 -fPIC -std=c++14 -O3 -std=c++14 -g -Wno-reorder -fPIC -D__HIP_PLATFORM_HCC__=1 -DUSE_ROCM=1 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -std=c++14 -U__HIP_NO_HALF_OPERATORS__ -U__HIP_NO_HALF_CONVERSIONS__ -U__HIP_NO_HALF2_OPERATORS__ -DROCM_VERSION_MAJOR=5 -DROCM_VERSION_MINOR=4 --amdgpu-target=gfx900 --amdgpu-target=gfx906 --amdgpu-target=gfx908 --amdgpu-target=gfx90a --amdgpu-target=gfx1030 -fno-gpu-rdc -c /home/user/workspace/DeepSpeed/csrc/transformer/inference/csrc/layer_norm.hip -o layer_norm.cuda.o Warning: The --amdgpu-target option has been deprecated and will be removed in the future. Use --offload-arch instead. Warning: The --amdgpu-target option has been deprecated and will be removed in the future. Use --offload-arch instead. Warning: The --amdgpu-target option has been deprecated and will be removed in the future. Use --offload-arch instead. Warning: The --amdgpu-target option has been deprecated and will be removed in the future. Use --offload-arch instead. Warning: The --amdgpu-target option has been deprecated and will be removed in the future. Use --offload-arch instead. In file included from /home/user/workspace/DeepSpeed/csrc/transformer/inference/csrc/layer_norm.hip:8: /home/user/workspace/DeepSpeed/csrc/includes/conversion_utils_hip.h:270:12: error: use of undeclared identifier __double2half'; did you mean '__double2hiint? return __double2half(val); ^~~~~~~~~~~~~ __double2hiint /opt/rocm-5.4.0/include/hip/amd_detail/amd_device_functions.h:440:30: note: __double2hiint declared here __device__ static inline int __double2hiint(double x) { ^ In file included from /home/user/workspace/DeepSpeed/csrc/transformer/inference/csrc/layer_norm.hip:12: /home/user/workspace/DeepSpeed/csrc/includes/reduction_utils_hip.h:278:43: error: excess elements in struct initializer constexpr __half2_raw zero = {0x0000, 0x0000}; ^~~~~~ /home/user/workspace/DeepSpeed/csrc/includes/reduction_utils_hip.h:285:42: error: excess elements in struct initializer constexpr __half2_raw inf = {0x7C00, 0x7C00}; ^~~~~~ /home/user/workspace/DeepSpeed/csrc/includes/reduction_utils_hip.h:292:46: error: excess elements in struct initializer constexpr __half2_raw neg_inf = {0xFC00, 0xFC00}; ^~~~~~ In file included from /home/user/workspace/DeepSpeed/csrc/transformer/inference/csrc/layer_norm.hip:8: In file included from /home/user/workspace/DeepSpeed/csrc/includes/conversion_utils_hip.h:9: In file included from /home/user/workspace/DeepSpeed/csrc/includes/ds_kernel_utils_hip.h:24: In file included from /opt/rocm-5.4.0/include/hip/hip_cooperative_groups.h:38: /opt/rocm-5.4.0/include/hip/amd_detail/amd_hip_cooperative_groups.h:527:3: error: static assertion failed due to requirement 'integral_constant<bool, false>::value': Tile size is either not a power of 2 or greater than the wavefront size static_assert(is_valid_tile_size<size>::value, ^ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /opt/rocm-5.4.0/include/hip/amd_detail/amd_hip_cooperative_groups.h:563:39: note: in instantiation of template class 'cooperative_groups::thread_block_tile_base<64>' requested here class thread_block_tile_type : public thread_block_tile_base<tileSize>, public tiled_group { ^ /opt/rocm-5.4.0/include/hip/amd_detail/amd_hip_cooperative_groups.h:624:43: note: in instantiation of template class cooperative_groups::thread_block_tile_type<64> requested here class thread_block_tile_internal : public thread_block_tile_type<size, ParentCGTy> { ^ /opt/rocm-5.4.0/include/hip/amd_detail/amd_hip_cooperative_groups.h:650:46: note: in instantiation of template class cooperative_groups::impl::thread_block_tile_internal<64, void> requested here class thread_block_tile<size, void> : public impl::thread_block_tile_internal<size, void> { ^ /home/user/workspace/DeepSpeed/csrc/includes/reduction_utils_hip.h:343:44: note: in instantiation of template class cooperative_groups::thread_block_tile<64, void> requested here data[0] = element<Op>(data[0], warp.shfl_xor(data[0], i)); ^ /home/user/workspace/DeepSpeed/csrc/includes/reduction_utils_hip.h:352:46: error: no member named 'shfl_xor' in cooperative_groups::thread_block_tile<64, void> data[0] = element<Op1>(data[0], warp.shfl_xor(data[0], i)); ~~~~ ^ /home/user/workspace/DeepSpeed/csrc/includes/reduction_utils_hip.h:353:46: error: no member named 'shfl_xor' in cooperative_groups::thread_block_tile<64, void> data[1] = element<Op2>(data[1], warp.shfl_xor(data[1], i)); ~~~~ ^ /home/user/workspace/DeepSpeed/csrc/includes/reduction_utils_hip.h:362:46: error: no member named 'shfl_xor' in cooperative_groups::thread_block_tile<64, void> data[0] = element<Op1>(data[0], warp.shfl_xor(data[0], i)); ~~~~ ^ /home/user/workspace/DeepSpeed/csrc/includes/reduction_utils_hip.h:363:46: error: no member named 'shfl_xor' in cooperative_groups::thread_block_tile<64, void> data[1] = element<Op2>(data[1], warp.shfl_xor(data[1], i)); ~~~~ ^ /home/user/workspace/DeepSpeed/csrc/includes/reduction_utils_hip.h:364:46: error: no member named 'shfl_xor' in cooperative_groups::thread_block_tile<64, void> data[2] = element<Op3>(data[2], warp.shfl_xor(data[2], i)); ~~~~ ^ /home/user/workspace/DeepSpeed/csrc/includes/reduction_utils_hip.h:373:46: error: no member named 'shfl_xor' in cooperative_groups::thread_block_tile<64, void> data[0] = element<Op1>(data[0], warp.shfl_xor(data[0], i)); ~~~~ ^ /home/user/workspace/DeepSpeed/csrc/includes/reduction_utils_hip.h:374:46: error: no member named 'shfl_xor' in cooperative_groups::thread_block_tile<64, void> data[1] = element<Op2>(data[1], warp.shfl_xor(data[1], i)); ~~~~ ^ /home/user/workspace/DeepSpeed/csrc/includes/reduction_utils_hip.h:375:46: error: no member named 'shfl_xor' in cooperative_groups::thread_block_tile<64, void> data[2] = element<Op3>(data[2], warp.shfl_xor(data[2], i)); ~~~~ ^ /home/user/workspace/DeepSpeed/csrc/includes/reduction_utils_hip.h:376:46: error: no member named 'shfl_xor' in 'cooperative_groups::thread_block_tile<64, void>' data[3] = element<Op4>(data[3], warp.shfl_xor(data[3], i)); ~~~~ ^ /home/user/workspace/DeepSpeed/csrc/includes/reduction_utils_hip.h:412:18: error: no member named 'meta_group_size' in 'cooperative_groups::thread_block_tile<64, void>' if (warp_arg.meta_group_size() > 1 && total_warps != 1) { ~~~~~~~~ ^ /home/user/workspace/DeepSpeed/csrc/includes/reduction_utils_hip.h:413:22: error: no member named 'thread_rank' in 'cooperative_groups::thread_block_tile<64, void>' if (warp_arg.thread_rank() == 0) { ~~~~~~~~ ^ /home/user/workspace/DeepSpeed/csrc/includes/reduction_utils_hip.h:417:54: error: no member named 'meta_group_rank' in 'cooperative_groups::thread_block_tile<64, void>' reduce_buffer + elems * warp_arg.meta_group_rank() + i, data + i); ~~~~~~~~ ^ /home/user/workspace/DeepSpeed/csrc/includes/reduction_utils_hip.h:424:22: error: no member named 'meta_group_rank' in 'cooperative_groups::thread_block_tile<64, void>' if (warp_arg.meta_group_rank() == 0) { ~~~~~~~~ ^ /home/user/workspace/DeepSpeed/csrc/includes/reduction_utils_hip.h:425:26: error: no member named 'thread_rank' in 'cooperative_groups::thread_block_tile<64, void>' if (warp_arg.thread_rank() < warp_arg.meta_group_size()) { ~~~~~~~~ ^ fatal error: too many errors emitted, stopping now [-ferror-limit=] 20 errors generated when compiling for gfx1030.

##Expected behavior
Has anyone encountered the same error and how to solve it? thanks you!

##Screenshots
If applicable, add screenshots to help explain your problem.

##System info (please complete the following information):

  • OS: Ubuntu 20.04

  • GPU:GPU : MI100 * 8

  • Python version: Python 3.8.13

  • Any other relevant info about your setup

##Docker context
rocm/pytorch:latest

@WeiMa01 WeiMa01 added bug Something isn't working inference labels Apr 12, 2023
@WeiMa01 WeiMa01 changed the title [BUG] [BUG] error: use of undeclared identifier '__double2half'; did you mean '__double2hiint'?" Apr 12, 2023
@mrwyattii
Copy link
Contributor

Hi @WeiMa01 thanks for reporting this issue. It looks like this might be a problem in the hipification of our CUDA code. Could you please share if you were able to run this example previously (perhaps with an older version of ROCm/DeepSpeed)? I'll look into this further. Thanks

@WeiMa01
Copy link
Author

WeiMa01 commented Apr 13, 2023

Hi @WeiMa01 thanks for reporting this issue. It looks like this might be a problem in the hipification of our CUDA code. Could you please share if you were able to run this example previously (perhaps with an older version of ROCm/DeepSpeed)? I'll look into this further. Thanks

Hi @mrwyattii I can run the script normally on the NVIDIA V100 GPU before, but this error occurs when I transplant it to the MI100 server, and I have not run it with other version of ROCm/DeepSpeed on AMD GPU before.

@seungrokjung
Copy link

I also got stuck at this issue:
error: use of undeclared identifier '__double2half'; did you mean '__double2hiint'?

@loadams
Copy link
Contributor

loadams commented Apr 18, 2023

This should now be fixed by this PR: #3236

Can you try again and let us know if that fixes this issue?

@loadams loadams closed this as completed Apr 18, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working inference
Projects
None yet
Development

No branches or pull requests

4 participants