Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update builder.py #5249

Closed
wants to merge 1 commit into from
Closed

Conversation

ehartford
Copy link

Fix deepspeed2 with ROCm

Fix deepspeed2 with ROCm
@ehartford
Copy link
Author

@microsoft-github-policy-service agree

@tjruwase
Copy link
Contributor

@ehartford, thanks for the PR. Can you please share a bit more about the issue that this fixes?

@ehartford
Copy link
Author

ehartford commented Mar 11, 2024

Ubuntu Server 20.04
AMD mi-210 (gfx90a)
ROCm 6.0
torch-2.3.0.dev20240309+rocm6.0
DeepSpeed tag v0.14.0

$ python -c "import torch; print(torch.version.hip)"
6.0.32830-d62f6a171

DeepSpeed Zero1 was working but DeepSpeed Zero2 wasn't working.

[rank14]:   File "/home/ehartford/miniconda3/envs/axolotl/lib/python3.12/site-packages/accelerate/accelerator.py", line 1598, in _prepare_deepspeed
[rank14]:     optimizer = DeepSpeedCPUAdam(optimizer.param_groups, **defaults)
[rank14]:                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank14]:   File "/home/ehartford/miniconda3/envs/axolotl/lib/python3.12/site-packages/deepspeed-0.14.1+535a908f-py3.12.egg/deepspeed/ops/adam/cpu_adam.py", line 94, in __init__
[rank14]:     self.ds_opt_adam = CPUAdamBuilder().load()
[rank14]:                        ^^^^^^^^^^^^^^^^^^^^^^^
[rank14]:   File "/home/ehartford/miniconda3/envs/axolotl/lib/python3.12/site-packages/deepspeed-0.14.1+535a908f-py3.12.egg/deepspeed/ops/op_builder/builder.py", line 479, in load
[rank14]:     return self.jit_load(verbose)
[rank14]:            ^^^^^^^^^^^^^^^^^^^^^^
[rank14]:   File "/home/ehartford/miniconda3/envs/axolotl/lib/python3.12/site-packages/deepspeed-0.14.1+535a908f-py3.12.egg/deepspeed/ops/op_builder/builder.py", line 511, in jit_load
[rank14]:     cxx_args = self.strip_empty_entries(self.cxx_args())
[rank14]:                                         ^^^^^^^^^^^^^^^
[rank14]:   File "/home/ehartford/miniconda3/envs/axolotl/lib/python3.12/site-packages/deepspeed-0.14.1+535a908f-py3.12.egg/deepspeed/ops/op_builder/builder.py", line 766, in cxx_args
[rank14]:     CUDA_ENABLE = self.is_cuda_enable()
[rank14]:                   ^^^^^^^^^^^^^^^^^^^^^
[rank14]:   File "/home/ehartford/miniconda3/envs/axolotl/lib/python3.12/site-packages/deepspeed-0.14.1+535a908f-py3.12.egg/deepspeed/ops/op_builder/builder.py", line 370, in is_cuda_enable
[rank14]:     assert_no_cuda_mismatch(self.name)
[rank14]:   File "/home/ehartford/miniconda3/envs/axolotl/lib/python3.12/site-packages/deepspeed-0.14.1+535a908f-py3.12.egg/deepspeed/ops/op_builder/builder.py", line 85, in assert_no_cuda_mismatch
[rank14]:     torch_cuda_version = ".".join(torch.version.cuda.split('.')[:2])

So, I delete DeepSpeed and install manually from source.

I set environment variables like this:

export GPU_ARCHS="gfx90a"
export ROCM_TARGET="gfx90a"
export HIP_PATH="/opt/rocm-6.0.0"
export ROCM_PATH="/opt/rocm-6.0.0"
export ROCM_HOME="/opt/rocm-6.0.0"
export HIP_PLATFORM=amd
export DS_BUILD_CPU_ADAM=1 
export TORCH_HIP_ARCH_LIST="gfx90a"

Then when I try to do DS_BUILD_CPU_ADAM=1 TORCH_HIP_ARCH_LIST="gfx90a" python setup.py install
I get:

  File "/scratch/axolotl/DeepSpeed/op_builder/builder.py", line 85, in assert_no_cuda_mismatch
    torch_cuda_version = ".".join(torch.version.cuda.split('.')[:2])
                                  ^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'split'

So I asked Claude Opus to fix it, and it suggested the change in this PR.

After I made that change, then I was able to compile and install. And then, DeepSpeed Zero2 was working after that.

@IlyasMoutawwakil
Copy link

IlyasMoutawwakil commented Mar 11, 2024

this doesn't solve the issue imo, it just makes it more confusing, calling a cuda specific function (assert_no_cuda_mismatch) while on rocm system.
the error was raised because unlike in:
https://github.com/microsoft/DeepSpeed/blob/535a908f1b60f819df4ccf1071f7c917c39dabbe/op_builder/builder.py#L632C1-L633C51
self.is_rocm_pytorch is not checked before calling assert_no_cuda_mismatch in: https://github.com/microsoft/DeepSpeed/blob/535a908f1b60f819df4ccf1071f7c917c39dabbe/op_builder/builder.py#L369C1-L371C39
and that's why you got an error, so just adding that check will solve your issue :)

and to truly solve this issue, what's needed is a assert_no_rocm_hip_mismatch utility function, that implements similar logic to its cuda counterpart (checking whether the system's rocm version is compatible with the installed torch wheel's hip version).

@ehartford
Copy link
Author

ok.
I have no stake.
This code unblocked me.
But a proper fix would be awesome.
Thanks for your consideration.

@ehartford ehartford closed this Mar 11, 2024
@tjruwase
Copy link
Contributor

@ehartford, thanks for sharing these details. I am glad that you are unblocked. Could you please create ticket for this issue with the details above? That would be very helpful for our investigation. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants