Skip to content

[BUG] Duplicate/Wrong(?) Compute capability flags added #7972

@Flamefire

Description

@Flamefire

for cc in ccs:
num = cc[0] + cc[1].split('+')[0]
args.append(f'-gencode=arch=compute_{num},code=sm_{num}')
if cc[1].endswith('+PTX'):
args.append(f'-gencode=arch=compute_{num},code=compute_{num}')

adds -gencode= flags for the current devices.

As

if "TORCH_CUDA_ARCH_LIST" in os.environ:
torch_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST")
os.environ["TORCH_CUDA_ARCH_LIST"] = ""

clears $TORCH_CUDA_ARCH_LIST PyTorch will compile for the current device(s?) and add appropriate flags, again.

This results in an overly long command line:
nvcc [...] -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_80,code=sm_80 [...] -gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_80,code=compute_80 [...]/deepspeed/ops/csrc/fp_quantizer/fp_quantize_impl.cu -o fp_quantize_impl.cuda.o

Depending on how nvcc handles the repetition this may lead to redundant compilation.

A better approach would be to NOT pass CUDA CC flags and instead set TORCH_CUDA_ARCH_LIST appropriately and let PyTorch handle this which also avoids

UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions