diff --git a/comfy_extras/nodes_rebatch.py b/comfy_extras/nodes_rebatch.py index 0a9daf272761..46c9c3105bc6 100644 --- a/comfy_extras/nodes_rebatch.py +++ b/comfy_extras/nodes_rebatch.py @@ -3,11 +3,10 @@ class LatentRebatch: @classmethod def INPUT_TYPES(s): - return {"required": { "latents": ("LATENT",), + return {"required": { "latents": ("LATENT", { "is_list": True }), "batch_size": ("INT", {"default": 1, "min": 1, "max": 64}), }} - RETURN_TYPES = ("LATENT",) - INPUT_IS_LIST = True + RETURN_TYPES = ("LATENT", ) OUTPUT_IS_LIST = (True, ) FUNCTION = "rebatch" @@ -54,8 +53,6 @@ def cat_batch(batch1, batch2): return result def rebatch(self, latents, batch_size): - batch_size = batch_size[0] - output_list = [] current_batch = (None, None, None) processed = 0 @@ -105,4 +102,4 @@ def rebatch(self, latents, batch_size): NODE_DISPLAY_NAME_MAPPINGS = { "RebatchLatents": "Rebatch Latents", -} \ No newline at end of file +} diff --git a/execution.py b/execution.py index 218a84c36df8..989c1a5b55fe 100644 --- a/execution.py +++ b/execution.py @@ -13,21 +13,44 @@ import comfy.model_management +def slice_lists_into_dict(d, i): + """ + get a slice of inputs, repeat last input when list isn't long enough + d={ "seed": [ 1, 2, 3 ], "steps": [ 4, 8 ] }, i=2 -> { "seed": 3, "steps": 8 } + """ + d_new = {} + for k, v in d.items(): + d_new[k] = v[i if len(v) > i else -1] + return d_new + def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}): valid_inputs = class_def.INPUT_TYPES() input_data_all = {} + required = valid_inputs.get("required", {}) + optional = valid_inputs.get("optional", {}) for x in inputs: input_data = inputs[x] - if isinstance(input_data, list): - input_unique_id = input_data[0] - output_index = input_data[1] + input_type = input_data["type"] + + input_def = required.get(x) + if input_def is None: + input_def = optional.get(x) + + use_value_as_list = input_def is not None and len(input_def) > 1 and input_def[1].get("is_list", False) + + if input_type == "link": + input_unique_id = input_data["origin_id"] + output_index = input_data["origin_slot"] if input_unique_id not in outputs: return None obj = outputs[input_unique_id][output_index] + if use_value_as_list: + obj = [obj] input_data_all[x] = obj else: - if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]): - input_data_all[x] = [input_data] + value = input_data["value"] + if input_def is not None: + input_data_all[x] = [value] if "hidden" in valid_inputs: h = valid_inputs["hidden"] @@ -39,37 +62,23 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da input_data_all[x] = [extra_data['extra_pnginfo']] if h[x] == "UNIQUE_ID": input_data_all[x] = [unique_id] + return input_data_all def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): - # check if node wants the lists - intput_is_list = False - if hasattr(obj, "INPUT_IS_LIST"): - intput_is_list = obj.INPUT_IS_LIST - - max_len_input = max([len(x) for x in input_data_all.values()]) - - # get a slice of inputs, repeat last input when list isn't long enough - def slice_dict(d, i): - d_new = dict() - for k,v in d.items(): - d_new[k] = v[i if len(v) > i else -1] - return d_new - results = [] - if intput_is_list: + max_len_input = max([len(x) for x in input_data_all.values()]) + + for i in range(max_len_input): if allow_interrupt: nodes.before_node_execution() - results.append(getattr(obj, func)(**input_data_all)) - else: - for i in range(max_len_input): - if allow_interrupt: - nodes.before_node_execution() - results.append(getattr(obj, func)(**slice_dict(input_data_all, i))) + + args = slice_lists_into_dict(input_data_all, i) + results.append(getattr(obj, func)(**args)) + return results def get_output_data(obj, input_data_all): - results = [] uis = [] return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True) @@ -120,10 +129,10 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute for x in inputs: input_data = inputs[x] + input_type = input_data["type"] - if isinstance(input_data, list): - input_unique_id = input_data[0] - output_index = input_data[1] + if input_type == "link": + input_unique_id = input_data["origin_id"] if input_unique_id not in outputs: result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui) if result[0] is not True: @@ -192,9 +201,9 @@ def recursive_will_execute(prompt, outputs, current_item): for x in inputs: input_data = inputs[x] - if isinstance(input_data, list): - input_unique_id = input_data[0] - output_index = input_data[1] + input_type = input_data["type"] + if input_type == "link": + input_unique_id = input_data["origin_id"] if input_unique_id not in outputs: will_execute += recursive_will_execute(prompt, outputs, input_unique_id) @@ -235,10 +244,10 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item elif inputs == old_prompt[unique_id]['inputs']: for x in inputs: input_data = inputs[x] + input_type = input_data["type"] - if isinstance(input_data, list): - input_unique_id = input_data[0] - output_index = input_data[1] + if input_type == "link": + input_unique_id = input_data["origin_id"] if input_unique_id in outputs: to_delete = recursive_output_delete_if_changed(prompt, old_prompt, outputs, input_unique_id) else: @@ -366,6 +375,150 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): comfy.model_management.soft_empty_cache() +def validate_link(prompt, x, val, info, validated): + type_input = info[0] + + o_id = val.get("origin_id", None) + o_slot = val.get("origin_slot", None) + + if o_id is None or o_slot is None: + error = { + "type": "bad_linked_input", + "message": "Bad linked input, must be a dictionary like { type: 'link', origin_id: 1, origin_slot: 1 }", + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val + } + } + return (False, error) + + o_class_type = prompt[o_id]['class_type'] + r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES + if r[o_slot] != type_input: + received_type = r[val[1]] + details = f"{x}, {received_type} != {type_input}" + error = { + "type": "return_type_mismatch", + "message": "Return type mismatch between linked nodes", + "details": details, + "extra_info": { + "input_name": x, + "input_config": info, + "received_type": received_type, + "linked_node": val + } + } + return (False, error) + try: + r = validate_inputs(prompt, o_id, validated) + if r[0] is False: + # `r` will be set in `validated[o_id]` already + return (False, None) + except Exception as ex: + typ, _, tb = sys.exc_info() + exception_type = full_type_name(typ) + reasons = [{ + "type": "exception_during_inner_validation", + "message": "Exception when validating inner node", + "details": str(ex), + "extra_info": { + "input_name": x, + "input_config": info, + "exception_message": str(ex), + "exception_type": exception_type, + "traceback": traceback.format_tb(tb), + "linked_node": val, + "linked_node_inputs": prompt[o_id] + } + }] + validated[o_id] = (False, reasons, o_id) + return (False, None) + + return (True, val) + + +def validate_value(inputs, unique_id, x, val, info, obj_class): + type_input = info[0] + result_val = val + + try: + if type_input == "INT": + result_val = int(val) + if type_input == "FLOAT": + result_val = float(val) + if type_input == "STRING": + result_val = str(val) + except Exception as ex: + error = { + "type": "invalid_input_type", + "message": f"Failed to convert an input value to a {type_input} value", + "details": f"{x}, {val}, {ex}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + "exception_message": str(ex) + } + } + return (False, error) + + if len(info) > 1: + if "min" in info[1] and val < info[1]["min"]: + error = { + "type": "value_smaller_than_min", + "message": "Value {} smaller than min of {}".format(val, info[1]["min"]), + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + } + } + return (False, error) + if "max" in info[1] and val > info[1]["max"]: + error = { + "type": "value_bigger_than_max", + "message": "Value {} bigger than max of {}".format(val, info[1]["max"]), + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + } + } + return (False, error) + else: + # Validate combo widget + if isinstance(type_input, list): + if val not in type_input: + input_config = info + list_info = "" + + # Don't send back gigantic lists like if they're lots of + # scanned model filepaths + if len(type_input) > 20: + list_info = f"(list of length {len(type_input)})" + input_config = None + else: + list_info = str(type_input) + + error = { + "type": "value_not_in_list", + "message": "Value not in list", + "details": f"{x}: '{val}' not in {list_info}", + "extra_info": { + "input_name": x, + "input_config": input_config, + "received_value": val, + } + } + return (False, error) + + return (True, result_val) + + def validate_inputs(prompt, item, validated): unique_id = item if unique_id in validated: @@ -396,168 +549,84 @@ def validate_inputs(prompt, item, validated): val = inputs[x] info = required_inputs[x] - type_input = info[0] - if isinstance(val, list): - if len(val) != 2: - error = { - "type": "bad_linked_input", - "message": "Bad linked input, must be a length-2 list of [node_id, slot_index]", - "details": f"{x}", - "extra_info": { - "input_name": x, - "input_config": info, - "received_value": val - } + + input_type = None + if isinstance(val, dict): + input_type = val.get("type", None) + + if input_type not in ["link", "value"]: + error = { + "type": "bad_input_format", + "message": "Bad input format, must be a dictionary with 'type' set to 'link' or 'value'", + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val } - errors.append(error) + } + errors.append(error) + continue + + if input_type == "link": + result = validate_link(prompt, x, val, info, validated) + if result[0] is False: + valid = False + if result[1] is not None: + errors.append(result[1]) continue - o_id = val[0] - o_class_type = prompt[o_id]['class_type'] - r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES - if r[val[1]] != type_input: - received_type = r[val[1]] - details = f"{x}, {received_type} != {type_input}" + inputs[x] = result[1] + + elif input_type == "value": + inner_val = val.get("value", None) + if inner_val is None: error = { - "type": "return_type_mismatch", - "message": "Return type mismatch between linked nodes", - "details": details, + "type": "bad_value_input", + "message": "Bad value input, must be a dictionary like { type: 'value', value: 42 }", + "details": f"{x}, {val}", "extra_info": { "input_name": x, "input_config": info, - "received_type": received_type, - "linked_node": val + "received_value": val, } } - errors.append(error) - continue - try: - r = validate_inputs(prompt, o_id, validated) - if r[0] is False: - # `r` will be set in `validated[o_id]` already - valid = False - continue - except Exception as ex: - typ, _, tb = sys.exc_info() - valid = False - exception_type = full_type_name(typ) - reasons = [{ - "type": "exception_during_inner_validation", - "message": "Exception when validating inner node", - "details": str(ex), - "extra_info": { - "input_name": x, - "input_config": info, - "exception_message": str(ex), - "exception_type": exception_type, - "traceback": traceback.format_tb(tb), - "linked_node": val - } - }] - validated[o_id] = (False, reasons, o_id) + return (False, error) + + result = validate_value(inputs, unique_id, x, inner_val, info, obj_class) + + if result[0] is False: + errors.append(result[1]) continue - else: - try: - if type_input == "INT": - val = int(val) - inputs[x] = val - if type_input == "FLOAT": - val = float(val) - inputs[x] = val - if type_input == "STRING": - val = str(val) - inputs[x] = val - except Exception as ex: + + inputs[x] = { "type": "value", "value": result[1] } + + if hasattr(obj_class, "VALIDATE_INPUTS"): + input_data_all = get_input_data(inputs, obj_class, unique_id) + #ret = obj_class.VALIDATE_INPUTS(**input_data_all) + ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS") + for i, r in enumerate(ret): + if r is not True: + details = "" + if r is not False: + details += str(r) + + input_data_formatted = {} + if input_data_all is not None: + input_data_formatted = {} + for name, inputList in input_data_all.items(): + input_data_formatted[name] = [format_value(x) for x in inputList] + error = { - "type": "invalid_input_type", - "message": f"Failed to convert an input value to a {type_input} value", - "details": f"{x}, {val}, {ex}", + "type": "custom_validation_failed", + "message": "Custom validation failed for node", + "details": details, "extra_info": { - "input_name": x, "input_config": info, - "received_value": val, - "exception_message": str(ex) + "received_inputs": input_data_formatted, } } errors.append(error) - continue - - if len(info) > 1: - if "min" in info[1] and val < info[1]["min"]: - error = { - "type": "value_smaller_than_min", - "message": "Value {} smaller than min of {}".format(val, info[1]["min"]), - "details": f"{x}", - "extra_info": { - "input_name": x, - "input_config": info, - "received_value": val, - } - } - errors.append(error) - continue - if "max" in info[1] and val > info[1]["max"]: - error = { - "type": "value_bigger_than_max", - "message": "Value {} bigger than max of {}".format(val, info[1]["max"]), - "details": f"{x}", - "extra_info": { - "input_name": x, - "input_config": info, - "received_value": val, - } - } - errors.append(error) - continue - - if hasattr(obj_class, "VALIDATE_INPUTS"): - input_data_all = get_input_data(inputs, obj_class, unique_id) - #ret = obj_class.VALIDATE_INPUTS(**input_data_all) - ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS") - for i, r in enumerate(ret): - if r is not True: - details = f"{x}" - if r is not False: - details += f" - {str(r)}" - - error = { - "type": "custom_validation_failed", - "message": "Custom validation failed for node", - "details": details, - "extra_info": { - "input_name": x, - "input_config": info, - "received_value": val, - } - } - errors.append(error) - continue - else: - if isinstance(type_input, list): - if val not in type_input: - input_config = info - list_info = "" - - # Don't send back gigantic lists like if they're lots of - # scanned model filepaths - if len(type_input) > 20: - list_info = f"(list of length {len(type_input)})" - input_config = None - else: - list_info = str(type_input) - - error = { - "type": "value_not_in_list", - "message": "Value not in list", - "details": f"{x}: '{val}' not in {list_info}", - "extra_info": { - "input_name": x, - "input_config": input_config, - "received_value": val, - } - } - errors.append(error) - continue if len(errors) > 0 or valid is not True: ret = (False, errors, unique_id) @@ -644,7 +713,7 @@ def validate_prompt(prompt): node_errors[node_id]["dependent_outputs"].append(o) print("Output will be ignored") - if len(good_outputs) == 0: + if len(good_outputs) == 0 or node_errors: errors_list = [] for o, errors in errors: for error in errors: diff --git a/nodes.py b/nodes.py index b057504edae3..2bc76a5b3f9f 100644 --- a/nodes.py +++ b/nodes.py @@ -1083,41 +1083,84 @@ class LoadImage: @classmethod def INPUT_TYPES(s): input_dir = folder_paths.get_input_directory() - files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] + output_dir = folder_paths.get_output_directory() + file_dict = {} + file_dict["input"] = sorted(f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))) + file_dict["output"] = sorted(f for f in os.listdir(output_dir) if os.path.isfile(os.path.join(output_dir, f))) return {"required": - {"image": (sorted(files), )}, + {"images": ("MULTIIMAGEUPLOAD", { "filepaths": file_dict } )}, } CATEGORY = "image" RETURN_TYPES = ("IMAGE", "MASK") - FUNCTION = "load_image" - def load_image(self, image): - image_path = folder_paths.get_annotated_filepath(image) - i = Image.open(image_path) - i = ImageOps.exif_transpose(i) - image = i.convert("RGB") - image = np.array(image).astype(np.float32) / 255.0 - image = torch.from_numpy(image)[None,] - if 'A' in i.getbands(): - mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 - mask = 1. - torch.from_numpy(mask) - else: - mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") - return (image, mask) + FUNCTION = "load_images" + + def load_images(self, images): + output_images = [] + output_masks = [] + + loaded_images = [] + + for idx in range(len(images)): + image_path = folder_paths.get_annotated_filepath(images[idx]) + i = Image.open(image_path) + i = ImageOps.exif_transpose(i) + loaded_images.append(i) + + min_size = float('inf') + min_image = None + + for image in loaded_images: + size = image.size[0] * image.size[1] + if size < min_size: + min_size = size + min_image = image + + for idx in range(len(images)): + i = loaded_images[idx] + + if i != min_image: + i = i.resize(min_image.size) + + image = i.convert("RGB") + + image = np.array(image).astype(np.float32) / 255.0 + image = torch.from_numpy(image)[None,] + if 'A' in i.getbands(): + mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 + mask = 1. - torch.from_numpy(mask) + else: + mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") + + output_images.append(image) + output_masks.append(mask) + + return (torch.cat(output_images), torch.cat(output_masks), ) @classmethod - def IS_CHANGED(s, image): - image_path = folder_paths.get_annotated_filepath(image) - m = hashlib.sha256() - with open(image_path, 'rb') as f: - m.update(f.read()) - return m.digest().hex() + def IS_CHANGED(s, images): + hashes = [] + + for image in images: + image_path = folder_paths.get_annotated_filepath(image) + m = hashlib.sha256() + with open(image_path, 'rb') as f: + m.update(f.read()) + hashes.append(m.digest().hex()) + + return hashes @classmethod - def VALIDATE_INPUTS(s, image): - if not folder_paths.exists_annotated_filepath(image): - return "Invalid image file: {}".format(image) + def VALIDATE_INPUTS(s, images): + invalid = [] + + for image in images: + if not folder_paths.exists_annotated_filepath(image): + invalid.append(image) + + if len(invalid) > 0: + return "Invalid image file(s): {}".format(", ".join(invalid)) return True @@ -1193,7 +1236,6 @@ def upscale(self, image, upscale_method, width, height, crop): return (s,) class ImageInvert: - @classmethod def INPUT_TYPES(s): return {"required": { "image": ("IMAGE",)}} @@ -1209,7 +1251,6 @@ def invert(self, image): class ImagePadForOutpaint: - @classmethod def INPUT_TYPES(s): return { diff --git a/web/extensions/core/uploadImage.js b/web/extensions/core/uploadImage.js index 45fabb78ed74..614926f12882 100644 --- a/web/extensions/core/uploadImage.js +++ b/web/extensions/core/uploadImage.js @@ -5,8 +5,10 @@ import { app } from "/scripts/app.js"; app.registerExtension({ name: "Comfy.UploadImage", async beforeRegisterNodeDef(nodeType, nodeData, app) { - if (nodeData.name === "LoadImage" || nodeData.name === "LoadImageMask") { + switch (nodeData.name) { + case "LoadImageMask": nodeData.input.required.upload = ["IMAGEUPLOAD"]; + break; } }, }); diff --git a/web/scripts/app.js b/web/scripts/app.js index 385a54579552..f329ab1318bb 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1110,9 +1110,11 @@ export class ComfyApp { for (const inputName in inputs) { const inputData = inputs[inputName]; const type = inputData[0]; + const options = inputData[1] || {}; + const inputShape = options.is_list ? LiteGraph.GRID_SHAPE : LiteGraph.CIRCLE_SHAPE; if(inputData[1]?.forceInput) { - this.addInput(inputName, type); + this.addInput(inputName, type, { shape: inputShape }); } else { if (Array.isArray(type)) { // Enums @@ -1125,7 +1127,7 @@ export class ComfyApp { Object.assign(config, widgets[type](this, inputName, inputData, app) || {}); } else { // Node connection inputs - this.addInput(inputName, type); + this.addInput(inputName, type, { shape: inputShape }); } } } @@ -1133,7 +1135,7 @@ export class ComfyApp { for (const o in nodeData["output"]) { const output = nodeData["output"][o]; const outputName = nodeData["output_name"][o] || output; - const outputShape = nodeData["output_is_list"][o] ? LiteGraph.GRID_SHAPE : LiteGraph.CIRCLE_SHAPE ; + const outputShape = nodeData["output_is_list"][o] ? LiteGraph.GRID_SHAPE : LiteGraph.CIRCLE_SHAPE; this.addOutput(outputName, output, { shape: outputShape }); } @@ -1312,7 +1314,8 @@ export class ComfyApp { for (const i in widgets) { const widget = widgets[i]; if (!widget.options || widget.options.serialize !== false) { - inputs[widget.name] = widget.serializeValue ? await widget.serializeValue(n, i) : widget.value; + const value = widget.serializeValue ? await widget.serializeValue(n, i) : widget.value; + inputs[widget.name] = { type: "value", value } } } } @@ -1332,7 +1335,11 @@ export class ComfyApp { } if (link) { - inputs[node.inputs[i].name] = [String(link.origin_id), parseInt(link.origin_slot)]; + inputs[node.inputs[i].name] = { + type: "link", + origin_id: String(link.origin_id), + origin_slot: parseInt(link.origin_slot) + }; } } } @@ -1376,6 +1383,9 @@ export class ComfyApp { message += "\n" + nodeError.class_type + ":" for (const errorReason of nodeError.errors) { message += "\n - " + errorReason.message + ": " + errorReason.details + if (errorReason.extra_info?.traceback) { + message += "\n" + errorReason.extra_info.traceback.join("") + } } } return message diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index dfa26aef430e..53b3d5c9b050 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -246,81 +246,261 @@ function addMultilineWidget(node, name, opts, app) { return { minWidth: 400, minHeight: 200, widget }; } -export const ComfyWidgets = { - "INT:seed": seedWidget, - "INT:noise_seed": seedWidget, - FLOAT(node, inputName, inputData) { - const { val, config } = getNumberDefaults(inputData, 0.5); - return { widget: node.addWidget("number", inputName, val, () => {}, config) }; - }, - INT(node, inputName, inputData) { - const { val, config } = getNumberDefaults(inputData, 1); - Object.assign(config, { precision: 0 }); - return { - widget: node.addWidget( - "number", - inputName, - val, - function (v) { - const s = this.options.step / 10; - this.value = Math.round(v / s) * s; - }, - config - ), +const FLOAT = (node, inputName, inputData) => { + const { val, config } = getNumberDefaults(inputData, 0.5); + return { widget: node.addWidget("number", inputName, val, () => {}, config) }; +} + +const INT = (node, inputName, inputData) => { + const { val, config } = getNumberDefaults(inputData, 1); + Object.assign(config, { precision: 0 }); + return { + widget: node.addWidget( + "number", + inputName, + val, + function (v) { + const s = this.options.step / 10; + this.value = Math.round(v / s) * s; + }, + config + ), + }; +} + +const STRING = (node, inputName, inputData, app) => { + const defaultVal = inputData[1].default || ""; + const multiline = !!inputData[1].multiline; + + if (multiline) { + return addMultilineWidget(node, inputName, { defaultVal, ...inputData[1] }, app); + } else { + return { widget: node.addWidget("text", inputName, defaultVal, () => {}, {}) }; + } +} + +const COMBO = (node, inputName, inputData) => { + const type = inputData[0]; + let defaultValue = type[0]; + let options = inputData[1] || {} + if (options.default) { + defaultValue = options.default + } + + if (options.is_list) { + defaultValue = [defaultValue] + const widget = node.addWidget("text", inputName, defaultValue, () => {}, { values: type }) + widget.disabled = true; + return { widget }; + } + else { + return { widget: node.addWidget("combo", inputName, defaultValue, () => {}, { values: type }) }; + } +} + +const IMAGEUPLOAD = (node, inputName, inputData, app) => { + const imageWidget = node.widgets.find((w) => w.name === "image"); + let uploadWidget; + + function showImage(name) { + const img = new Image(); + img.onload = () => { + node.imgs = [img]; + app.graph.setDirtyCanvas(true); }; - }, - STRING(node, inputName, inputData, app) { - const defaultVal = inputData[1].default || ""; - const multiline = !!inputData[1].multiline; - - if (multiline) { - return addMultilineWidget(node, inputName, { defaultVal, ...inputData[1] }, app); - } else { - return { widget: node.addWidget("text", inputName, defaultVal, () => {}, {}) }; + let folder_separator = name.lastIndexOf("/"); + let subfolder = ""; + if (folder_separator > -1) { + subfolder = name.substring(0, folder_separator); + name = name.substring(folder_separator + 1); } - }, - COMBO(node, inputName, inputData) { - const type = inputData[0]; - let defaultValue = type[0]; - if (inputData[1] && inputData[1].default) { - defaultValue = inputData[1].default; + img.src = `/view?filename=${name}&type=input&subfolder=${subfolder}`; + node.setSizeForImage?.(); + } + + var default_value = imageWidget.value; + Object.defineProperty(imageWidget, "value", { + set : function(value) { + this._real_value = value; + }, + + get : function() { + let value = ""; + if (this._real_value) { + value = this._real_value; + } else { + return default_value; + } + + if (value.filename) { + let real_value = value; + value = ""; + if (real_value.subfolder) { + value = real_value.subfolder + "/"; + } + + value += real_value.filename; + + if(real_value.type && real_value.type !== "input") + value += ` [${real_value.type}]`; + } + return value; } - return { widget: node.addWidget("combo", inputName, defaultValue, () => {}, { values: type }) }; - }, - IMAGEUPLOAD(node, inputName, inputData, app) { - const imageWidget = node.widgets.find((w) => w.name === "image"); - let uploadWidget; - - function showImage(name) { - const img = new Image(); - img.onload = () => { - node.imgs = [img]; - app.graph.setDirtyCanvas(true); - }; + }); + + // Add our own callback to the combo widget to render an image when it changes + const cb = node.callback; + imageWidget.callback = function () { + showImage(imageWidget.value); + if (cb) { + return cb.apply(this, arguments); + } + }; + + // On load if we have a value then render the image + // The value isnt set immediately so we need to wait a moment + // No change callbacks seem to be fired on initial setting of the value + requestAnimationFrame(() => { + if (imageWidget.value) { + showImage(imageWidget.value); + } + }); + + async function uploadFile(file, updateNode) { + try { + // Wrap file in formdata so it includes filename + const body = new FormData(); + body.append("image", file); + const resp = await fetch("/upload/image", { + method: "POST", + body, + }); + + if (resp.status === 200) { + const data = await resp.json(); + // Add the file as an option and update the widget value + if (!imageWidget.options.values.includes(data.name)) { + imageWidget.options.values.push(data.name); + } + + if (updateNode) { + showImage(data.name); + + imageWidget.value = data.name; + } + } else { + alert(resp.status + " - " + resp.statusText); + } + } catch (error) { + alert(error); + } + } + + const fileInput = document.createElement("input"); + Object.assign(fileInput, { + type: "file", + accept: "image/jpeg,image/png,image/webp", + style: "display: none", + onchange: async () => { + if (fileInput.files.length) { + await uploadFile(fileInput.files[0], true); + } + }, + }); + document.body.append(fileInput); + + // Create the button widget for selecting the files + uploadWidget = node.addWidget("button", "choose file to upload", "image", () => { + fileInput.value = null; + fileInput.click(); + }, { serialize: false }); + + // Add handler to check if an image is being dragged over our node + node.onDragOver = function (e) { + if (e.dataTransfer && e.dataTransfer.items) { + const image = [...e.dataTransfer.items].find((f) => f.kind === "file" && f.type.startsWith("image/")); + return !!image; + } + + return false; + }; + + // On drop upload files + node.onDragDrop = function (e) { + console.log("onDragDrop called"); + let handled = false; + for (const file of e.dataTransfer.files) { + if (file.type.startsWith("image/")) { + uploadFile(file, !handled); // Dont await these, any order is fine, only update on first one + handled = true; + } + } + + return handled; + }; + + return { widget: uploadWidget }; +} + +async function loadImageAsync(imageURL) { + return new Promise((resolve) => { + const e = new Image(); + e.setAttribute('crossorigin', 'anonymous'); + e.addEventListener("load", () => { resolve(e); }); + e.src = imageURL; + return e; + }); +} + +const MULTIIMAGEUPLOAD = (node, inputName, inputData, app) => { + let filepaths = { input: [], output: [] } + + if (inputData[1] && inputData[1].filepaths) { + filepaths = inputData[1].filepaths + } + + const update = function(v) { + this.value = v + } + + const imagesWidget = node.addWidget("combo", inputName, inputData, update, { values: filepaths["input"] }) + imagesWidget._filepaths = filepaths + imagesWidget._entries = filepaths["input"] + + async function showImages(names) { + node.imgs = [] + + for (const name of names) { let folder_separator = name.lastIndexOf("/"); let subfolder = ""; if (folder_separator > -1) { subfolder = name.substring(0, folder_separator); name = name.substring(folder_separator + 1); } - img.src = `/view?filename=${name}&type=input&subfolder=${subfolder}${app.getPreviewFormatParam()}`; + const src = `/view?filename=${name}&type=input&subfolder=${subfolder}${app.getPreviewFormatParam()}`; + const img = await loadImageAsync(src); + node.imgs.push(img) + node.imageIndex = null; node.setSizeForImage?.(); + app.graph.setDirtyCanvas(true); } + } - var default_value = imageWidget.value; - Object.defineProperty(imageWidget, "value", { - set : function(value) { - this._real_value = value; - }, + var default_value = imagesWidget.value; + Object.defineProperty(imagesWidget, "value", { + set : function(value) { + if (typeof value === "string") { + value = [value] + } + this._real_value = value; + }, - get : function() { - let value = ""; - if (this._real_value) { - value = this._real_value; - } else { - return default_value; - } + get : function() { + this._real_value ||= [] + const result = [] + + for (const value of this._real_value) { if (value.filename) { let real_value = value; value = ""; @@ -333,29 +513,36 @@ export const ComfyWidgets = { if(real_value.type && real_value.type !== "input") value += ` [${real_value.type}]`; } - return value; + + result.push(value) } - }); - // Add our own callback to the combo widget to render an image when it changes - const cb = node.callback; - imageWidget.callback = function () { - showImage(imageWidget.value); + this._real_value = result + return this._real_value; + } + }); + + // Add our own callback to the combo widget to render an image when it changes + const cb = node.callback; + imagesWidget.callback = () => { + showImages(imagesWidget.value).then(() => { if (cb) { return cb.apply(this, arguments); } - }; + }) + }; - // On load if we have a value then render the image - // The value isnt set immediately so we need to wait a moment - // No change callbacks seem to be fired on initial setting of the value - requestAnimationFrame(() => { - if (imageWidget.value) { - showImage(imageWidget.value); - } - }); + // On load if we have a value then render the image + // The value isnt set immediately so we need to wait a moment + // No change callbacks seem to be fired on initial setting of the value + requestAnimationFrame(async () => { + if (Array.isArray(imagesWidget.value) && imagesWidget.value.length > 0) { + await showImages(imagesWidget.value); + } + }); - async function uploadFile(file, updateNode) { + async function uploadFiles(files, updateNode) { + for (const file of files) { try { // Wrap file in formdata so it includes filename const body = new FormData(); @@ -367,15 +554,8 @@ export const ComfyWidgets = { if (resp.status === 200) { const data = await resp.json(); - // Add the file as an option and update the widget value - if (!imageWidget.options.values.includes(data.name)) { - imageWidget.options.values.push(data.name); - } - if (updateNode) { - showImage(data.name); - - imageWidget.value = data.name; + imagesWidget.value.push(data.name) } } else { alert(resp.status + " - " + resp.statusText); @@ -385,49 +565,260 @@ export const ComfyWidgets = { } } - const fileInput = document.createElement("input"); - Object.assign(fileInput, { - type: "file", - accept: "image/jpeg,image/png,image/webp", - style: "display: none", - onchange: async () => { - if (fileInput.files.length) { - await uploadFile(fileInput.files[0], true); - } - }, - }); - document.body.append(fileInput); + if (updateNode) { + await showImages(imagesWidget.value); + } + } + + const fileInput = document.createElement("input"); + Object.assign(fileInput, { + type: "file", + multiple: "multiple", + accept: "image/jpeg,image/png,image/webp", + style: "display: none", + onchange: async () => { + if (fileInput.files.length) { + await uploadFiles(fileInput.files, true); + } + }, + }); + document.body.append(fileInput); + + // Create the button widget for selecting the files + const pickWidget = node.addWidget("button", "pick files from ComfyUI folders", "images", () => { + const graphCanvas = LiteGraph.LGraphCanvas.active_canvas + if (graphCanvas == null) + return; - // Create the button widget for selecting the files - uploadWidget = node.addWidget("button", "choose file to upload", "image", () => { - fileInput.click(); + if (imagesWidget.panel != null) + return + + imagesWidget.panel = graphCanvas.createPanel("Pick Images", { closable: true }); + imagesWidget.panel.onClose = () => { + imagesWidget.panel = null; + } + imagesWidget.panel.node = node; + imagesWidget.panel.classList.add("multiimageupload_dialog"); + const swap = (arr, i, j) => { + const temp = arr[i]; + arr[i] = arr[j]; + arr[j] = temp; + } + + const rootHtml = ` +