Skip to content

Commit

Permalink
add ascend backend
Browse files Browse the repository at this point in the history
  • Loading branch information
YukMingLaw committed May 11, 2024
1 parent 4f63ee9 commit ba59f7b
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 0 deletions.
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,19 @@ This is the command to install pytorch nightly instead which might have performa

```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121```

### ASCEND

Ascend users should install ```cann>=7.0.0 torch+cpu>=2.1.0``` and ```torch_npu```,below are the installation reference documents.

https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha001/softwareinst/instg/instg_0001.html

https://gitee.com/ascend/pytorch

This is the command to lanuch ComfyUI using Ascend backend.

```python main.py --use-npu```


#### Troubleshooting

If you get the "Torch not compiled with CUDA enabled" error, uninstall torch with:
Expand Down
4 changes: 4 additions & 0 deletions comfy/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ def __call__(self, parser, namespace, values, option_string=None):

parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.")

parser.add_argument("--disable-torchair-optimize", action="store_true", help="Disables torchair graph modee optimize when loading models with Ascend NPUs.")

parser.add_argument("--use-npu", action="store_true", help="use Huawei Ascend NPUs backend.")

class LatentPreviewMethod(enum.Enum):
NoPreviews = "none"
Auto = "auto"
Expand Down
40 changes: 40 additions & 0 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from comfy.cli_args import args
import comfy.utils
import torch
if args.use_npu == True:
import torch_npu
import sys

class VRAMState(Enum):
Expand Down Expand Up @@ -71,6 +73,12 @@ def is_intel_xpu():
return True
return False

def is_ascend_npu():
if args.use_npu and torch.npu.is_available():
return True
return False


def get_torch_device():
global directml_enabled
global cpu_state
Expand All @@ -84,6 +92,8 @@ def get_torch_device():
else:
if is_intel_xpu():
return torch.device("xpu", torch.xpu.current_device())
elif is_ascend_npu():
return torch.device("npu", torch.npu.current_device())
else:
return torch.device(torch.cuda.current_device())

Expand All @@ -104,6 +114,11 @@ def get_total_memory(dev=None, torch_total_too=False):
mem_reserved = stats['reserved_bytes.all.current']
mem_total = torch.xpu.get_device_properties(dev).total_memory
mem_total_torch = mem_reserved
elif is_ascend_npu():
stats = torch.npu.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
mem_total = torch.npu.get_device_properties(dev).total_memory
mem_total_torch = mem_reserved
else:
stats = torch.cuda.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
Expand Down Expand Up @@ -179,6 +194,10 @@ def is_nvidia():
if is_intel_xpu():
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
if is_ascend_npu():
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True

except:
pass

Expand Down Expand Up @@ -249,6 +268,8 @@ def get_torch_device_name(device):
return "{}".format(device.type)
elif is_intel_xpu():
return "{} {}".format(device, torch.xpu.get_device_name(device))
elif is_ascend_npu():
return "{} {}".format(device, torch.npu.get_device_name(device))
else:
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))

Expand Down Expand Up @@ -306,6 +327,11 @@ def model_load(self, lowvram_model_memory=0):
if is_intel_xpu() and not args.disable_ipex_optimize:
self.real_model = ipex.optimize(self.real_model.eval(), graph_mode=True, concat_linear=True)

if is_ascend_npu() and not args.disable_torchair_optimize: # torchair optimize
import torchair as tng
npu_backend = tng.get_npu_backend()
self.real_model = torch.compile(self.real_model.eval(), backend=npu_backend, dynamic=False)

self.weights_loaded = True
return self.real_model

Expand Down Expand Up @@ -649,6 +675,8 @@ def xformers_enabled():
return False
if directml_enabled:
return False
if is_ascend_npu():
return False
return XFORMERS_IS_AVAILABLE


Expand Down Expand Up @@ -690,6 +718,13 @@ def get_free_memory(dev=None, torch_free_too=False):
mem_reserved = stats['reserved_bytes.all.current']
mem_free_torch = mem_reserved - mem_active
mem_free_total = torch.xpu.get_device_properties(dev).total_memory - mem_allocated
elif is_ascend_npu():
stats = torch.npu.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_npu, _ = torch.npu.mem_get_info(dev)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_npu + mem_free_torch
else:
stats = torch.cuda.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
Expand Down Expand Up @@ -755,6 +790,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if is_intel_xpu():
return True

if is_ascend_npu():
return True

if torch.version.hip:
return True

Expand Down Expand Up @@ -833,6 +871,8 @@ def soft_empty_cache(force=False):
torch.mps.empty_cache()
elif is_intel_xpu():
torch.xpu.empty_cache()
elif is_ascend_npu():
torch.npu.empty_cache()
elif torch.cuda.is_available():
if force or is_nvidia(): #This seems to make things worse on ROCm so I only do it for cuda
torch.cuda.empty_cache()
Expand Down

0 comments on commit ba59f7b

Please sign in to comment.