diff --git a/.github/workflows/package_wheels.yml b/.github/workflows/package_wheels.yml index 9f68138..3909865 100644 --- a/.github/workflows/package_wheels.yml +++ b/.github/workflows/package_wheels.yml @@ -76,4 +76,4 @@ jobs: uses: actions/cache/save@v3 with: path: ${{ env.archive_name }}.zip - key: ${{ env.archive_name }} + key: ${{ env.archive_name }}-${{ hashFiles('reqs.txt') }} diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 438f2eb..67d97ef 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -6,7 +6,7 @@ on: name: description: Release tag / name ? required: true - default: "latest" + default: 'latest' type: string environment: description: Environment to run tests against @@ -27,9 +27,9 @@ jobs: - name: โ™ป๏ธ Checking out the repository uses: actions/checkout@v3 with: - submodules: "recursive" + submodules: 'recursive' path: ${{ env.repo_name }} - + # - name: ๐Ÿ“ Prepare file with paths to remove # run: | # find ${{ env.repo_name }} -type f -size +10M > .release_ignore @@ -56,7 +56,7 @@ jobs: else echo "No .release_ignore file found. Skipping removal of files and directories." fi - + - name: ๐Ÿ“ฆ Building custom comfy nodes shell: bash run: | @@ -98,7 +98,7 @@ jobs: id: cache with: path: ${{ env.archive_name }}.zip - key: ${{ env.archive_name }} + key: ${{ env.archive_name }}-${{ hashFiles('reqs.txt') }} - name: ๐Ÿ“ฆ Unzip wheels shell: bash run: | diff --git a/.github/workflows/test_embedded.yml b/.github/workflows/test_embedded.yml index 26982d5..31c4c61 100644 --- a/.github/workflows/test_embedded.yml +++ b/.github/workflows/test_embedded.yml @@ -3,7 +3,6 @@ name: ๐Ÿงช Test Comfy Portable on: workflow_dispatch: inputs: - invalidate_cache: description: 'Whether to invalidate the cache or not' required: false @@ -14,28 +13,33 @@ jobs: env: repo_name: ${{ github.event.repository.name }} steps: - - name: โš ๏ธ Invalidate Cache if Needed - if: ${{ github.event.inputs.invalidate_cache == 'true' }} - run: | - echo "Invalidating Cache..." - rmdir /s /q C:\Users\runner\AppData\Local\Temp\caches + - name: โšก๏ธ Restore Cache if Available + id: cache-comfy + uses: actions/cache/restore@v3 + with: + path: ComfyUI_windows_portable + key: ${{ runner.os }}-comfy-env - name: ๐Ÿšก Download and Extract Comfy + id: download-extract-comfy + if: steps.cache-comfy.outputs.cache-hit != 'true' + shell: bash run: | mkdir comfy_temp curl -L -o comfy_temp/comfyui.7z https://github.com/comfyanonymous/ComfyUI/releases/download/latest/ComfyUI_windows_portable_nvidia_cu118_or_cpu.7z + 7z x comfy_temp/comfyui.7z -o./comfy_temp # mv comfy_temp/ComfyUI_windows_portable/python_embeded . # mv comfy_temp/ComfyUI_windows_portable/ComfyUI . # mv comfy_temp/ComfyUI_windows_portable/update . - + ls mv comfy_temp/ComfyUI_windows_portable . - - name: ๐Ÿ“ฆ Cache Comfy Environment - id: cache-comfy - uses: actions/cache@v2 + - name: ๐Ÿ’พ Store cache + uses: actions/cache/save@v3 + if: steps.cache-comfy.outputs.cache-hit != 'true' with: path: ComfyUI_windows_portable key: ${{ runner.os }}-comfy-env @@ -50,10 +54,15 @@ jobs: shell: bash run: | # run install - export COMFY_PYTHON="ComfyUI_windows_portable/python_embeded/python.exe" - cd ComfyUI_windows_portable/ComfyUI/ + export COMFY_PYTHON="${GITHUB_WORKSPACE}/ComfyUI_windows_portable/python_embeded/python.exe" + cd "${GITHUB_WORKSPACE}/ComfyUI_windows_portable/ComfyUI/" $COMFY_PYTHON custom_nodes/${{ env.repo_name }}/install.py -w - # check node loading state - cd custom_nodes - $COMFY_PYTHON -c "import ${{ env.repo_name }}" + - name: โฌ Import mtb_nodes + shell: bash + run: | + export COMFY_PYTHON="${GITHUB_WORKSPACE}/ComfyUI_windows_portable/python_embeded/python.exe" + cd "${GITHUB_WORKSPACE}/ComfyUI_windows_portable/ComfyUI" + $COMFY_PYTHON -s main.py --quick-test-for-ci --cpu + + $COMFY_PYTHON -m pip freeze diff --git a/__init__.py b/__init__.py index aec26e0..0439ac0 100644 --- a/__init__.py +++ b/__init__.py @@ -26,7 +26,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {} NODE_CLASS_MAPPINGS_DEBUG = {} -__version__ = "0.1.2" +__version__ = "0.1.4" def extract_nodes_from_source(filename): diff --git a/install.py b/install.py index a2af587..604ca65 100644 --- a/install.py +++ b/install.py @@ -1,7 +1,6 @@ import requests import os import ast -import re import argparse import sys import subprocess @@ -9,10 +8,12 @@ import platform from pathlib import Path import sys -import zipfile -import shutil import stat - +import threading +import signal +from contextlib import suppress +from queue import Queue, Empty +from contextlib import contextmanager here = Path(__file__).parent executable = sys.executable @@ -27,7 +28,7 @@ mode = "venv" -if mode == None: +if mode is None: mode = "unknown" # region ansi @@ -102,12 +103,104 @@ def print_formatted(text, *formats, color=None, background=None, **kwargs): formatted_text = apply_format(text, *formats) formatted_text = apply_color(formatted_text, color, background) file = kwargs.get("file", sys.stdout) - print( - apply_color(apply_format("[mtb install] ", "bold"), color="yellow"), - formatted_text, - file=file, + header = "[mtb install] " + + # Handle console encoding for Unicode characters (utf-8) + encoded_header = header.encode("utf-8", errors="replace").decode("utf-8") + encoded_text = formatted_text.encode("utf-8", errors="replace").decode("utf-8") + + if sys.platform == "win32": + output_text = ( + " " * len(encoded_header) + if kwargs.get("no_header") + else apply_color(apply_format(encoded_header, "bold"), color="yellow") + ) + output_text += encoded_text + "\n" + sys.stdout.buffer.write(output_text.encode("utf-8")) + else: + print( + " " * len(encoded_header) + if kwargs.get("no_header") + else apply_color(apply_format(encoded_header, "bold"), color="yellow"), + encoded_text, + file=file, + ) + + +# endregion + + +# region utils +def enqueue_output(out, queue): + for line in iter(out.readline, b""): + queue.put(line) + out.close() + + +def run_command(cmd): + if isinstance(cmd, str): + shell_cmd = cmd + shell = True + elif isinstance(cmd, list): + shell_cmd = " ".join(cmd) + shell = False + else: + raise ValueError( + "Invalid 'cmd' argument. It must be a string or a list of arguments." + ) + + process = subprocess.Popen( + shell_cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True, + shell=shell, ) + # Create separate threads to read standard output and standard error streams + stdout_queue = Queue() + stderr_queue = Queue() + stdout_thread = threading.Thread( + target=enqueue_output, args=(process.stdout, stdout_queue) + ) + stderr_thread = threading.Thread( + target=enqueue_output, args=(process.stderr, stderr_queue) + ) + stdout_thread.daemon = True + stderr_thread.daemon = True + stdout_thread.start() + stderr_thread.start() + + interrupted = False + + def signal_handler(signum, frame): + nonlocal interrupted + interrupted = True + print("Command execution interrupted.") + + # Register the signal handler for keyboard interrupts (SIGINT) + signal.signal(signal.SIGINT, signal_handler) + + # Process output from both streams until the process completes or interrupted + while not interrupted and ( + process.poll() is None or not stdout_queue.empty() or not stderr_queue.empty() + ): + with suppress(Empty): + stdout_line = stdout_queue.get_nowait() + if stdout_line.strip() != "": + print(stdout_line.strip()) + with suppress(Empty): + stderr_line = stderr_queue.get_nowait() + if stderr_line.strip() != "": + print(stderr_line.strip()) + return_code = process.returncode + + if return_code == 0 and not interrupted: + print("Command executed successfully!") + else: + if not interrupted: + print(f"Command failed with return code: {return_code}") + # endregion @@ -115,9 +208,7 @@ def print_formatted(text, *formats, color=None, background=None, **kwargs): import requirements except ImportError: print_formatted("Installing requirements-parser...", "italic", color="yellow") - subprocess.check_call( - [sys.executable, "-m", "pip", "install", "requirements-parser"] - ) + run_command([sys.executable, "-m", "pip", "install", "requirements-parser"]) import requirements print_formatted("Done.", "italic", color="green") @@ -126,7 +217,7 @@ def print_formatted(text, *formats, color=None, background=None, **kwargs): from tqdm import tqdm except ImportError: print_formatted("Installing tqdm...", "italic", color="yellow") - subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "tqdm"]) + run_command([sys.executable, "-m", "pip", "install", "--upgrade", "tqdm"]) from tqdm import tqdm import importlib @@ -141,16 +232,41 @@ def print_formatted(text, *formats, color=None, background=None, **kwargs): def is_pipe(): - try: - mode = os.fstat(0).st_mode - return ( - stat.S_ISFIFO(mode) - or stat.S_ISREG(mode) - or stat.S_ISBLK(mode) - or stat.S_ISSOCK(mode) - ) - except OSError: + if not sys.stdin.isatty(): return False + if sys.platform == "win32": + try: + import msvcrt + + return msvcrt.get_osfhandle(0) != -1 + except ImportError: + return False + else: + try: + mode = os.fstat(0).st_mode + return ( + stat.S_ISFIFO(mode) + or stat.S_ISREG(mode) + or stat.S_ISBLK(mode) + or stat.S_ISSOCK(mode) + ) + except OSError: + return False + + +@contextmanager +def suppress_std(): + with open(os.devnull, "w") as devnull: + old_stdout = sys.stdout + old_stderr = sys.stderr + sys.stdout = devnull + sys.stderr = devnull + + try: + yield + finally: + sys.stdout = old_stdout + sys.stderr = old_stderr # Get the version from __init__.py @@ -211,46 +327,51 @@ def try_import(requirement): installed = False pip_name = dependency - if specs := requirement.specs: - pip_name += "".join(specs[0]) - + pip_spec = "".join(specs[0]) if (specs := requirement.specs) else "" try: - import_module(import_name) + with suppress_std(): + import_module(import_name) print_formatted( - f"Package {pip_name} already installed (import name: '{import_name}').", + f"\tโœ… Package {pip_name} already installed (import name: '{import_name}').", "bold", color="green", + no_header=True, ) installed = True except ImportError: - pass + print_formatted( + f"\tโ›” Package {pip_name} is missing (import name: '{import_name}').", + "bold", + color="red", + no_header=True, + ) - return (installed, pip_name, import_name) + return (installed, pip_name, pip_spec, import_name) def import_or_install(requirement, dry=False): - installed, pip_name, import_name = try_import(requirement) + installed, pip_name, pip_spec, import_name = try_import(requirement) + + pip_install_name = pip_name + pip_spec if not installed: print_formatted(f"Installing package {pip_name}...", "italic", color="yellow") if dry: print_formatted( - f"Dry-run: Package {pip_name} would be installed (import name: '{import_name}').", + f"Dry-run: Package {pip_install_name} would be installed (import name: '{import_name}').", color="yellow", ) else: try: - subprocess.check_call( - [sys.executable, "-m", "pip", "install", pip_name] - ) + run_command([sys.executable, "-m", "pip", "install", pip_install_name]) print_formatted( - f"Package {pip_name} installed successfully using pip package name (import name: '{import_name}')", + f"Package {pip_install_name} installed successfully using pip package name (import name: '{import_name}')", "bold", color="green", ) except subprocess.CalledProcessError as e: print_formatted( - f"Failed to install package {pip_name} using pip package name (import name: '{import_name}'). Error: {str(e)}", + f"Failed to install package {pip_install_name} using pip package name (import name: '{import_name}'). Error: {str(e)}", "bold", color="red", ) @@ -279,7 +400,7 @@ def install_dependencies(dry=False): if not clone_dir.exists(): clone_dir.parent.mkdir(parents=True, exist_ok=True) print_formatted(f"Cloning {url} to {clone_dir}", "italic", color="yellow") - subprocess.check_call(["git", "clone", "--recursive", url, clone_dir]) + run_command(["git", "clone", "--recursive", url, clone_dir.as_posix()]) # os.chdir(clone_dir) here = clone_dir @@ -316,7 +437,7 @@ def install_dependencies(dry=False): args = parser.parse_args() - wheels_directory = here / "wheels" + # wheels_directory = here / "wheels" print_formatted(f"Detected environment: {apply_color(mode,'cyan')}") # Install dependencies from requirements.txt @@ -335,23 +456,26 @@ def install_dependencies(dry=False): if mode in ["colab", "embeded"]: print_formatted( f"Downloading and installing release wheels since we are in a Comfy {apply_color(mode,'cyan')} environment", + "italic", + color="yellow", ) if full: print_formatted( - f"Downloading and installing release wheels since no arguments where provided" + f"Downloading and installing release wheels since no arguments where provided", + "italic", + color="yellow", ) # - Check the env before proceeding. - missing_wheels = False + missing_deps = [] parsed_requirements = get_requirements(here / "reqs.txt") if parsed_requirements: for requirement in parsed_requirements: - installed, pip_name, import_name = try_import(requirement) + installed, pip_name, pip_spec, import_name = try_import(requirement) if not installed: - missing_wheels = True - break + missing_deps.append(pip_name.split("-")[0]) - if not missing_wheels: + if len(missing_deps) == 0: print_formatted( f"All requirements are already installed.", "italic", color="green" ) @@ -392,97 +516,120 @@ def install_dependencies(dry=False): # ) # sys.exit() - # Download the assets for the given version + short_platform = { + "windows": "win_amd64", + "linux": "linux_x86_64", + } matching_assets = [ asset for asset in tag_data["assets"] - if current_platform in asset["name"] and asset["name"].endswith("zip") + if asset["name"].endswith(".whl") + and ( + "any" in asset["name"] or short_platform[current_platform] in asset["name"] + ) ] if not matching_assets: print_formatted( f"Unsupported operating system: {current_platform}", color="yellow" ) - - wheels_directory.mkdir(exist_ok=True) - # - Install the wheels - for asset in matching_assets: - asset_name = asset["name"] - asset_download_url = asset["browser_download_url"] - print_formatted(f"Downloading asset: {asset_name}", color="yellow") - asset_dest = wheels_directory / asset_name - download_file(asset_download_url, asset_dest) - - # - Unzip to wheels dir - whl_files = [] - whl_order = None - with zipfile.ZipFile(asset_dest, "r") as zip_ref: - for item in tqdm(zip_ref.namelist(), desc="Extracting", unit="file"): - if item.endswith(".whl"): - item_basename = os.path.basename(item) - target_path = wheels_directory / item_basename - with zip_ref.open(item) as source, open( - target_path, "wb" - ) as target: - whl_files.append(target_path) - shutil.copyfileobj(source, target) - elif item.endswith("order.txt"): - item_basename = os.path.basename(item) - target_path = wheels_directory / item_basename - with zip_ref.open(item) as source, open( - target_path, "wb" - ) as target: - whl_order = target_path - shutil.copyfileobj(source, target) - + wheel_order_asset = next( + (asset for asset in tag_data["assets"] if asset["name"] == "wheel_order.txt"), + None, + ) + if wheel_order_asset is not None: print_formatted( - f"Wheels extracted for {current_platform} to the '{wheels_directory}' directory.", - "bold", - color="green", + "โš™๏ธ Sorting the release wheels using wheels order", "italic", color="yellow" ) + response = requests.get(wheel_order_asset["browser_download_url"]) + if response.status_code == 200: + wheel_order = [line.strip() for line in response.text.splitlines()] - if whl_files: - if whl_order: - with open(whl_order, "r") as order: - wheel_order_lines = [line.strip() for line in order] - whl_files = sorted( - whl_files, - key=lambda x: wheel_order_lines.index(x.name.split("-")[0]), - ) - - for whl_file in tqdm(whl_files, desc="Installing", unit="package"): - whl_path = wheels_directory / whl_file - - # check if installed + def get_order_index(val): try: - whl_dep = whl_path.name.split("-")[0] - import_name = pip_map.get(whl_dep, whl_dep) - import_module(import_name) - tqdm.write( - f"Package {import_name} already installed, skipping wheel installation.", - ) - continue - except ImportError: - if args.dry: - tqdm.write( - f"Dry-run: Package {whl_path.name} would be installed.", - ) - continue - - tqdm.write("Installing wheel: " + whl_path.name) - - subprocess.check_call( - [ - sys.executable, - "-m", - "pip", - "install", - whl_path.as_posix(), - ] - ) - - print_formatted("Wheels installation completed.", color="green") + return wheel_order.index(val) + except ValueError: + return len(wheel_order) + + matching_assets = sorted( + matching_assets, + key=lambda x: get_order_index(x["name"].split("-")[0]), + ) else: - print_formatted("No .whl files found. Nothing to install.", color="yellow") + print("Failed to fetch wheel_order.txt. Status code:", response.status_code) + + missing_deps_urls = [] + for whl_file in matching_assets: + # check if installed + whl_dep = whl_file["name"].split("-")[0] + missing_deps_urls.append(whl_file["browser_download_url"]) + + # run_command( + # [ + # sys.executable, + # "-m", + # "pip", + # "install", + # whl_path.as_posix(), + # ] + # ) + # # - Install the wheels + # for asset in matching_assets: + # asset_name = asset["name"] + # asset_download_url = asset["browser_download_url"] + # print_formatted(f"Downloading asset: {asset_name}", color="yellow") + # asset_dest = wheels_directory / asset_name + # download_file(asset_download_url, asset_dest) + + # # - Unzip to wheels dir + # whl_files = [] + # whl_order = None + # with zipfile.ZipFile(asset_dest, "r") as zip_ref: + # for item in tqdm(zip_ref.namelist(), desc="Extracting", unit="file"): + # if item.endswith(".whl"): + # item_basename = os.path.basename(item) + # target_path = wheels_directory / item_basename + # with zip_ref.open(item) as source, open( + # target_path, "wb" + # ) as target: + # whl_files.append(target_path) + # shutil.copyfileobj(source, target) + # elif item.endswith("order.txt"): + # item_basename = os.path.basename(item) + # target_path = wheels_directory / item_basename + # with zip_ref.open(item) as source, open( + # target_path, "wb" + # ) as target: + # whl_order = target_path + # shutil.copyfileobj(source, target) + + # print_formatted( + # f"Wheels extracted for {current_platform} to the '{wheels_directory}' directory.", + # "bold", + # color="green", + # ) + + # print_formatted( + # "\tFound those missing wheels from the release:\n\t\t -" + # + "\n\t\t - ".join(missing_deps_urls), + # "italic", + # color="yellow", + # no_header=True, + # ) + + install_cmd = [sys.executable, "-m", "pip", "install"] - # - Install all remainings - install_dependencies(dry=args.dry) + wheel_cmd = install_cmd + missing_deps_urls + + # - Install all deps + if not args.dry: + run_command(wheel_cmd) + run_command(install_cmd + ["-r", (here / "reqs.txt").as_posix()]) + print_formatted( + "Successfully installed all dependencies.", "italic", color="green" + ) + else: + print_formatted( + f"Would have run the following command:\n\t{apply_color(' '.join(install_cmd),'cyan')}", + "italic", + color="yellow", + ) diff --git a/nodes/faceenhance.py b/nodes/faceenhance.py index e922559..4dafd39 100644 --- a/nodes/faceenhance.py +++ b/nodes/faceenhance.py @@ -4,9 +4,12 @@ import os from pathlib import Path import folder_paths +from ..utils import pil2tensor, np2tensor, tensor2np + from basicsr.utils import imwrite + + from PIL import Image -from ..utils import pil2tensor, tensor2pil, np2tensor, tensor2np import torch from ..log import NullWriter, log from comfy import model_management @@ -28,7 +31,7 @@ def get_models_root(cls): @classmethod def get_models(cls): models_path = cls.get_models_root() - + if not models_path.exists(): log.warning(f"No models found at {models_path}") return [] diff --git a/nodes/faceswap.py b/nodes/faceswap.py index bf8c2db..4f0eba9 100644 --- a/nodes/faceswap.py +++ b/nodes/faceswap.py @@ -22,15 +22,19 @@ log = mklog(__name__) + class LoadFaceAnalysisModel: """Loads a face analysis model""" models = [] + @staticmethod def get_models() -> List[str]: models_path = os.path.join(folder_paths.models_dir, "insightface/*") models = glob.glob(models_path) - models = [Path(x).name for x in models if x.endswith(".onnx") or x.endswith(".pth")] + models = [ + Path(x).name for x in models if x.endswith(".onnx") or x.endswith(".pth") + ] return models @classmethod @@ -50,10 +54,12 @@ def INPUT_TYPES(cls): def load_model(self, faceswap_model: str): face_analyser = insightface.app.FaceAnalysis( - name=faceswap_model, root=os.path.join(folder_paths.models_dir, "insightface") + name=faceswap_model, + root=os.path.join(folder_paths.models_dir, "insightface"), ) return (face_analyser,) + class LoadFaceSwapModel: """Loads a faceswap model""" @@ -140,7 +146,7 @@ def do_swap(img): int(x) for x in faces_index.strip(",").split(",") if x.isnumeric() } sys.stdout = NullWriter() - swapped = swap_face(faceanalysis_model,ref, img, faceswap_model, face_ids) + swapped = swap_face(faceanalysis_model, ref, img, faceswap_model, face_ids) sys.stdout = sys.__stdout__ return pil2tensor(swapped) @@ -164,15 +170,18 @@ def do_swap(img): # region face swap utils -def get_face_single(face_analyser,img_data: np.ndarray, face_index=0, det_size=(640, 640)): - +def get_face_single( + face_analyser, img_data: np.ndarray, face_index=0, det_size=(640, 640) +): face_analyser.prepare(ctx_id=0, det_size=det_size) face = face_analyser.get(img_data) if len(face) == 0 and det_size[0] > 320 and det_size[1] > 320: log.debug("No face ed, trying again with smaller image") det_size_half = (det_size[0] // 2, det_size[1] // 2) - return get_face_single(face_analyser,img_data, face_index=face_index, det_size=det_size_half) + return get_face_single( + face_analyser, img_data, face_index=face_index, det_size=det_size_half + ) try: return sorted(face, key=lambda x: x.bbox[0])[face_index] @@ -195,12 +204,14 @@ def swap_face( if face_swapper_model is not None: cv_source_img = cv2.cvtColor(np.array(source_img), cv2.COLOR_RGB2BGR) cv_target_img = cv2.cvtColor(np.array(target_img), cv2.COLOR_RGB2BGR) - source_face = get_face_single(face_analyser,cv_source_img, face_index=0) + source_face = get_face_single(face_analyser, cv_source_img, face_index=0) if source_face is not None: result = cv_target_img for face_num in faces_index: - target_face = get_face_single(face_analyser,cv_target_img, face_index=face_num) + target_face = get_face_single( + face_analyser, cv_target_img, face_index=face_num + ) if target_face is not None: sys.stdout = NullWriter() result = face_swapper_model.get(result, target_face, source_face) diff --git a/reqs.txt b/reqs.txt index 2d60329..b0c986f 100644 --- a/reqs.txt +++ b/reqs.txt @@ -1,17 +1,11 @@ onnxruntime-gpu==1.15.1 -imageio===2.28.1 qrcode[pil] -numpy==1.23.5 -rembg==2.0.37 +rembg==2.0.50 # on windows non WSL 2.10 is the last version with GPU support -tensorflow<2.11.0; platform_system == "Windows" +tensorflow==2.10.1; platform_system == "Windows" tb-nightly==2.12.0a20230126; platform_system == "Windows" tensorflow; platform_system != "Windows" # the old tf version on windows comes with a breaking protobuf version -protobuf==3.20.2; platform_system == "Windows" -gdown @ git+https://github.com/melMass/gdown@main -mmdet==3.0.0 facexlib==0.3.0 insightface==0.7.3 -mmcv==2.0.0 basicsr==1.4.2 diff --git a/scripts/download_models.py b/scripts/download_models.py index 505ff9c..171bdbd 100644 --- a/scripts/download_models.py +++ b/scripts/download_models.py @@ -2,6 +2,8 @@ import requests from rich.console import Console from tqdm import tqdm +import subprocess +import sys try: import folder_paths @@ -30,13 +32,13 @@ "size": 332, "download_url": [ "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth", + "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" # TODO: provide a way to selectively download models from "packs" # https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/GFPGANv1.pth # https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth - # https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth # https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth ], - "destination": "upscale_models", + "destination": "face_restore", }, "FILM: Frame Interpolation for Large Motion": { "size": 402, @@ -51,7 +53,6 @@ from urllib.parse import urlparse from pathlib import Path -import gdown def download_model(download_url, destination): @@ -63,6 +64,21 @@ def download_model(download_url, destination): filename = os.path.basename(urlparse(download_url).path) response = None if "drive.google.com" in download_url: + try: + import gdown + except ImportError: + print("Installing gdown") + subprocess.check_call( + [ + sys.executable, + "-m", + "pip", + "install", + "git+https://github.com/melMass/gdown@main", + ] + ) + import gdown + if "/folders/" in download_url: # download folder try: diff --git a/utils.py b/utils.py index 9bc7788..1152504 100644 --- a/utils.py +++ b/utils.py @@ -3,11 +3,10 @@ import torch from pathlib import Path import sys - -from typing import Union, List -from .log import log +from typing import List +# region MISC Utilities def add_path(path, prepend=False): if isinstance(path, list): for p in path: @@ -24,6 +23,33 @@ def add_path(path, prepend=False): sys.path.append(path) +# todo use the requirements library +reqs_map = { + "onnxruntime": "onnxruntime-gpu==1.15.1", + "basicsr": "basicsr==1.4.2", + "rembg": "rembg==2.0.50", + "qrcode": "qrcode[pil]", +} + + +def import_install(package_name): + from pip._internal import main as pip_main + + try: + __import__(package_name) + except ImportError: + package_spec = reqs_map.get(package_name) + if package_spec is None: + print(f"Installing {package_name}") + package_spec = package_name + + pip_main(["install", package_spec]) + __import__(package_name) + + +# endregion + +# region GLOBAL VARIABLES # Get the absolute path of the parent directory of the current script here = Path(__file__).parent.resolve() @@ -44,8 +70,10 @@ def add_path(path, prepend=False): # Add the ComfyUI directory and custom nodes path to the sys.path list add_path(comfy_dir) add_path((comfy_dir / "custom_nodes")) +# endregion +# region TENSOR UTILITIES def tensor2pil(image: torch.Tensor) -> List[Image.Image]: batch_count = 1 if len(image.shape) > 3: @@ -89,3 +117,6 @@ def tensor2np(tensor: torch.Tensor) -> List[np.ndarray]: return out return [np.clip(255.0 * tensor.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)] + + +# endregion