Skip to content

Commit

Permalink
Add backend argument.
Browse files Browse the repository at this point in the history
  • Loading branch information
lshqqytiger committed Jun 2, 2023
1 parent 5e48f30 commit 0e4cbeb
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 7 deletions.
1 change: 1 addition & 0 deletions modules/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the API instead of the webui")
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
parser.add_argument("--backend", type=str, help="Select the backend to be used. Default: 'auto'", choices=["cuda", "rocm", "directml", "auto"], default="auto")
parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False)
parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origin(s) in the form of a comma-separated list (no spaces)", default=None)
parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS origin(s) in the form of a single regular expression", default=None)
Expand Down
8 changes: 5 additions & 3 deletions modules/devices.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import sys
import contextlib
import torch
import torch_directml
from modules import errors
from modules.sd_hijack_utils import CondFunc
from packaging import version
Expand Down Expand Up @@ -50,8 +49,11 @@ def get_optimal_device_name():
if has_mps():
return "mps"

if torch_directml.is_available():
return get_dml_device_string()
from modules import shared
if shared.cmd_opts.backend == 'directml':
import torch_directml
if torch_directml.is_available():
return get_dml_device_string()

return "cpu"

Expand Down
19 changes: 17 additions & 2 deletions modules/launch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,24 @@ def run_extensions_installers(settings_file):

def prepare_environment():
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118")
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
if shutil.which('nvidia-smi') is None and shutil.which('rocminfo') is None:
if args.backend == 'auto':
nvidia_driver_found = shutil.which('nvidia-smi') is not None
rocm_found = shutil.which('rocminfo') is not None
if nvidia_driver_found:
args.backend = 'cuda'
print("NVIDIA driver was found. Automatically changed backend to 'cuda'. You can manually select which backend will be used through '--backend' argument.")
elif rocm_found:
args.backend = 'rocm'
print("ROCm was found. Automatically changed backend to 'rocm'. You can manually select which backend will be used through '--backend' argument.")
else:
args.backend = 'directml'
if args.backend == 'cuda':
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
if args.backend == 'rocm':
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2")
if args.backend == 'directml':
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==2.0.0 torchvision==0.15.1 torch-directml")

requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")

xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.17')
Expand Down
1 change: 0 additions & 1 deletion modules/sd_hijack_optimizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import psutil

import torch
import torch_directml
from torch import einsum

from ldm.util import default
Expand Down
2 changes: 1 addition & 1 deletion modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
loaded_hypernetworks = []


if device.type == 'privateuseone':
if cmd_opts.backend == 'directml':
import modules.dml


Expand Down

0 comments on commit 0e4cbeb

Please sign in to comment.