diff --git a/README.md b/README.md index 230e204..60d055d 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ To install, either: - Input Image - Input Float - Input Integer +- Input Video Each input node supports setting a default value and additional configuration options. @@ -25,6 +26,9 @@ Each input node supports setting a default value and additional configuration op - Save Image - Save Images - Save Video - VHS +- Output Text +- Output Float +- Output Integer ### Convert Widgets to ShellAgent Inputs diff --git a/comfy-nodes/input_audio.py b/comfy-nodes/input_audio.py new file mode 100644 index 0000000..7e5997f --- /dev/null +++ b/comfy-nodes/input_audio.py @@ -0,0 +1,200 @@ +import folder_paths +import node_helpers + +from PIL import Image, ImageOps, ImageSequence, ImageFile +import numpy as np +import torch +import os +import uuid +import tqdm +import torchaudio +import hashlib +from comfy_extras.nodes_audio import SaveAudio + + +class LoadAudio: + @classmethod + def INPUT_TYPES(s): + input_dir = folder_paths.get_input_directory() + files = folder_paths.filter_files_content_types( + os.listdir(input_dir), ["audio", "video"]) + return {"required": {"audio": (sorted(files), {"audio_upload": True})}} + + CATEGORY = "audio" + + RETURN_TYPES = ("AUDIO", ) + FUNCTION = "load" + + def load(self, audio): + audio_path = folder_paths.get_annotated_filepath(audio) + waveform, sample_rate = torchaudio.load(audio_path) + audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate} + return (audio, ) + + @classmethod + def IS_CHANGED(s, audio): + image_path = folder_paths.get_annotated_filepath(audio) + m = hashlib.sha256() + with open(image_path, 'rb') as f: + m.update(f.read()) + return m.digest().hex() + + @classmethod + def VALIDATE_INPUTS(s, audio): + if not folder_paths.exists_annotated_filepath(audio): + return "Invalid audio file: {}".format(audio) + return True + + +class ShellAgentPluginInputAudio: + @classmethod + def INPUT_TYPES(s): + input_dir = folder_paths.get_input_directory() + files = folder_paths.filter_files_content_types( + os.listdir(input_dir), ["audio", "video"]) + return { + "required": { + "input_name": ( + "STRING", + {"multiline": False, "default": "input_audio", "forceInput": False}, + ), + "default_value": ( + sorted(files), {"audio_upload": True, "forceInput": False} + ), + }, + "optional": { + "description": ( + "STRING", + {"multiline": True, "default": "", "forceInput": False}, + ), + } + } + + RETURN_TYPES = ("AUDIO", ) + FUNCTION = "load" + + CATEGORY = "shellagent" + + @classmethod + def validate(cls, **kwargs): + schema = { + "title": kwargs["input_name"], + "type": "string", + "default": kwargs["default_value"], + "description": kwargs.get("description", ""), + "url_type": "audio" + } + return schema + + @classmethod + def VALIDATE_INPUTS(s, audio): + if not folder_paths.exists_annotated_filepath(audio): + return "Invalid audio file: {}".format(audio) + return True + + @classmethod + def VALIDATE_INPUTS(s, input_name, default_value, description=""): + audio = default_value + if audio.startswith("http"): + return True + + if not folder_paths.exists_annotated_filepath(audio): + return "Invalid audio file: {}".format(audio) + return True + + def load(self, input_name, default_value=None, display_name=None, description=None): + input_dir = folder_paths.get_input_directory() + audio_path = default_value + try: + if audio_path.startswith('http'): + import requests + from io import BytesIO + print("Fetching audio from url: ", audio_path) + response = requests.get(audio_path) + response.raise_for_status() + audio_file = BytesIO(response.content) + waveform, sample_rate = torchaudio.load(audio_file) + else: + if not os.path.isfile(audio_path): # abs path + # local path + audio_path = os.path.join(input_dir, audio_path) + waveform, sample_rate = torchaudio.load(audio_path) + + audio = {"waveform": waveform.unsqueeze( + 0), "sample_rate": sample_rate} + return (audio, ) + # image = ImageOps.exif_transpose(image) + # image = image.convert("RGB") + # image = np.array(image).astype(np.float32) / 255.0 + # image = torch.from_numpy(image)[None,] + # return [image] + except Exception as e: + raise e + + +class ShellAgentSaveAudios(SaveAudio): + @classmethod + def INPUT_TYPES(s): + return {"required": {"audio": ("AUDIO", ), + "output_name": ("STRING", {"multiline": False, "default": "output_audio"},), + "filename_prefix": ("STRING", {"default": "audio/ComfyUI"})}, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, + } + # { + # "required": { + # "images": ("IMAGE", {"tooltip": "The audio to save."}), + # "output_name": ("STRING", {"multiline": False, "default": "output_image"},), + # "filename_prefix": ("STRING", {"default": "ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}) + # }, + # "hidden": { + # "prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO" + # }, + # } + + CATEGORY = "shellagent" + + @classmethod + def validate(cls, **kwargs): + schema = { + "title": kwargs["output_name"], + "type": "array", + "items": { + "type": "string", + "url_type": "audio", + } + } + return schema + + def save_audio(self, audio, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None, **extra_kwargs): + results = super().save_audio(audio, filename_prefix, prompt, extra_pnginfo) + results["shellagent_kwargs"] = extra_kwargs + return results + + def save_flac(self, audio, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None, **extra_kwargs): + results = super().save_flac(audio, filename_prefix, "flac", prompt, extra_pnginfo) + results["shellagent_kwargs"] = extra_kwargs + return results + + +class ShellAgentSaveAudio(ShellAgentSaveAudios): + @classmethod + def validate(cls, **kwargs): + schema = { + "title": kwargs["output_name"], + "type": "string", + "url_type": "audio", + } + return schema + + +NODE_CLASS_MAPPINGS = { + "ShellAgentPluginInputAudio": ShellAgentPluginInputAudio, + "ShellAgentPluginSaveAudios": ShellAgentSaveAudios, + "ShellAgentPluginSaveAudio": ShellAgentSaveAudio, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "ShellAgentPluginInputAudio": "Input Audio (ShellAgent Plugin)", + "ShellAgentPluginSaveAudios": "Save Audios (ShellAgent Plugin)", + "ShellAgentPluginSaveAudio": "Save Audio (ShellAgent Plugin)", +} diff --git a/comfy-nodes/input_image.py b/comfy-nodes/input_image.py index e2b25fc..3399a4f 100755 --- a/comfy-nodes/input_image.py +++ b/comfy-nodes/input_image.py @@ -1,11 +1,39 @@ import folder_paths -from PIL import Image, ImageOps +import node_helpers + +from PIL import Image, ImageOps, ImageSequence, ImageFile import numpy as np import torch import os import uuid import tqdm - +from io import BytesIO +import PIL +import cv2 +from pillow_heif import register_heif_opener + +register_heif_opener() + +def safe_open_image(image_bytes): + try: + image_pil = Image.open(BytesIO(image_bytes)) + except PIL.UnidentifiedImageError as e: + print(e) + # Convert response content (bytes) to a NumPy array + image_array = np.frombuffer(image_bytes, np.uint8) + + # Decode the image from the NumPy array (OpenCV format: BGR) + image_cv = cv2.imdecode(image_array, cv2.IMREAD_COLOR) + + if image_cv is not None: + # Convert the BGR image to RGB + image_rgb = cv2.cvtColor(image_cv, cv2.COLOR_BGR2RGB) + + # Convert the RGB NumPy array to a PIL Image + image_pil = Image.fromarray(image_rgb) + else: + raise ValueError("The image cannot be identified by neither PIL nor OpenCV") + return image_pil class ShellAgentPluginInputImage: @classmethod @@ -20,20 +48,19 @@ def INPUT_TYPES(s): {"multiline": False, "default": "input_image", "forceInput": False}, ), "default_value": ( - # "STRING", {"image_upload": True, "default": files[0] if len(files) else ""}, sorted(files), {"image_upload": True, "forceInput": False} ), }, "optional": { "description": ( "STRING", - {"multiline": True, "default": "", "forceInput": False}, + {"multiline": False, "default": "", "forceInput": False}, ), } } - RETURN_TYPES = ("IMAGE",) - RETURN_NAMES = ("image",) + RETURN_TYPES = ("IMAGE", "MASK") + # RETURN_NAMES = ("image",) FUNCTION = "run" @@ -45,7 +72,7 @@ def validate(cls, **kwargs): "title": kwargs["input_name"], "type": "string", "default": kwargs["default_value"], - "description": kwargs["description"], + "description": kwargs.get("description", ""), "url_type": "image" } return schema @@ -53,24 +80,72 @@ def validate(cls, **kwargs): @classmethod def VALIDATE_INPUTS(s, input_name, default_value, description=""): image = default_value + if image.startswith("http"): return True + if image == "": + return "Invalid image file: please check if the image is empty or invalid" + + if os.path.isfile(image): + return True + if not folder_paths.exists_annotated_filepath(image): return "Invalid image file: {}".format(image) return True + + def convert_image_mask(self, img): + output_images = [] + output_masks = [] + w, h = None, None + + excluded_formats = ['MPO'] + + for i in ImageSequence.Iterator(img): + i = node_helpers.pillow(ImageOps.exif_transpose, i) + + if i.mode == 'I': + i = i.point(lambda i: i * (1 / 255)) + image = i.convert("RGB") + + if len(output_images) == 0: + w = image.size[0] + h = image.size[1] + + if image.size[0] != w or image.size[1] != h: + continue + + image = np.array(image).astype(np.float32) / 255.0 + image = torch.from_numpy(image)[None,] + if 'A' in i.getbands(): + mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 + mask = 1. - torch.from_numpy(mask) + else: + mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") + output_images.append(image) + output_masks.append(mask.unsqueeze(0)) + + if len(output_images) > 1 and img.format not in excluded_formats: + output_image = torch.cat(output_images, dim=0) + output_mask = torch.cat(output_masks, dim=0) + else: + output_image = output_images[0] + output_mask = output_masks[0] + + return (output_image, output_mask) + def run(self, input_name, default_value=None, display_name=None, description=None): - input_dir = folder_paths.get_input_directory() image_path = default_value + input_dir = folder_paths.get_input_directory() try: if image_path.startswith('http'): import requests from io import BytesIO print("Fetching image from url: ", image_path) response = requests.get(image_path) - image = Image.open(BytesIO(response.content)) + image = safe_open_image(response.content) elif image_path.startswith('data:image/png;base64,') or image_path.startswith('data:image/jpeg;base64,') or image_path.startswith('data:image/jpg;base64,'): import base64 from io import BytesIO @@ -82,13 +157,14 @@ def run(self, input_name, default_value=None, display_name=None, description=Non if not os.path.isfile(image_path): # abs path # local path image_path = os.path.join(input_dir, image_path) - image = Image.open(image_path).convert("RGB") - - image = ImageOps.exif_transpose(image) - image = image.convert("RGB") - image = np.array(image).astype(np.float32) / 255.0 - image = torch.from_numpy(image)[None,] - return [image] + image = node_helpers.pillow(Image.open, image_path) + + return self.convert_image_mask(image) + # image = ImageOps.exif_transpose(image) + # image = image.convert("RGB") + # image = np.array(image).astype(np.float32) / 255.0 + # image = torch.from_numpy(image)[None,] + # return [image] except Exception as e: raise e diff --git a/comfy-nodes/input_text.py b/comfy-nodes/input_text.py index f499322..1fb3f56 100755 --- a/comfy-nodes/input_text.py +++ b/comfy-nodes/input_text.py @@ -42,7 +42,7 @@ def validate(cls, **kwargs): "title": kwargs["input_name"], "type": "string", "default": kwargs["default_value"], - "description": kwargs["description"], + "description": kwargs.get("description", ""), } if kwargs.get("choices", "") != "": schema["enums"] = eval(kwargs["choices"]) @@ -101,7 +101,7 @@ def validate(cls, **kwargs): "title": kwargs["input_name"], "type": "number", "default": kwargs["default_value"], - "description": kwargs["description"], + "description": kwargs.get("description", ""), } if kwargs.get("choices", "") != "": schema["enums"] = eval(kwargs["choices"]) @@ -184,14 +184,58 @@ def validate(cls, **kwargs): def run(self, input_name, default_value=None, display_name=None, description=None, **kwargs): return [default_value] +class ShellAgentPluginInputBoolean: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "input_name": ( + "STRING", + {"multiline": False, "default": "input_bool"}, + ), + }, + "optional": { + "default_value": ( + "BOOLEAN", + {"default": False}, + ), + "description": ( + "STRING", + {"multiline": True, "default": ""}, + ), + } + } + + RETURN_TYPES = ("BOOLEAN",) + RETURN_NAMES = ("boolean",) + + FUNCTION = "run" + + CATEGORY = "shellagent" + + @classmethod + def validate(cls, **kwargs): + schema = { + "title": kwargs["input_name"], + "type": "boolean", + "default": kwargs["default_value"], + "description": kwargs.get("description", ""), + } + return schema + + def run(self, input_name, default_value=None, display_name=None, description=None, **kwargs): + return [default_value] + NODE_CLASS_MAPPINGS = { "ShellAgentPluginInputText": ShellAgentPluginInputText, "ShellAgentPluginInputFloat": ShellAgentPluginInputFloat, - "ShellAgentPluginInputInteger": ShellAgentPluginInputInteger + "ShellAgentPluginInputInteger": ShellAgentPluginInputInteger, + "ShellAgentPluginInputBoolean": ShellAgentPluginInputBoolean, } NODE_DISPLAY_NAME_MAPPINGS = { "ShellAgentPluginInputText": "Input Text (ShellAgent Plugin)", "ShellAgentPluginInputFloat": "Input Float (ShellAgent Plugin)", "ShellAgentPluginInputInteger": "Input Integer (ShellAgent Plugin)", + "ShellAgentPluginInputBoolean": "Input Boolean (ShellAgent Plugin)", } \ No newline at end of file diff --git a/comfy-nodes/input_video.py b/comfy-nodes/input_video.py index 48e94a0..c13c029 100644 --- a/comfy-nodes/input_video.py +++ b/comfy-nodes/input_video.py @@ -4,7 +4,7 @@ import torch import os import uuid -import tqdm +from tqdm import tqdm # class ShellAgentPluginInputImage: @@ -120,6 +120,15 @@ def validate(cls, **kwargs): "url_type": "video" } return schema + + @classmethod + def VALIDATE_INPUTS(s, input_name, default_value, description=""): + video = default_value + if video.startswith("http"): + return True + if not folder_paths.exists_annotated_filepath(video): + return "Invalid video file: {}".format(video) + return True def run(self, input_name, default_value=None, description=None): input_dir = folder_paths.get_input_directory() @@ -166,4 +175,4 @@ def run(self, input_name, default_value=None, description=None): NODE_DISPLAY_NAME_MAPPINGS = { # "ShellAgentPluginInputImage": "Input Image (ShellAgent Plugin)", "ShellAgentPluginInputVideo": "Input Video (ShellAgent Plugin)" -} \ No newline at end of file +} diff --git a/comfy-nodes/output_image.py b/comfy-nodes/output_image.py index 3198e4f..7a884b6 100644 --- a/comfy-nodes/output_image.py +++ b/comfy-nodes/output_image.py @@ -75,12 +75,15 @@ def validate(cls, **kwargs): return schema def save_video(self, filenames, **kwargs): - status, (preview_image, video_path) = filenames + status, output_files = filenames + if len(output_files) == 0: + raise ValueError("the filenames are empty") + print("output_files", output_files) + video_path = output_files[-1] cwd = os.getcwd() - preview_image = os.path.relpath(preview_image) - video_path = os.path.relpath(video_path) - results = {"ui": {"image": [preview_image], "video": [video_path]}} - print(results) + # preview_image = os.path.relpath(preview_image) + video_path = os.path.relpath(video_path, folder_paths.base_path) + results = {"ui": {"video": [video_path]}} return results diff --git a/comfy-nodes/output_text.py b/comfy-nodes/output_text.py new file mode 100644 index 0000000..fbb6c91 --- /dev/null +++ b/comfy-nodes/output_text.py @@ -0,0 +1,88 @@ + +json_type_mapipng = { + "text": "string", + "float": "number", + "integer": "integer", + "boolean": "boolean", +} + +class ShellAgentOutputText: + TYPE_STR = "text" + @classmethod + def INPUT_TYPES(s): + return { + "required": { + s.TYPE_STR: ("STRING", {"tooltip": f"The {s.TYPE_STR} to output."}), + "output_name": ("STRING", {"multiline": False, "default": f"output_{s.TYPE_STR}"},), + }, + } + + RETURN_TYPES = () + FUNCTION = "output_var" + + OUTPUT_NODE = True + + CATEGORY = "shellagent" + DESCRIPTION = "output the text" + + @classmethod + def validate(cls, **kwargs): + schema = { + "title": kwargs["output_name"], + "type": json_type_mapipng[cls.TYPE_STR] + } + return schema + + def output_var(self, **kwargs): + results = {"ui": {"output": [kwargs[self.TYPE_STR]]}} + return results + +class ShellAgentOutputFloat(ShellAgentOutputText): + TYPE_STR = "float" + DESCRIPTION = "output the float" + @classmethod + def INPUT_TYPES(s): + return { + "required": { + s.TYPE_STR: ("FLOAT", {"tooltip": f"The {s.TYPE_STR} to output."}), + "output_name": ("STRING", {"multiline": False, "default": f"output_{s.TYPE_STR}"},), + }, + } + + +class ShellAgentOutputInteger(ShellAgentOutputText): + TYPE_STR = "integer" + DESCRIPTION = "output the integer" + @classmethod + def INPUT_TYPES(s): + return { + "required": { + s.TYPE_STR: ("INT", {"tooltip": f"The {s.TYPE_STR} to output."}), + "output_name": ("STRING", {"multiline": False, "default": f"output_{s.TYPE_STR}"},), + }, + } + +class ShellAgentOutputBoolean(ShellAgentOutputText): + TYPE_STR = "boolean" + DESCRIPTION = "output the integer" + @classmethod + def INPUT_TYPES(s): + return { + "required": { + s.TYPE_STR: ("BOOLEAN", {"tooltip": f"The {s.TYPE_STR} to output."}), + "output_name": ("STRING", {"multiline": False, "default": f"output_{s.TYPE_STR}"},), + }, + } + + +NODE_CLASS_MAPPINGS = { + "ShellAgentPluginOutputText": ShellAgentOutputText, + "ShellAgentPluginOutputFloat": ShellAgentOutputFloat, + "ShellAgentPluginOutputInteger": ShellAgentOutputInteger, + "ShellAgentPluginOutputBoolean": ShellAgentOutputBoolean, +} +NODE_DISPLAY_NAME_MAPPINGS = { + "ShellAgentPluginOutputText": "Output Text (ShellAgent Plugin)", + "ShellAgentPluginOutputFloat": "Output Float (ShellAgent Plugin)", + "ShellAgentPluginOutputInteger": "Output Integer (ShellAgent Plugin)", +} \ No newline at end of file diff --git a/custom_routes.py b/custom_routes.py index 3fbab8b..9e2dca2 100755 --- a/custom_routes.py +++ b/custom_routes.py @@ -29,9 +29,12 @@ from datetime import datetime import nodes import traceback +import re +import keyword +import uuid -from .dependency_checker import resolve_dependencies - +from .dependency_checker import resolve_dependencies, inspect_repo_version +from folder_paths import base_path as BASE_PATH WORKFLOW_ROOT = "shellagent/comfy_workflow" @@ -45,6 +48,14 @@ "ShellAgentPluginSaveVideoVHS": "video", } +# Regular expression for a valid Python variable name +variable_name_pattern = r'^[a-zA-Z_][a-zA-Z0-9_]*$' + +def is_valid_variable_name(name): + # Check if it matches the pattern and is not a keyword + if re.match(variable_name_pattern, name) and not keyword.iskeyword(name): + return True + return False def schema_validator(prompt): from nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS @@ -84,6 +95,9 @@ def schema_validator(prompt): continue if hasattr(node_cls, "validate"): schema = node_cls.validate(**node_info["inputs"]) + # validate schema + if not is_valid_variable_name(schema["title"]): + raise ValueError(f'`{schema["title"]}` is not a valid variable name!') else: raise NotImplementedError("the validate is not implemented") schemas[mode][node_id] = schema @@ -158,16 +172,56 @@ async def shellagent_export(request): # for fname, dict_to_save in fname_mapping.items(): # with open(os.path.join(save_root, fname), "w") as f: # json.dump(dict_to_save, f, indent=2) + warning_message = "" + if dependency_results.get("black_list_nodes", []): + warning_message = "The following nodes cannot be deployed to myshell:\n" + for item in dependency_results["black_list_nodes"]: + warning_message += f" {item['name']}: {item['reason']}\n" + + if len(schemas["inputs"]) + len(schemas["outputs"]) == 0: + warning_message += f"The workflow contains neither inputs nor outputs!\n" return_dict = { "success": True, - "dependencies": dependency_results, + "dependencies": dependency_results["dependencies"], + "warning_message": warning_message, "schemas": schemas } except Exception as e: status = 400 return_dict = { "success": False, - "message": str(traceback.format_exc()), + "message_detail": str(traceback.format_exc()), + "message": str(e), } - return web.json_response(return_dict, status=status) \ No newline at end of file + return web.json_response(return_dict, status=status) + + +@server.PromptServer.instance.routes.post("/shellagent/inspect_version") # data same as queue prompt, plus workflow_name +async def shellagent_inspect_version(request): + data = await request.json() + comfyui_version = inspect_repo_version(BASE_PATH) + comfyui_shellagent_plugin_version = inspect_repo_version(os.path.dirname(__file__)) + return_dict = { + "comfyui_version": comfyui_version, + "comfyui_shellagent_plugin_version": comfyui_shellagent_plugin_version, + } + return web.json_response(return_dict, status=200) + + +@server.PromptServer.instance.routes.post("/shellagent/get_mac_addr") # data same as queue prompt, plus workflow_name +async def shellagent_get_mac_addr(request): + data = await request.json() + return_dict = { + "mac_addr": uuid.getnode() + } + return web.json_response(return_dict, status=200) + +@server.PromptServer.instance.routes.post("/shellagent/check_exist") # check if the file or folder exist +async def shellagent_check_exist(request): + data = await request.json() + + return_dict = { + "exist": uuid.getnode() == data["mac_addr"] and os.path.exists(data["path"]) # really exist, instead of same name + } + return web.json_response(return_dict, status=200) \ No newline at end of file diff --git a/dependency_checker.py b/dependency_checker.py index b9fee0b..b516da9 100644 --- a/dependency_checker.py +++ b/dependency_checker.py @@ -5,25 +5,36 @@ from functools import partial import re import glob +import sys from folder_paths import models_dir as MODELS_DIR from folder_paths import base_path as BASE_PATH +from folder_paths import get_full_path -from .utils import compute_sha256, windows_to_linux_path + +from .utils.utils import compute_sha256, windows_to_linux_path +from .utils.pytree import tree_map from .file_upload import collect_local_file, process_local_file_path_async model_list_json = json.load(open(os.path.join(os.path.dirname(__file__), "model_info.json"))) model_loaders_info = json.load(open(os.path.join(os.path.dirname(__file__), "model_loader_info.json"))) node_deps_info = json.load(open(os.path.join(os.path.dirname(__file__), "node_deps_info.json"))) +node_blacklist = json.load(open(os.path.join(os.path.dirname(__file__), "node_blacklist.json"))) +node_remote_skip_models = json.load(open(os.path.join(os.path.dirname(__file__), "node_remote.json"))) + +model_suffix = [".ckpt", ".safetensors", ".bin", ".pth", ".pt", ".onnx", ".gguf", ".sft", ".ttf"] +extra_packages = ["transformers", "timm", "diffusers", "accelerate"] + +def get_full_path_or_raise(folder_name: str, filename: str) -> str: + full_path = get_full_path(folder_name, filename) + if full_path is None: + raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.") + return full_path -model_suffix = [".ckpt", ".safetensors", ".bin", ".pth", ".pt", ".onnx"] -def handle_model_info(ckpt_path): +def handle_model_info(ckpt_path, filename, rel_save_path): ckpt_path = windows_to_linux_path(ckpt_path) - filename = os.path.basename(ckpt_path) - dirname = os.path.dirname(ckpt_path) - save_path = os.path.dirname(os.path.relpath(ckpt_path, MODELS_DIR)) metadata_path = ckpt_path + ".json" if os.path.isfile(metadata_path): metadata = json.load(open(metadata_path)) @@ -35,7 +46,7 @@ def handle_model_info(ckpt_path): model_id = compute_sha256(ckpt_path) data = { "id": model_id, - "save_path": save_path, + "save_path": rel_save_path, "filename": filename, } json.dump(data, open(metadata_path, "w")) @@ -45,8 +56,8 @@ def handle_model_info(ckpt_path): urls = [] item = { - "filename": filename, - "save_path": windows_to_linux_path(save_path), + "filename": windows_to_linux_path(filename), + "save_path": windows_to_linux_path(rel_save_path), "urls": urls, } return model_id, item @@ -59,6 +70,10 @@ def inspect_repo_version(module_path): "repo": "", "commit": "" } + + if not os.path.isdir(os.path.join(module_path, ".git")): + return result + # Get the remote repository URL try: remote_url = subprocess.check_output( @@ -87,7 +102,7 @@ def inspect_repo_version(module_path): def fetch_model_searcher_results(model_ids): import requests - url = "https://shellagent.myshell.ai/models_searcher/search_urls" + url = "https://models-searcher.myshell.life/search_urls" headers = { "Content-Type": "application/json" } @@ -96,59 +111,164 @@ def fetch_model_searcher_results(model_ids): } response = requests.post(url, headers=headers, json=data) - results = [item[:10] for item in response.json()] + if response.status_code == 200: + results = [item[:10] for item in response.json()] + else: + results = None return results +def split_package_version(require_line): + require_line = require_line.strip() + + pattern = r"^([a-zA-Z0-9_\-\[\]]+)(.*)$" + match = re.match(pattern, require_line.strip()) + + if match: + package_name = match.group(1) # First capturing group is the package name + version_specifier = match.group(2) if match.group(2) else "" # Second group is the version, if present + return package_name, version_specifier + else: + assert len(require_line) == 0 or require_line.strip()[0] == "#", require_line + return None, None + +def get_package_version(package_name): + try: + if sys.version_info >= (3, 8): + from importlib.metadata import version, PackageNotFoundError + return version(package_name) + else: + from pkg_resources import get_distribution, DistributionNotFound + return get_distribution(package_name).version + except Exception: + return None + def resolve_dependencies(prompt, custom_dependencies): # resolve custom nodes and models at the same time from nodes import NODE_CLASS_MAPPINGS + import folder_paths + custom_nodes = [] - ckpt_paths = [] + ckpt_paths = {} file_mapping_dict = {} + + SKIP_FOLDER_NAMES = ["configs", "custom_nodes"] + def collect_unknown_models(filename, node_id, node_info, custom_node_path): + if type(filename) != str: + return + is_model = False + for possible_suffix in model_suffix: + if filename.endswith(possible_suffix): + is_model = True + if is_model: + print(f"find {filename}, is_model=True") + # find possible paths + matching_files = {} + # Walk through all subdirectories and files in the directory + rel_save_path = None + for possible_folder_name in folder_paths.folder_names_and_paths: + if possible_folder_name in SKIP_FOLDER_NAMES: + print(f"skip {possible_folder_name}") + continue + full_path = folder_paths.get_full_path(possible_folder_name, filename) + if full_path is None: + continue + rel_save_path = os.path.relpath(folder_paths.folder_names_and_paths[possible_folder_name][0][0], folder_paths.models_dir) + matching_files[full_path] = { + "rel_save_path": rel_save_path + } + + print(f"matched files: {matching_files}") + + # step 2: search for all the files under "models" + + for full_path in glob.glob(f"{folder_paths.models_dir}/**/*", recursive=True): + if os.path.isfile(full_path) and full_path.endswith(filename) and full_path not in matching_files: + folder_path = full_path[:-len(filename)] + rel_save_path = os.path.relpath(folder_path, folder_paths.models_dir) + matching_files[full_path] = { + "rel_save_path": rel_save_path + } + + print(f"matched files: {matching_files}") + + # step 3: search inside the custom nodes + if custom_node_path is not None: + for full_path in glob.glob(f"{custom_node_path}/**/*", recursive=True): + if os.path.isfile(full_path) and full_path.endswith(filename) and full_path not in matching_files: + folder_path = full_path[:-len(filename)] + rel_save_path = os.path.relpath(folder_path, folder_paths.models_dir) + matching_files[full_path] = { + "rel_save_path": rel_save_path + } + + if len(matching_files) == 0: + raise ValueError(f"Cannot find model: `{filename}`, Node ID: `{node_id}`, Node Info: `{node_info}`") + + elif len(matching_files) <= 3: + for full_path, info in matching_files.items(): + ckpt_paths[full_path] = { + "filename": filename, + "rel_save_path": info["rel_save_path"] + } + return + else: + raise ValueError(f"Multiple models of `{filename}` founded, Node ID: `{node_id}`, Node Info: `{node_info}`, Possible paths: `{list(matching_files.keys())}`") + + for node_id, node_info in prompt.items(): node_class_type = node_info.get("class_type") if node_class_type is None: raise NotImplementedError(f"Missing nodes founded, please first install the missing nodes using ComfyUI Manager") node_cls = NODE_CLASS_MAPPINGS[node_class_type] - if hasattr(node_cls, "RELATIVE_PYTHON_MODULE"): + + skip_model_check = False + + custom_node_path = None + if hasattr(node_cls, "RELATIVE_PYTHON_MODULE") and node_cls.RELATIVE_PYTHON_MODULE.startswith("custom_nodes."): + print(node_cls.RELATIVE_PYTHON_MODULE) custom_nodes.append(node_cls.RELATIVE_PYTHON_MODULE) + custom_node_path = os.path.join(BASE_PATH, node_cls.RELATIVE_PYTHON_MODULE.replace(".", "/")) + if node_cls.RELATIVE_PYTHON_MODULE[len("custom_nodes."):] in node_remote_skip_models: + skip_model_check = True + print(f"skip model check for {node_class_type}") + if node_class_type in model_loaders_info: - for field_name, filename in node_info["inputs"].items(): - for item in model_loaders_info[node_class_type]: - pattern = item["field_name"] - if re.match(f"^{pattern}$", field_name): - ckpt_path = os.path.join(MODELS_DIR, item["save_path"], filename) - ckpt_paths.append(ckpt_path) - else: for field_name, filename in node_info["inputs"].items(): if type(filename) != str: continue - is_model = False - for possible_suffix in model_suffix: - if filename.endswith(possible_suffix): - is_model = True - if is_model: - print(f"find {filename}, is_model=True") - # find possible paths - matching_files = [] - # Walk through all subdirectories and files in the directory - for possible_filename in glob.glob(os.path.join(MODELS_DIR, "**", "*"), recursive=True): - if os.path.isfile(possible_filename) and possible_filename.endswith(filename): - matching_files.append(possible_filename) - print(f"matched files: {matching_files}") - if len(matching_files) == 1: - ckpt_paths.append(matching_files[0]) + for item in model_loaders_info[node_class_type]: + pattern = item["field_name"] + if re.match(f"^{pattern}$", field_name) and any([filename.endswith(possible_suffix) for possible_suffix in model_suffix]): + ckpt_path = get_full_path_or_raise(item["save_path"], filename) + if hasattr(folder_paths, "map_legacy"): + save_folder = folder_paths.map_legacy(item["save_path"]) + else: + save_folder = item["save_path"] + rel_save_path = os.path.relpath(folder_paths.folder_names_and_paths[save_folder][0][0], folder_paths.models_dir) + ckpt_paths[ckpt_path] = { + "filename": filename, + "rel_save_path": rel_save_path + } + elif not skip_model_check: + tree_map(lambda x: collect_unknown_models(x, node_id, node_info, custom_node_path), node_info["inputs"]) + list(map(partial(collect_local_file, mapping_dict=file_mapping_dict), node_info["inputs"].values())) - ckpt_paths = list(set(ckpt_paths)) print("ckpt_paths:", ckpt_paths) custom_nodes = list(set(custom_nodes)) # step 0: comfyui version - comfyui_version = inspect_repo_version(BASE_PATH) - + repo_info = inspect_repo_version(BASE_PATH) + if repo_info["repo"] == "": + repo_info["require_recheck"] = True + if repo_info["name"] in custom_dependencies["custom_nodes"]: + repo_info["repo"] = custom_dependencies["custom_nodes"][repo_info["name"]].get("repo", "") + repo_info["commit"] = custom_dependencies["custom_nodes"][repo_info["name"]].get("commit", "") + comfyui_version = repo_info + # step 1: custom nodes custom_nodes_list = [] custom_nodes_names = [] + requirements_lines = [] for custom_node in custom_nodes: try: repo_info = inspect_repo_version(os.path.join(BASE_PATH, custom_node.replace(".", "/"))) @@ -161,20 +281,39 @@ def resolve_dependencies(prompt, custom_dependencies): # resolve custom nodes an custom_nodes_names.append(repo_info["name"]) except: print(f"failed to resolve repo info of {custom_node}") + requirement_file = os.path.join(BASE_PATH, custom_node.replace(".", "/"), "requirements.txt") + if os.path.isfile(requirement_file): + try: + requirements_lines += open(requirement_file).readlines() + except: + pass + requirements_lines = list(set(requirements_lines)) + requirements_packages = [package_name for package_name, version_specifier in map(split_package_version, requirements_lines) if package_name is not None] + package_names = set(requirements_packages + extra_packages) + pypi_deps = { + package_name: get_package_version(package_name) + for package_name in package_names + } for repo_name in custom_nodes_names: if repo_name in node_deps_info: for deps_node in node_deps_info[repo_name]: if deps_node["name"] not in custom_nodes_names: - repo_info = inspect_repo_version(os.path.join("custom_nodes", deps_node["name"])) + repo_info = inspect_repo_version(os.path.join(BASE_PATH, "custom_nodes", deps_node["name"])) deps_node["commit"] = repo_info["commit"] custom_nodes_list.append(deps_node) + custom_nodes_names.append(deps_node["name"]) + + black_list_nodes = [] + for repo_name in custom_nodes_names: + if repo_name in node_blacklist: + black_list_nodes.append({"name": repo_name, "reason": node_blacklist[repo_name]["reason"]}) # step 2: models models_dict = {} missing_model_ids = [] - for ckpt_path in ckpt_paths: - model_id, item = handle_model_info(ckpt_path) + for ckpt_path, ckpt_info in ckpt_paths.items(): + model_id, item = handle_model_info(ckpt_path, ckpt_info["filename"], ckpt_info["rel_save_path"]) models_dict[model_id] = item if len(item["urls"]) == 0: item["require_recheck"] = True @@ -184,20 +323,30 @@ def resolve_dependencies(prompt, custom_dependencies): # resolve custom nodes an # try to fetch from myshell model searcher missing_model_results_myshell = fetch_model_searcher_results(missing_model_ids) - for missing_model_id, missing_model_urls in zip(missing_model_ids, missing_model_results_myshell): - if len(missing_model_urls) > 0: - models_dict[missing_model_id]["require_recheck"] = False - models_dict[missing_model_id]["urls"] = missing_model_urls - print("successfully fetch results from myshell", models_dict[missing_model_id]) + if missing_model_results_myshell is not None: + for missing_model_id, missing_model_urls in zip(missing_model_ids, missing_model_results_myshell): + if len(missing_model_urls) > 0: + models_dict[missing_model_id]["require_recheck"] = False + models_dict[missing_model_id]["urls"] = missing_model_urls + print("successfully fetch results from myshell", models_dict[missing_model_id]) # step 3: handle local files process_local_file_path_async(file_mapping_dict, max_workers=20) - files_dict = {v[0]: {"filename": windows_to_linux_path(os.path.relpath(v[2], BASE_PATH)), "urls": [v[1]]} for v in file_mapping_dict.values()} + files_dict = { + v[0]: { + "filename": windows_to_linux_path(os.path.relpath(v[2], BASE_PATH)) if not v[3] else v[2], + "urls": [v[1]]} for v in file_mapping_dict.values()} - results = { + depencencies = { "comfyui_version": comfyui_version, "custom_nodes": custom_nodes_list, "models": models_dict, "files": files_dict, + "pypi": pypi_deps + } + + return_dict = { + "dependencies": depencencies, + "black_list_nodes": black_list_nodes, } - return results \ No newline at end of file + return return_dict diff --git a/file_upload.py b/file_upload.py index fd471fe..aa13ff8 100644 --- a/file_upload.py +++ b/file_upload.py @@ -5,7 +5,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed import folder_paths -from .utils import compute_sha256 +from .utils.utils import compute_sha256, get_alphanumeric_hash ext_to_type = { # image @@ -27,7 +27,7 @@ '.m4a': 'audio/mp4', } -def upload_file_to_myshell(local_file: str) -> str: +def upload_file_to_myshell(local_file: str, target_path: str, is_abs) -> str: ''' Now we only support upload file one-by-one ''' MYSHELL_KEY = os.environ.get('MYSHELL_KEY', "OPENSOURCE_FIXED") @@ -51,8 +51,8 @@ def upload_file_to_myshell(local_file: str) -> str: response = requests.request("POST", server_url, headers=headers, files=files) if response.status_code == 200: end_time = time.time() - logging.info(f"{local_file} uploaded, time elapsed: {end_time - start_time}") - return [sha256sum, response.json()['url'], local_file] + logging.info(f"{local_file} uploaded, time elapsed: {end_time - start_time}, will be saved to {target_path}") + return [sha256sum, response.json()['url'], target_path, is_abs] else: raise Exception( f"[HTTP ERROR] {response.status_code} - {response.text} \n" @@ -66,8 +66,11 @@ def collect_local_file(item, mapping_dict={}): abspath = os.path.abspath(item) input_abspath = os.path.join(input_dir, item) # required file type + is_abs = False if os.path.isfile(abspath): fpath = abspath + is_abs = True + elif os.path.isfile(input_abspath): fpath = input_abspath else: @@ -75,7 +78,13 @@ def collect_local_file(item, mapping_dict={}): if fpath is not None: ext = os.path.splitext(fpath)[1] if ext.lower() in ext_to_type.keys(): - mapping_dict[item] = fpath + if is_abs: # if use abs path, replace it + filename_hash = get_alphanumeric_hash(abspath)[:16] + count = len(mapping_dict) + target_path = f"/ShellAgentDeploy/ComfyUI/input/{filename_hash}_{count:06d}{ext}" + mapping_dict[item] = (fpath, target_path, is_abs) + else: + mapping_dict[item] = (fpath, fpath, is_abs) return else: return @@ -86,7 +95,7 @@ def process_local_file_path_async(mapping_dict, max_workers=10): start_time = time.time() with ThreadPoolExecutor(max_workers=max_workers) as executor: # Submit tasks to the executor - futures = {executor.submit(upload_file_to_myshell, full_path): filename for filename, full_path in mapping_dict.items()} + futures = {executor.submit(upload_file_to_myshell, source_path, target_path, is_abs): filename for filename, (source_path, target_path, is_abs) in mapping_dict.items()} logging.info("submit done") # Collect the results as they complete for future in as_completed(futures): @@ -95,7 +104,8 @@ def process_local_file_path_async(mapping_dict, max_workers=10): result = future.result() mapping_dict[filename] = result except Exception as e: - print(f"Error processing {filename}: {e}") + del mapping_dict[filename] + raise NotImplementedError(f"Error processing {filename}: {e}") end_time = time.time() logging.info(f"upload end, elapsed time: {end_time - start_time}") return \ No newline at end of file diff --git a/node_blacklist.json b/node_blacklist.json new file mode 100644 index 0000000..4fca91b --- /dev/null +++ b/node_blacklist.json @@ -0,0 +1,5 @@ +{ + "comfyui-ollama": { + "reason": "this node requires installing an extra software on linux, which is currently unsupported" + } +} \ No newline at end of file diff --git a/node_deps_info.json b/node_deps_info.json index dfb1945..040669e 100644 --- a/node_deps_info.json +++ b/node_deps_info.json @@ -1,19 +1,45 @@ { - "ComfyUI-Easy-Use": [ - { - "name": "ComfyUI-Inspire-Pack", - "repo": "https://github.com/ltdrdata/ComfyUI-Inspire-Pack.git", - "commit": "" - }, - { - "name": "ComfyUI-Advanced-ControlNet", - "repo": "https://github.com/Kosinkadink/ComfyUI-Advanced-ControlNet.git", - "commit": "" - }, - { - "name": "ComfyUI_smZNodes", - "repo": "https://github.com/shiimizu/ComfyUI_smZNodes.git", - "commit": "" - } - ] -} \ No newline at end of file + "ComfyUI-Easy-Use": [ + { + "name": "ComfyUI-Inspire-Pack", + "repo": "https://github.com/ltdrdata/ComfyUI-Inspire-Pack.git", + "commit": "" + }, + { + "name": "ComfyUI-Advanced-ControlNet", + "repo": "https://github.com/Kosinkadink/ComfyUI-Advanced-ControlNet.git", + "commit": "" + }, + { + "name": "ComfyUI_smZNodes", + "repo": "https://github.com/shiimizu/ComfyUI_smZNodes.git", + "commit": "" + }, + { + "name": "ComfyUI_IPAdapter_plus", + "repo": "https://github.com/cubiq/ComfyUI_IPAdapter_plus.git", + "commit": "" + } + ], + "efficiency-nodes-comfyui": [ + { + "name": "comfyui_controlnet_aux", + "repo": "https://github.com/Fannovel16/comfyui_controlnet_aux.git", + "commit": "" + } + ], + "ComfyUI-Anyline": [ + { + "name": "comfyui_controlnet_aux", + "repo": "https://github.com/Fannovel16/comfyui_controlnet_aux.git", + "commit": "" + } + ], + "ComfyUI-Impact-Pack": [ + { + "name": "ComfyUI-Impact-Subpack", + "repo": "https://github.com/ltdrdata/ComfyUI-Impact-Subpack.git", + "commit": "" + } + ] +} diff --git a/node_remote.json b/node_remote.json new file mode 100644 index 0000000..f089f82 --- /dev/null +++ b/node_remote.json @@ -0,0 +1,3 @@ +[ + "BizyAir" +] \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index f1d3a61..6b6d458 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,6 @@ pydantic opencv-python imageio-ffmpeg brotli -# logfire \ No newline at end of file +pillow_heif +easydict +# logfire diff --git a/utils/pytree.py b/utils/pytree.py new file mode 100644 index 0000000..d326a84 --- /dev/null +++ b/utils/pytree.py @@ -0,0 +1,1197 @@ +""" +Copy from https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py + +Contains utility functions for working with nested python data structures. + +A *pytree* is Python nested data structure. It is a tree in the sense that +nodes are Python collections (e.g., list, tuple, dict) and the leaves are +Python values. Furthermore, a pytree should not contain reference cycles. + +pytrees are useful for working with nested collections of Tensors. For example, +one can use `tree_map` to map a function over all Tensors inside some nested +collection of Tensors and `tree_leaves` to get a flat list of all Tensors +inside some nested collection. pytrees are helpful for implementing nested +collection support for PyTorch APIs. + +This pytree implementation is not very performant due to Python overhead +To improve the performance we can move parts of the implementation to C++. +""" + +import dataclasses +import importlib +import json +import threading +import warnings +from collections import defaultdict, deque, namedtuple, OrderedDict +from typing import ( + Any, + Callable, + cast, + DefaultDict, + Deque, + Dict, + FrozenSet, + Iterable, + List, + NamedTuple, + Optional, + OrderedDict as GenericOrderedDict, + overload, + Tuple, + Type, + TypeVar, + Union, +) +from easydict import EasyDict as edict + + +__all__ = [ + "PyTree", + "Context", + "FlattenFunc", + "UnflattenFunc", + "DumpableContext", + "ToDumpableContextFn", + "FromDumpableContextFn", + "TreeSpec", + "LeafSpec", + "register_pytree_node", + "tree_flatten", + "tree_unflatten", + "tree_leaves", + "tree_structure", + "tree_map", + "tree_map_", + "tree_map_only", + "tree_map_only_", + "tree_all", + "tree_any", + "tree_all_only", + "tree_any_only", + "treespec_dumps", + "treespec_loads", + "treespec_pprint", +] + + +T = TypeVar("T") +S = TypeVar("S") +U = TypeVar("U") +R = TypeVar("R") + + +DEFAULT_TREESPEC_SERIALIZATION_PROTOCOL = 1 +NO_SERIALIZED_TYPE_NAME_FOUND = "NO_SERIALIZED_TYPE_NAME_FOUND" + +Context = Any +PyTree = Any +FlattenFunc = Callable[[PyTree], Tuple[List[Any], Context]] +UnflattenFunc = Callable[[Iterable[Any], Context], PyTree] +DumpableContext = Any # Any json dumpable text +ToDumpableContextFn = Callable[[Context], DumpableContext] +FromDumpableContextFn = Callable[[DumpableContext], Context] +ToStrFunc = Callable[["TreeSpec", List[str]], str] +MaybeFromStrFunc = Callable[[str], Optional[Tuple[Any, Context, str]]] + + +# A NodeDef holds two callables: +# - flatten_fn should take the collection and return a flat list of values. +# It can also return some context that is used in reconstructing the +# collection. +# - unflatten_fn should take a flat list of values and some context +# (returned by flatten_fn). It returns the collection by reconstructing +# it from the list and the context. +class NodeDef(NamedTuple): + type: Type[Any] + flatten_fn: FlattenFunc + unflatten_fn: UnflattenFunc + + +_NODE_REGISTRY_LOCK = threading.Lock() +SUPPORTED_NODES: Dict[Type[Any], NodeDef] = {} + + +# _SerializeNodeDef holds the following: +# - typ: the type of the node (e.g., "Dict", "List", etc) +# - serialized_type_name: the fully qualified name of the type, e.g. "collections.OrderedDict" +# - to_dumpable_context takes a TreeSpec, and returns a serialized string format of the +# context, and the version number +# - from_dumpable_context takes in a string representation of the context, and the +# version, and returns the deserialized context +class _SerializeNodeDef(NamedTuple): + typ: Type[Any] + serialized_type_name: str + to_dumpable_context: Optional[ToDumpableContextFn] + from_dumpable_context: Optional[FromDumpableContextFn] + + +SUPPORTED_SERIALIZED_TYPES: Dict[Type[Any], _SerializeNodeDef] = {} +SERIALIZED_TYPE_TO_PYTHON_TYPE: Dict[str, Type[Any]] = {} + + +def register_pytree_node( + cls: Type[Any], + flatten_fn: FlattenFunc, + unflatten_fn: UnflattenFunc, + *, + serialized_type_name: Optional[str] = None, + to_dumpable_context: Optional[ToDumpableContextFn] = None, + from_dumpable_context: Optional[FromDumpableContextFn] = None, +) -> None: + """Register a container-like type as pytree node. + + Args: + cls: the type to register + flatten_fn: A callable that takes a pytree and returns a flattened + representation of the pytree and additional context to represent the + flattened pytree. + unflatten_fn: A callable that takes a flattened version of the pytree, + additional context, and returns an unflattened pytree. + serialized_type_name: A keyword argument used to specify the fully qualified + name used when serializing the tree spec. + to_dumpable_context: An optional keyword argument to custom specify how + to convert the context of the pytree to a custom json dumpable + representation. This is used for json serialization, which is being + used in torch.export right now. + from_dumpable_context: An optional keyword argument to custom specify how + to convert the custom json dumpable representation of the context + back to the original context. This is used for json deserialization, + which is being used in torch.export right now. + """ + with _NODE_REGISTRY_LOCK: + if cls in SUPPORTED_NODES: + raise ValueError(f"{cls} is already registered as pytree node.") + + _private_register_pytree_node( + cls, + flatten_fn, + unflatten_fn, + serialized_type_name=serialized_type_name, + to_dumpable_context=to_dumpable_context, + from_dumpable_context=from_dumpable_context, + ) + + try: + from . import _cxx_pytree as cxx + except ImportError: + pass + else: + cxx._private_register_pytree_node( + cls, + flatten_fn, + unflatten_fn, + serialized_type_name=serialized_type_name, + to_dumpable_context=to_dumpable_context, + from_dumpable_context=from_dumpable_context, + ) + + +def _register_pytree_node( + cls: Type[Any], + flatten_fn: FlattenFunc, + unflatten_fn: UnflattenFunc, + to_str_fn: Optional[ToStrFunc] = None, # deprecated + maybe_from_str_fn: Optional[MaybeFromStrFunc] = None, # deprecated + *, + serialized_type_name: Optional[str] = None, + to_dumpable_context: Optional[ToDumpableContextFn] = None, + from_dumpable_context: Optional[FromDumpableContextFn] = None, +) -> None: + """Register a container-like type as pytree node for the Python pytree only. + + Args: + cls: the type to register + flatten_fn: A callable that takes a pytree and returns a flattened + representation of the pytree and additional context to represent the + flattened pytree. + unflatten_fn: A callable that takes a flattened version of the pytree, + additional context, and returns an unflattened pytree. + serialized_type_name: A keyword argument used to specify the fully qualified + name used when serializing the tree spec. + to_dumpable_context: An optional keyword argument to custom specify how + to convert the context of the pytree to a custom json dumpable + representation. This is used for json serialization, which is being + used in torch.export right now. + from_dumpable_context: An optional keyword argument to custom specify how + to convert the custom json dumpable representation of the context + back to the original context. This is used for json deserialization, + which is being used in torch.export right now. + """ + warnings.warn( + "torch.utils._pytree._register_pytree_node is deprecated. " + "Please use torch.utils._pytree.register_pytree_node instead.", + stacklevel=2, + ) + + if to_str_fn is not None or maybe_from_str_fn is not None: + warnings.warn( + "to_str_fn and maybe_from_str_fn is deprecated. " + "Please use to_dumpable_context and from_dumpable_context instead." + ) + + _private_register_pytree_node( + cls, + flatten_fn, + unflatten_fn, + serialized_type_name=serialized_type_name, + to_dumpable_context=to_dumpable_context, + from_dumpable_context=from_dumpable_context, + ) + + +def _private_register_pytree_node( + cls: Type[Any], + flatten_fn: FlattenFunc, + unflatten_fn: UnflattenFunc, + *, + serialized_type_name: Optional[str] = None, + to_dumpable_context: Optional[ToDumpableContextFn] = None, + from_dumpable_context: Optional[FromDumpableContextFn] = None, +) -> None: + """This is an internal function that is used to register a pytree node type + for the Python pytree only. End-users should use :func:`register_pytree_node` + instead. + """ + with _NODE_REGISTRY_LOCK: + if cls in SUPPORTED_NODES: + # TODO: change this warning to an error after OSS/internal stabilize + warnings.warn( + f"{cls} is already registered as pytree node. " + "Overwriting the previous registration.", + ) + + node_def = NodeDef( + cls, + flatten_fn, + unflatten_fn, + ) + SUPPORTED_NODES[cls] = node_def + + if (to_dumpable_context is None) ^ (from_dumpable_context is None): + raise ValueError( + f"Both to_dumpable_context and from_dumpable_context for {cls} must " + "be None or registered." + ) + + if serialized_type_name is None: + serialized_type_name = f"{cls.__module__}.{cls.__qualname__}" + + serialize_node_def = _SerializeNodeDef( + cls, + serialized_type_name, + to_dumpable_context, + from_dumpable_context, + ) + SUPPORTED_SERIALIZED_TYPES[cls] = serialize_node_def + SERIALIZED_TYPE_TO_PYTHON_TYPE[serialized_type_name] = cls + + +def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]: + return list(d.values()), list(d.keys()) + + +def _dict_unflatten(values: Iterable[Any], context: Context) -> Dict[Any, Any]: + return dict(zip(context, values)) + + +def _list_flatten(d: List[Any]) -> Tuple[List[Any], Context]: + return d, None + + +def _list_unflatten(values: Iterable[Any], context: Context) -> List[Any]: + return list(values) + + +def _tuple_flatten(d: Tuple[Any, ...]) -> Tuple[List[Any], Context]: + return list(d), None + + +def _tuple_unflatten(values: Iterable[Any], context: Context) -> Tuple[Any, ...]: + return tuple(values) + + +def _namedtuple_flatten(d: NamedTuple) -> Tuple[List[Any], Context]: + return list(d), type(d) + + +def _namedtuple_unflatten(values: Iterable[Any], context: Context) -> NamedTuple: + return cast(NamedTuple, context(*values)) + + +def _namedtuple_serialize(context: Context) -> DumpableContext: + json_namedtuple = { + "class_name": context.__name__, + "fields": context._fields, + } + return json_namedtuple + + +def _namedtuple_deserialize(dumpable_context: DumpableContext) -> Context: + class_name = dumpable_context["class_name"] + assert isinstance(class_name, str) + # type: ignore[misc] + context = namedtuple(class_name, dumpable_context["fields"]) + return context + + +def _ordereddict_flatten(d: GenericOrderedDict[Any, Any]) -> Tuple[List[Any], Context]: + return list(d.values()), list(d.keys()) + + +def _ordereddict_unflatten( + values: Iterable[Any], + context: Context, +) -> GenericOrderedDict[Any, Any]: + return OrderedDict((key, value) for key, value in zip(context, values)) + + +_odict_flatten = _ordereddict_flatten +_odict_unflatten = _ordereddict_unflatten + + +def _defaultdict_flatten(d: DefaultDict[Any, Any]) -> Tuple[List[Any], Context]: + values, dict_context = _dict_flatten(d) + return values, [d.default_factory, dict_context] + + +def _defaultdict_unflatten( + values: Iterable[Any], + context: Context, +) -> DefaultDict[Any, Any]: + default_factory, dict_context = context + return defaultdict(default_factory, _dict_unflatten(values, dict_context)) + + +def _defaultdict_serialize(context: Context) -> DumpableContext: + default_factory, dict_context = context + json_defaultdict = { + "default_factory_module": default_factory.__module__, + "default_factory_name": default_factory.__qualname__, + "dict_context": dict_context, + } + return json_defaultdict + + +def _defaultdict_deserialize(dumpable_context: DumpableContext) -> Context: + assert isinstance(dumpable_context, dict) + assert set(dumpable_context) == { + "default_factory_module", + "default_factory_name", + "dict_context", + } + + default_factory_module = dumpable_context["default_factory_module"] + default_factory_name = dumpable_context["default_factory_name"] + assert isinstance(default_factory_module, str) + assert isinstance(default_factory_name, str) + module = importlib.import_module(default_factory_module) + default_factory = getattr(module, default_factory_name) + + dict_context = dumpable_context["dict_context"] + return [default_factory, dict_context] + + +def _deque_flatten(deq: Deque[Any]) -> Tuple[List[Any], Context]: + return list(deq), deq.maxlen + + +def _deque_unflatten(values: Iterable[Any], context: Context) -> Deque[Any]: + return deque(values, maxlen=context) + + +_private_register_pytree_node( + tuple, + _tuple_flatten, + _tuple_unflatten, + serialized_type_name="builtins.tuple", +) +_private_register_pytree_node( + list, + _list_flatten, + _list_unflatten, + serialized_type_name="builtins.list", +) +_private_register_pytree_node( + dict, + _dict_flatten, + _dict_unflatten, + serialized_type_name="builtins.dict", +) +_private_register_pytree_node( + namedtuple, # type: ignore[arg-type] + _namedtuple_flatten, + _namedtuple_unflatten, + serialized_type_name="collections.namedtuple", + to_dumpable_context=_namedtuple_serialize, + from_dumpable_context=_namedtuple_deserialize, +) +_private_register_pytree_node( + OrderedDict, + _ordereddict_flatten, + _ordereddict_unflatten, + serialized_type_name="collections.OrderedDict", +) +_private_register_pytree_node( + defaultdict, + _defaultdict_flatten, + _defaultdict_unflatten, + serialized_type_name="collections.defaultdict", + to_dumpable_context=_defaultdict_serialize, + from_dumpable_context=_defaultdict_deserialize, +) +_private_register_pytree_node( + deque, + _deque_flatten, + _deque_unflatten, + serialized_type_name="collections.deque", +) + + +STANDARD_DICT_TYPES: FrozenSet[type] = frozenset( + {dict, OrderedDict, defaultdict}, +) +BUILTIN_TYPES: FrozenSet[type] = frozenset( + {tuple, list, dict, namedtuple, OrderedDict, + defaultdict, deque}, # type: ignore[arg-type] +) + + +# h/t https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple +def _is_namedtuple_instance(tree: Any) -> bool: + typ = type(tree) + bases = typ.__bases__ + if len(bases) != 1 or bases[0] != tuple: + return False + fields = getattr(typ, "_fields", None) + if not isinstance(fields, tuple): + return False + return all(type(entry) == str for entry in fields) + + +def _get_node_type(tree: Any) -> Any: + if _is_namedtuple_instance(tree): + return namedtuple + return type(tree) + + +# A leaf is defined as anything that is not a Node. +def _is_leaf(tree: PyTree) -> bool: + return _get_node_type(tree) not in SUPPORTED_NODES + + +# A TreeSpec represents the structure of a pytree. It holds: +# "type": the type of root Node of the pytree +# context: some context that is useful in unflattening the pytree +# children_specs: specs for each child of the root Node +# num_leaves: the number of leaves +@dataclasses.dataclass +class TreeSpec: + type: Any + context: Context + children_specs: List["TreeSpec"] + + num_nodes: int = dataclasses.field(init=False) + num_leaves: int = dataclasses.field(init=False) + num_children: int = dataclasses.field(init=False) + + def __post_init__(self) -> None: + self.num_nodes = 1 + \ + sum(spec.num_nodes for spec in self.children_specs) + self.num_leaves = sum(spec.num_leaves for spec in self.children_specs) + self.num_children = len(self.children_specs) + + def __repr__(self, indent: int = 0) -> str: + repr_prefix: str = f"TreeSpec({self.type.__name__}, {self.context}, [" + children_specs_str: str = "" + if self.num_children > 0: + indent += 2 + children_specs_str += self.children_specs[0].__repr__(indent) + children_specs_str += "," if self.num_children > 1 else "" + children_specs_str += ",".join( + [ + "\n" + " " * indent + child.__repr__(indent) + for child in self.children_specs[1:] + ] + ) + repr_suffix: str = f"{children_specs_str}])" + return repr_prefix + repr_suffix + + def is_leaf(self) -> bool: + return self.num_nodes == 1 and self.num_leaves == 1 + + def _flatten_up_to_helper(self, tree: PyTree, subtrees: List[PyTree]) -> None: + if self.is_leaf(): + subtrees.append(tree) + return + + node_type = _get_node_type(tree) + if self.type not in BUILTIN_TYPES: + # Always require custom node types to match exactly + if node_type != self.type: + raise ValueError( + f"Type mismatch; " + f"expected {self.type!r}, but got {node_type!r}.", + ) + flatten_fn = SUPPORTED_NODES[node_type].flatten_fn + child_pytrees, context = flatten_fn(tree) + if len(child_pytrees) != self.num_children: + raise ValueError( + f"Node arity mismatch; " + f"expected {self.num_children}, but got {len(child_pytrees)}.", + ) + if context != self.context: + raise ValueError( + f"Node context mismatch for custom node type {self.type!r}.", + ) + else: + # For builtin dictionary types, we allow some flexibility + # Otherwise, we require exact matches + both_standard_dict = ( + self.type in STANDARD_DICT_TYPES and node_type in STANDARD_DICT_TYPES + ) + if node_type != self.type and not both_standard_dict: + raise ValueError( + f"Node type mismatch; " + f"expected {self.type!r}, but got {node_type!r}.", + ) + if len(tree) != self.num_children: + raise ValueError( + f"Node arity mismatch; " + f"expected {self.num_children}, but got {len(tree)}.", + ) + + if both_standard_dict: # dictionary types are compatible with each other + dict_context = ( + self.context + if self.type is not defaultdict + # ignore mismatch of `default_factory` for defaultdict + else self.context[1] + ) + expected_keys = dict_context + got_key_set = set(tree) + expected_key_set = set(expected_keys) + if got_key_set != expected_key_set: + missing_keys = expected_key_set.difference(got_key_set) + extra_keys = got_key_set.difference(expected_key_set) + message = "" + if missing_keys: + message += f"; missing key(s): {missing_keys}" + if extra_keys: + message += f"; extra key(s): {extra_keys}" + raise ValueError(f"Node keys mismatch{message}.") + child_pytrees = [tree[key] for key in expected_keys] + else: + flatten_fn = SUPPORTED_NODES[node_type].flatten_fn + child_pytrees, context = flatten_fn(tree) + if ( + context != self.context + and self.type is not deque # ignore mismatch of `maxlen` for deque + ): + raise ValueError( + f"Node context mismatch for node type {self.type!r}; " + # namedtuple type mismatch + f"expected {self.context!r}, but got {context!r}.", + ) + + for child_pytree, child_spec in zip(child_pytrees, self.children_specs): + child_spec._flatten_up_to_helper(child_pytree, subtrees) + + def flatten_up_to(self, tree: PyTree) -> List[PyTree]: + subtrees: List[PyTree] = [] + self._flatten_up_to_helper(tree, subtrees) + return subtrees + + def unflatten(self, leaves: Iterable[Any]) -> PyTree: + if not isinstance(leaves, (list, tuple)): + leaves = list(leaves) + if len(leaves) != self.num_leaves: + raise ValueError( + f"treespec.unflatten(leaves): `leaves` has length {len(leaves)} " + f"but the spec refers to a pytree that holds {self.num_leaves} " + f"items ({self}).", + ) + if self.is_leaf(): + return leaves[0] + + unflatten_fn = SUPPORTED_NODES[self.type].unflatten_fn + + # Recursively unflatten the children + start = 0 + end = 0 + child_pytrees = [] + for child_spec in self.children_specs: + end += child_spec.num_leaves + child_pytrees.append(child_spec.unflatten(leaves[start:end])) + start = end + + return unflatten_fn(child_pytrees, self.context) + + +class LeafSpec(TreeSpec): + def __init__(self) -> None: + super().__init__(None, None, []) + + def __post_init__(self) -> None: + self.num_nodes = 1 + self.num_leaves = 1 + self.num_children = 0 + + def __repr__(self, indent: int = 0) -> str: + return "*" + + +# All leaves are equivalent, so represent with a single object to save on +# object construction time +_LEAF_SPEC = LeafSpec() + + +def _tree_flatten_helper(tree: PyTree, leaves: List[Any]) -> TreeSpec: + if hasattr(tree, "keys"): # type(tree) == edict: + tree = {**tree} + + if _is_leaf(tree): + leaves.append(tree) + return _LEAF_SPEC + + node_type = _get_node_type(tree) + flatten_fn = SUPPORTED_NODES[node_type].flatten_fn + child_pytrees, context = flatten_fn(tree) + + # Recursively flatten the children + children_specs = [_tree_flatten_helper( + child, leaves) for child in child_pytrees] + + return TreeSpec(node_type, context, children_specs) + + +def tree_flatten(tree: PyTree) -> Tuple[List[Any], TreeSpec]: + """Flattens a pytree into a list of values and a TreeSpec that can be used + to reconstruct the pytree. + """ + leaves: List[Any] = [] + spec = _tree_flatten_helper(tree, leaves) + return leaves, spec + + +def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree: + """Given a list of values and a TreeSpec, builds a pytree. + This is the inverse operation of `tree_flatten`. + """ + if not isinstance(treespec, TreeSpec): + raise TypeError( + f"tree_unflatten(leaves, treespec): Expected `treespec` to be " + f"instance of TreeSpec but got item of type {type(treespec)}.", + ) + return treespec.unflatten(leaves) + + +def _tree_leaves_helper(tree: PyTree, leaves: List[Any]) -> None: + if _is_leaf(tree): + leaves.append(tree) + return + + node_type = _get_node_type(tree) + flatten_fn = SUPPORTED_NODES[node_type].flatten_fn + child_pytrees, _ = flatten_fn(tree) + + # Recursively flatten the children + for child in child_pytrees: + _tree_leaves_helper(child, leaves) + + +def tree_leaves(tree: PyTree) -> List[Any]: + """Get a list of leaves of a pytree.""" + leaves: List[Any] = [] + _tree_leaves_helper(tree, leaves) + return leaves + + +def tree_structure(tree: PyTree) -> TreeSpec: + """Get the TreeSpec for a pytree.""" + return tree_flatten(tree)[1] + + +def tree_map(func: Callable[..., Any], tree: PyTree, *rests: PyTree) -> PyTree: + """Map a multi-input function over pytree args to produce a new pytree. + + See also :func:`tree_map_`. + + >>> tree_map(lambda x: x + 1, {'x': 7, 'y': (42, 64)}) + {'x': 8, 'y': (43, 65)} + >>> tree_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None}) + {'x': False, 'y': (False, False), 'z': True} + + If multiple inputs are given, the structure of the tree is taken from the first input; + subsequent inputs need only have ``tree`` as a prefix: + + >>> tree_map(lambda x, y: [x] + y, [5, 6], [[7, 9], [1, 2]]) + [[5, 7, 9], [6, 1, 2]] + + Args: + func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the + corresponding leaves of the pytrees. + tree (pytree): A pytree to be mapped over, with each leaf providing the first positional + argument to function ``func``. + rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as + ``tree`` or has ``tree`` as a prefix. + + Returns: + A new pytree with the same structure as ``tree`` but with the value at each leaf given by + ``func(x, *xs)`` where ``x`` is the value at the corresponding leaf in ``tree`` and ``xs`` + is the tuple of values at corresponding nodes in ``rests``. + """ + leaves, treespec = tree_flatten(tree) + flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests] + return treespec.unflatten(map(func, *flat_args)) + + +def tree_map_(func: Callable[..., Any], tree: PyTree, *rests: PyTree) -> PyTree: + """Like :func:`tree_map`, but do an inplace call on each leaf and return the original tree. + + See also :func:`tree_map`. + + Args: + func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the + corresponding leaves of the pytrees. + tree (pytree): A pytree to be mapped over, with each leaf providing the first positional + argument to function ``func``. + rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as + ``tree`` or has ``tree`` as a prefix. + + Returns: + The original ``tree`` with the value at each leaf is given by the side-effect of function + ``func(x, *xs)`` (not the return value) where ``x`` is the value at the corresponding leaf + in ``tree`` and ``xs`` is the tuple of values at values at corresponding nodes in ``rests``. + """ + leaves, treespec = tree_flatten(tree) + flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests] + deque(map(func, *flat_args), maxlen=0) # consume and exhaust the iterable + return tree + + +Type2 = Tuple[Type[T], Type[S]] +Type3 = Tuple[Type[T], Type[S], Type[U]] +TypeAny = Union[Type[Any], Tuple[Type[Any], ...]] + +Fn2 = Callable[[Union[T, S]], R] +Fn3 = Callable[[Union[T, S, U]], R] +Fn = Callable[[T], R] +FnAny = Callable[[Any], R] + +MapOnlyFn = Callable[[T], Callable[[Any], Any]] + + +# These specializations help with type inference on the lambda passed to this +# function +@overload +def map_only(__type_or_types: Type2[T, S]) -> MapOnlyFn[Fn2[T, S, Any]]: + ... + + +@overload +def map_only(__type_or_types: Type3[T, S, U]) -> MapOnlyFn[Fn3[T, S, U, Any]]: + ... + + +@overload +def map_only(__type_or_types: Type[T]) -> MapOnlyFn[Fn[T, Any]]: + ... + + +# This specialization is needed for the implementations below that call +@overload +def map_only(__type_or_types: TypeAny) -> MapOnlyFn[FnAny[Any]]: + ... + + +def map_only(__type_or_types: TypeAny) -> MapOnlyFn[FnAny[Any]]: + """ + Suppose you are writing a tree_map over tensors, leaving everything + else unchanged. Ordinarily you would have to write: + + def go(t): + if isinstance(t, Tensor): + return ... + else: + return t + + With this function, you only need to write: + + @map_only(Tensor) + def go(t): + return ... + + You can also directly use 'tree_map_only' + """ + + def wrapper(func: Callable[[T], Any]) -> Callable[[Any], Any]: + # @functools.wraps(func) # torch dynamo doesn't support this yet + def wrapped(x: T) -> Any: + if isinstance(x, __type_or_types): + return func(x) + return x + + return wrapped + + return wrapper + + +@overload +def tree_map_only( + __type_or_types: Type[T], + func: Fn[T, Any], + tree: PyTree, +) -> PyTree: + ... + + +@overload +def tree_map_only( + __type_or_types: Type2[T, S], + func: Fn2[T, S, Any], + tree: PyTree, +) -> PyTree: + ... + + +@overload +def tree_map_only( + __type_or_types: Type3[T, S, U], + func: Fn3[T, S, U, Any], + tree: PyTree, +) -> PyTree: + ... + + +def tree_map_only( + __type_or_types: TypeAny, + func: FnAny[Any], + tree: PyTree, +) -> PyTree: + return tree_map(map_only(__type_or_types)(func), tree) + + +@overload +def tree_map_only_( + __type_or_types: Type[T], + func: Fn[T, Any], + tree: PyTree, +) -> PyTree: + ... + + +@overload +def tree_map_only_( + __type_or_types: Type2[T, S], + func: Fn2[T, S, Any], + tree: PyTree, +) -> PyTree: + ... + + +@overload +def tree_map_only_( + __type_or_types: Type3[T, S, U], + func: Fn3[T, S, U, Any], + tree: PyTree, +) -> PyTree: + ... + + +def tree_map_only_( + __type_or_types: TypeAny, + func: FnAny[Any], + tree: PyTree, +) -> PyTree: + return tree_map_(map_only(__type_or_types)(func), tree) + + +def tree_all(pred: Callable[[Any], bool], tree: PyTree) -> bool: + flat_args = tree_leaves(tree) + return all(map(pred, flat_args)) + + +def tree_any(pred: Callable[[Any], bool], tree: PyTree) -> bool: + flat_args = tree_leaves(tree) + return any(map(pred, flat_args)) + + +@overload +def tree_all_only( + __type_or_types: Type[T], + pred: Fn[T, bool], + tree: PyTree, +) -> bool: + ... + + +@overload +def tree_all_only( + __type_or_types: Type2[T, S], + pred: Fn2[T, S, bool], + tree: PyTree, +) -> bool: + ... + + +@overload +def tree_all_only( + __type_or_types: Type3[T, S, U], + pred: Fn3[T, S, U, bool], + tree: PyTree, +) -> bool: + ... + + +def tree_all_only( + __type_or_types: TypeAny, + pred: FnAny[bool], + tree: PyTree, +) -> bool: + flat_args = tree_leaves(tree) + return all(pred(x) for x in flat_args if isinstance(x, __type_or_types)) + + +@overload +def tree_any_only( + __type_or_types: Type[T], + pred: Fn[T, bool], + tree: PyTree, +) -> bool: + ... + + +@overload +def tree_any_only( + __type_or_types: Type2[T, S], + pred: Fn2[T, S, bool], + tree: PyTree, +) -> bool: + ... + + +@overload +def tree_any_only( + __type_or_types: Type3[T, S, U], + pred: Fn3[T, S, U, bool], + tree: PyTree, +) -> bool: + ... + + +def tree_any_only( + __type_or_types: TypeAny, + pred: FnAny[bool], + tree: PyTree, +) -> bool: + flat_args = tree_leaves(tree) + return any(pred(x) for x in flat_args if isinstance(x, __type_or_types)) + + +# Broadcasts a pytree to the provided TreeSpec and returns the flattened +# values. If this is not possible, then this function returns None. +# +# For example, given pytree=0 and spec=TreeSpec(list, None, [LeafSpec(), LeafSpec()]), +# would return [0, 0]. This is useful for part of the vmap implementation: +# a user can pass in vmap(fn, in_dims)(*inputs). `in_dims` should be +# broadcastable to the tree structure of `inputs` and we use +# _broadcast_to_and_flatten to check this. +def _broadcast_to_and_flatten(tree: PyTree, treespec: TreeSpec) -> Optional[List[Any]]: + assert isinstance(treespec, TreeSpec) + + if _is_leaf(tree): + return [tree] * treespec.num_leaves + if isinstance(treespec, LeafSpec): + return None + node_type = _get_node_type(tree) + if node_type != treespec.type: + return None + + flatten_fn = SUPPORTED_NODES[node_type].flatten_fn + child_pytrees, ctx = flatten_fn(tree) + + # Check if the Node is different from the spec + if len(child_pytrees) != treespec.num_children or ctx != treespec.context: + return None + + # Recursively flatten the children + result: List[Any] = [] + for child, child_spec in zip(child_pytrees, treespec.children_specs): + flat = _broadcast_to_and_flatten(child, child_spec) + if flat is not None: + result += flat + else: + return None + + return result + + +@dataclasses.dataclass +class _TreeSpecSchema: + """ + _TreeSpecSchema is the schema used to serialize the TreeSpec + It contains the following fields: + - type: A string name of the type. null for the case of a LeafSpec. + - context: Any format which is json dumpable + - children_spec: A list of children serialized specs. + """ + + type: Optional[str] + context: DumpableContext + children_spec: List["_TreeSpecSchema"] + + +class _ProtocolFn(NamedTuple): + treespec_to_json: Callable[[TreeSpec], DumpableContext] + json_to_treespec: Callable[[DumpableContext], TreeSpec] + + +_SUPPORTED_PROTOCOLS: Dict[int, _ProtocolFn] = {} + + +def _treespec_to_json(treespec: TreeSpec) -> _TreeSpecSchema: + if isinstance(treespec, LeafSpec): + return _TreeSpecSchema(None, None, []) + + if treespec.type not in SUPPORTED_SERIALIZED_TYPES: + raise NotImplementedError( + f"Serializing {treespec.type} in pytree is not registered.", + ) + + serialize_node_def = SUPPORTED_SERIALIZED_TYPES[treespec.type] + + serialized_type_name = serialize_node_def.serialized_type_name + + if serialized_type_name == NO_SERIALIZED_TYPE_NAME_FOUND: + raise NotImplementedError( + f"No registered serialization name for {treespec.type} found. " + "Please update your _register_pytree_node call with a `serialized_type_name` kwarg." + ) + + if serialize_node_def.to_dumpable_context is None: + try: + serialized_context = json.dumps(treespec.context) + except TypeError as e: + raise TypeError( + "Unable to serialize context. " + "Please make the context json dump-able, or register a " + "custom serializer using _register_pytree_node." + ) from e + else: + serialized_context = serialize_node_def.to_dumpable_context( + treespec.context) + + child_schemas = [_treespec_to_json(child) + for child in treespec.children_specs] + + return _TreeSpecSchema(serialized_type_name, serialized_context, child_schemas) + + +def _json_to_treespec(json_schema: DumpableContext) -> TreeSpec: + if ( + json_schema["type"] is None + and json_schema["context"] is None + and len(json_schema["children_spec"]) == 0 + ): + return LeafSpec() + + if json_schema["type"] not in SERIALIZED_TYPE_TO_PYTHON_TYPE: + raise NotImplementedError( + f'Deserializing {json_schema["type"]} in pytree is not registered.', + ) + + typ = SERIALIZED_TYPE_TO_PYTHON_TYPE[json_schema["type"]] + serialize_node_def = SUPPORTED_SERIALIZED_TYPES[typ] + + if serialize_node_def.from_dumpable_context is None: + try: + context = json.loads(json_schema["context"]) + except TypeError as ex: + raise TypeError( + "Unable to deserialize context. " + "Please make the context json load-able, or register a " + "custom serializer using _register_pytree_node.", + ) from ex + else: + context = serialize_node_def.from_dumpable_context( + json_schema["context"]) + + children_specs = [] + for child_string in json_schema["children_spec"]: + children_specs.append(_json_to_treespec(child_string)) + + return TreeSpec(typ, context, children_specs) + + +_SUPPORTED_PROTOCOLS[1] = _ProtocolFn(_treespec_to_json, _json_to_treespec) + + +def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str: + if not isinstance(treespec, TreeSpec): + raise TypeError( + f"treespec_dumps(treespec, protocol): Expected `treespec` to be instance of " + f"TreeSpec but got item of type {type(treespec)}.", + ) + + if protocol is None: + protocol = DEFAULT_TREESPEC_SERIALIZATION_PROTOCOL + + if protocol in _SUPPORTED_PROTOCOLS: + json_spec = _SUPPORTED_PROTOCOLS[protocol].treespec_to_json(treespec) + else: + raise ValueError( + f"Unknown protocol {protocol}. " + f"Available protocols: {list(_SUPPORTED_PROTOCOLS.keys())}", + ) + + str_spec = json.dumps((protocol, dataclasses.asdict(json_spec))) + return str_spec + + +def treespec_loads(serialized: str) -> TreeSpec: + protocol, json_schema = json.loads(serialized) + + if protocol in _SUPPORTED_PROTOCOLS: + return _SUPPORTED_PROTOCOLS[protocol].json_to_treespec(json_schema) + raise ValueError( + f"Unknown protocol {protocol}. " + f"Available protocols: {list(_SUPPORTED_PROTOCOLS.keys())}", + ) + + +class _DummyLeaf: + def __repr__(self) -> str: + return "*" + + +def treespec_pprint(treespec: TreeSpec) -> str: + dummy_tree = tree_unflatten( + [_DummyLeaf() for _ in range(treespec.num_leaves)], + treespec, + ) + return repr(dummy_tree) + + +# TODO(angelayi): remove this function after OSS/internal stabilize +def pytree_to_str(treespec: TreeSpec) -> str: + warnings.warn("pytree_to_str is deprecated. Please use treespec_dumps") + return treespec_dumps(treespec) + + +# TODO(angelayi): remove this function after OSS/internal stabilize +def str_to_pytree(json: str) -> TreeSpec: + warnings.warn("str_to_pytree is deprecated. Please use treespec_loads") + return treespec_loads(json) + + +def arg_tree_leaves(*args: PyTree, **kwargs: PyTree) -> List[Any]: + """Get a flat list of arguments to this function + + A slightly faster version of tree_leaves((args, kwargs)) + """ + leaves: List[Any] = [] + for a in args: + _tree_leaves_helper(a, leaves) + for a in kwargs.values(): + _tree_leaves_helper(a, leaves) + return leaves \ No newline at end of file diff --git a/utils.py b/utils/utils.py similarity index 58% rename from utils.py rename to utils/utils.py index 4854509..13e1216 100644 --- a/utils.py +++ b/utils/utils.py @@ -1,6 +1,8 @@ import hashlib import time from pathlib import PurePosixPath, Path, PureWindowsPath +import base64 +import re def windows_to_linux_path(windows_path): return PureWindowsPath(windows_path).as_posix() @@ -17,4 +19,17 @@ def compute_sha256(file_path, chunk_size=1024 ** 2): sha256.update(chunk) print("finish compute sha256 for", file_path, f"time: {time.time() - start}") # Return the hexadecimal digest of the hash - return sha256.hexdigest() \ No newline at end of file + return sha256.hexdigest() + + +def get_alphanumeric_hash(input_string: str) -> str: + # Generate a SHA-256 hash of the input string + sha256_hash = hashlib.sha256(input_string.encode()).digest() + + # Encode the hash in base64 to get a string with [A-Za-z0-9+/=] + base64_hash = base64.b64encode(sha256_hash).decode('ascii') + + # Remove any non-alphanumeric characters (+, /, =) + alphanumeric_hash = re.sub(r'[^a-zA-Z0-9]', '', base64_hash) + + return alphanumeric_hash \ No newline at end of file diff --git a/web/shellagent.js b/web/shellagent.js index e21a7d1..8480e4a 100644 --- a/web/shellagent.js +++ b/web/shellagent.js @@ -1,6 +1,9 @@ import { app } from "../../scripts/app.js"; import { api } from "../../scripts/api.js"; +var __defProp = Object.defineProperty; +var __name = (target, value) => __defProp(target, "name", { value, configurable: true }); + app.registerExtension({ name: "Shellagent.extension", async setup() { @@ -37,11 +40,15 @@ app.registerExtension({ }); }, async beforeRegisterNodeDef(nodeType, nodeData, app) { + if (["ShellAgentPluginOutputText", "ShellAgentPluginOutputFloat", "ShellAgentPluginOutputInteger"].indexOf(nodeData.name) > -1) { + chainCallback(nodeType.prototype, "onNodeCreated", function () { + this.convertWidgetToInput(this.widgets[0]) + }) + } if (["ShellAgentPluginInputText", "ShellAgentPluginInputFloat", "ShellAgentPluginInputInteger"].indexOf(nodeData.name) > -1) { chainCallback(nodeType.prototype, "onNodeCreated", function () { const widget = this.widgets.find(w => w.name === 'choices') - this.addWidget('button', 'manage choices', null, () => { const container = document.createElement("div"); Object.assign(container.style, { @@ -93,7 +100,7 @@ app.registerExtension({ try { arr = JSON.parse(widget.value) } catch { } - } else if(Array.isArray(widget.value)) { + } else if (Array.isArray(widget.value)) { arr = widget.value } @@ -137,12 +144,45 @@ app.registerExtension({ }) } + if (['LoadImage', 'LoadImageMask'].indexOf(nodeData.name) > -1) { + addMenuHandler(nodeType, function (_, options) { + options.unshift({ + content: "Replace with ShellAgent Input Image", + callback: () => { + const node = addNode("ShellAgentPluginInputImage", this, { before: true }); + + const dvn = node.widgets.find(w => w.name === 'default_value') + dvn.value = this.widgets.find(w => w.name === 'image')?.value + + app.graph.links.filter(l => l != null) + .forEach(l => { + const tn = app.graph._nodes_by_id[l.target_id] + node.connect(0, tn, 0) + }) + app.graph.remove(this); + } + }) + }) + } + if (nodeData.name === "ShellAgentPluginInputImage") { if ( nodeData?.input?.required?.default_value?.[1]?.image_upload === true ) { nodeData.input.required.upload = [ "IMAGEUPLOAD", + { widget: "default_value", imageInputName: "default_value", image_upload: true }, + ]; + } + } + + if (nodeData.name === "ShellAgentPluginInputAudio") { + if ( + nodeData?.input?.required?.default_value?.[1]?.audio_upload === true + ) { + nodeData.input.required.audioUI = ["AUDIO_UI"]; + nodeData.input.required.upload = [ + "SHELLAGENT_AUDIOUPLOAD", { widget: "default_value" }, ]; } @@ -151,7 +191,6 @@ app.registerExtension({ if (nodeData.name === "ShellAgentPluginInputVideo") { addUploadWidget(nodeType, nodeData, "default_value"); chainCallback(nodeType.prototype, "onNodeCreated", function () { - // const pathWidget = this.widgets.find((w) => w.name === "video"); const pathWidget = this.widgets.find((w) => w.name === "default_value"); chainCallback(pathWidget, "callback", (value) => { if (!value) { @@ -174,59 +213,238 @@ app.registerExtension({ if (nodeData.name.indexOf('ShellAgentPlugin') === -1) { addMenuHandler(nodeType, function (_, options) { + if (this.widgets) { let toInput = []; - for (const w of this.widgets) { if (["customtext"].indexOf(w.type) > -1) { toInput.push({ - content: `${w.name} <- Input Text`, - callback: () => { - this.convertWidgetToInput(w); - const node = addNode("ShellAgentPluginInputText", this, { before: true }); - const dvn = node.widgets.find(w => w.name === 'default_value') - dvn.value = w.value; - node.connect(0, this, this.inputs.length - 1); - } + content: w.name, + submenu: { + options: [ + { + content: 'Input Text', + callback: () => { + this.convertWidgetToInput(w); + const node = addNode("ShellAgentPluginInputText", this, { before: true }); + const dvn = node.widgets.find(w => w.name === 'default_value') + dvn.value = w.value; + node.connect(0, this, this.inputs.length - 1); + } + } + ] + }, }) } if (["number"].indexOf(w.type) > -1) { toInput.push({ - content: `${w.name} <- Input Interger`, - callback: () => { - this.convertWidgetToInput(w); - const node = addNode("ShellAgentPluginInputInteger", this, { before: true }); - const dvn = node.widgets.find(w => w.name === 'default_value') - dvn.value = w.value; - node.connect(0, this, this.inputs.length - 1); + content: w.name, + submenu: { + options: [ + { + content: 'Input Interger', + callback: () => { + this.convertWidgetToInput(w); + const node = addNode("ShellAgentPluginInputInteger", this, { before: true }); + const dvn = node.widgets.find(w => w.name === 'default_value') + dvn.value = w.value; + node.connect(0, this, this.inputs.length - 1); + } + }, + { + content: 'Input Float', + callback: () => { + this.convertWidgetToInput(w); + const node = addNode("ShellAgentPluginInputFloat", this, { before: true }); + const dvn = node.widgets.find(w => w.name === 'default_value') + dvn.value = w.value; + node.connect(0, this, this.inputs.length - 1); + } + } + ] + } + }) + } + } + if (toInput.length) { + options.unshift({ + content: "Convert to ShellAgent (Input)", + submenu: { + options: toInput + } + }) + } + } + + if (this.outputs) { + let toOutput = []; + for (const o of this.outputs) { + if (o.type === 'IMAGE') { + toOutput.push({ + content: o.name, + submenu: { + options: [ + { + content: 'Save Image', + callback: () => { + const node = addNode("ShellAgentPluginSaveImage", this); + this.connect(0, node, 0); + } + }, + { + content: 'Save Images', + callback: () => { + const node = addNode("ShellAgentPluginSaveImages", this); + this.connect(0, node, 0); + } + } + ] } + }) + } - toInput.push({ - content: `${w.name} <- Input Float`, - callback: () => { - this.convertWidgetToInput(w); - const node = addNode("ShellAgentPluginInputFloat", this, { before: true }); - const dvn = node.widgets.find(w => w.name === 'default_value') - dvn.value = w.value; - node.connect(0, this, this.inputs.length - 1); + if (o.type === 'STRING') { + toOutput.push({ + content: o.name, + submenu: { + options: [ + { + content: `Output Text`, + callback: () => { + const node = addNode("ShellAgentPluginOutputText", this); + this.connect(0, node, 0); + } + }, + { + content: `Output Float`, + callback: () => { + const node = addNode("ShellAgentPluginOutputFloat", this); + this.connect(0, node, 0); + } + }, + { + content: `Output Integer`, + callback: () => { + const node = addNode("ShellAgentPluginOutputInteger", this); + this.connect(0, node, 0); + } + } + ] + } + }) + } + + if (o.type === "VHS_FILENAMES") { + toOutput.push({ + content: o.name, + submenu: { + options: [ + { + content: `Save Video - VHS`, + callback: () => { + const node = addNode("ShellAgentPluginSaveVideoVHS", this); + this.connect(0, node, 0); + } + } + ] } }) } } - if (toInput.length) { + if (toOutput.length) { options.unshift({ - content: "Convert to ShellAgent", + content: "Connect to ShellAgent (Output)", submenu: { - options: toInput + options: toOutput } }) } + } }) } }, + + afterConfigureGraph(missingNodeTypes, app) { + function addIn(type, nodeId) { + if(LiteGraph.slot_types_default_in[type] == null) { + LiteGraph.slot_types_default_in[type] = [] + } + if (LiteGraph.slot_types_default_in[type].indexOf(nodeId) === -1) { + LiteGraph.slot_types_default_in[type].unshift(nodeId) + } + } + + function addOut(type, nodeId) { + if(LiteGraph.slot_types_default_out[type] == null) { + LiteGraph.slot_types_default_out[type] = [] + } + if (LiteGraph.slot_types_default_out[type].indexOf(nodeId) === -1) { + LiteGraph.slot_types_default_out[type].unshift(nodeId) + } + } + + addIn('IMAGE', 'ShellAgentPluginInputImage') + addIn('AUDIO', 'ShellAgentPluginInputAudio') + addOut('IMAGE', 'ShellAgentPluginSaveImage') + addOut('IMAGE', 'ShellAgentPluginSaveImages') + addOut('AUDIO', 'ShellAgentPluginSaveAudios') + addOut('AUDIO', 'ShellAgentPluginSaveAudio') + addOut('STRING', 'ShellAgentPluginOutputInteger') + addOut('STRING', 'ShellAgentPluginOutputFloat') + addOut('STRING', 'ShellAgentPluginOutputText') + }, + getCustomWidgets() { + return { + SHELLAGENT_AUDIOUPLOAD(node, inputName) { + const audioWidget = node.widgets.find( + (w) => w.name === "default_value" + ); + const audioUIWidget = node.widgets.find( + (w) => w.name === "audioUI" + ); + const onAudioWidgetUpdate = /* @__PURE__ */ __name(() => { + audioUIWidget.element.src = api.apiURL( + getResourceURL(...splitFilePath(audioWidget.value)) + ); + }, "onAudioWidgetUpdate"); + if (audioWidget.value) { + onAudioWidgetUpdate(); + } + audioWidget.callback = onAudioWidgetUpdate; + const onGraphConfigured = node.onGraphConfigured; + node.onGraphConfigured = function() { + onGraphConfigured?.apply(this, arguments); + if (audioWidget.value) { + onAudioWidgetUpdate(); + } + }; + const fileInput = document.createElement("input"); + fileInput.type = "file"; + fileInput.accept = "audio/*"; + fileInput.style.display = "none"; + fileInput.onchange = () => { + if (fileInput.files.length) { + uploadFileAudio(audioWidget, audioUIWidget, fileInput.files[0], true); + } + }; + const uploadWidget = node.addWidget( + "button", + inputName, + /* value=*/ + "", + () => { + fileInput.click(); + }, + { serialize: false } + ); + uploadWidget.label = "choose file to upload"; + return { widget: uploadWidget }; + } + }; + } }); function addMenuHandler(nodeType, cb) { @@ -599,4 +817,55 @@ function addLoadVideoCommon(nodeType, nodeData) { } }); }); +} + +function getResourceURL(subfolder, filename, type = "input") { + const params = [ + "filename=" + encodeURIComponent(filename), + "type=" + type, + "subfolder=" + subfolder, + app.getRandParam().substring(1) + ].join("&"); + return `/view?${params}`; +} + +function splitFilePath(path) { + const folder_separator = path.lastIndexOf("/"); + if (folder_separator === -1) { + return ["", path]; + } + return [ + path.substring(0, folder_separator), + path.substring(folder_separator + 1) + ]; +} + +async function uploadFileAudio(audioWidget, audioUIWidget, file2, updateNode, pasted = false) { + try { + const body = new FormData(); + body.append("image", file2); + if (pasted) body.append("subfolder", "pasted"); + const resp = await api.fetchApi("/upload/image", { + method: "POST", + body + }); + if (resp.status === 200) { + const data = await resp.json(); + let path = data.name; + if (data.subfolder) path = data.subfolder + "/" + path; + if (!audioWidget.options.values.includes(path)) { + audioWidget.options.values.push(path); + } + if (updateNode) { + audioUIWidget.element.src = api.apiURL( + getResourceURL(...splitFilePath(path)) + ); + audioWidget.value = path; + } + } else { + window.alert(resp.status + " - " + resp.statusText); + } + } catch (error) { + window.alert(error); + } } \ No newline at end of file