diff --git a/comfy/model_management.py b/comfy/model_management.py index 0e0e96672c53..0f5966371ccf 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -796,6 +796,8 @@ def archive_model_dtypes(model): for name, module in model.named_modules(): for param_name, param in module.named_parameters(recurse=False): setattr(module, f"{param_name}_comfy_model_dtype", param.dtype) + for buf_name, buf in module.named_buffers(recurse=False): + setattr(module, f"{buf_name}_comfy_model_dtype", buf.dtype) def cleanup_models(): diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index e380e406bd21..70f78a08978f 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -241,6 +241,7 @@ def __init__(self, model, load_device, offload_device, size=0, weight_inplace_up self.patches = {} self.backup = {} + self.backup_buffers = {} self.object_patches = {} self.object_patches_backup = {} self.weight_wrapper_patches = {} @@ -309,7 +310,7 @@ def get_free_memory(self, device): return comfy.model_management.get_free_memory(device) def get_clone_model_override(self): - return self.model, (self.backup, self.object_patches_backup, self.pinned) + return self.model, (self.backup, self.backup_buffers, self.object_patches_backup, self.pinned) def clone(self, disable_dynamic=False, model_override=None): class_ = self.__class__ @@ -336,7 +337,7 @@ def clone(self, disable_dynamic=False, model_override=None): n.force_cast_weights = self.force_cast_weights - n.backup, n.object_patches_backup, n.pinned = model_override[1] + n.backup, n.backup_buffers, n.object_patches_backup, n.pinned = model_override[1] # attachments n.attachments = {} @@ -1579,11 +1580,22 @@ def force_load_param(self, param_key, device_to): weight, _, _ = get_key_weight(self.model, key) if key not in self.backup: self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight, False) - comfy.utils.set_attr_param(self.model, key, weight.to(device=device_to)) - self.model.model_loaded_weight_memory += weight.numel() * weight.element_size() + model_dtype = getattr(m, param + "_comfy_model_dtype", None) + casted_weight = weight.to(dtype=model_dtype, device=device_to) + comfy.utils.set_attr_param(self.model, key, casted_weight) + self.model.model_loaded_weight_memory += casted_weight.numel() * casted_weight.element_size() move_weight_functions(m, device_to) + for key, buf in self.model.named_buffers(recurse=True): + if key not in self.backup_buffers: + self.backup_buffers[key] = buf + module, buf_name = comfy.utils.resolve_attr(self.model, key) + model_dtype = getattr(module, buf_name + "_comfy_model_dtype", None) + casted_buf = buf.to(dtype=model_dtype, device=device_to) + comfy.utils.set_attr_buffer(self.model, key, casted_buf) + self.model.model_loaded_weight_memory += casted_buf.numel() * casted_buf.element_size() + force_load_stat = f" Force pre-loaded {len(self.backup)} weights: {self.model.model_loaded_weight_memory // 1024} KB." if len(self.backup) > 0 else "" logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.{force_load_stat}") @@ -1607,6 +1619,8 @@ def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=Fals for key in list(self.backup.keys()): bk = self.backup.pop(key) comfy.utils.set_attr_param(self.model, key, bk.weight) + for key in list(self.backup_buffers.keys()): + comfy.utils.set_attr_buffer(self.model, key, self.backup_buffers.pop(key)) freed += self.model.model_loaded_weight_memory self.model.model_loaded_weight_memory = 0 diff --git a/comfy/supported_models.py b/comfy/supported_models.py index c0d3f387f35a..07feb31b345e 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1127,7 +1127,7 @@ class ZImagePixelSpace(ZImage): latent_format = latent_formats.ZImagePixelSpace # Much lower memory than latent-space models (no VAE, small patches). - memory_usage_factor = 0.05 # TODO: figure out the optimal value for this. + memory_usage_factor = 0.03 # TODO: figure out the optimal value for this. def get_model(self, state_dict, prefix="", device=None): return model_base.ZImagePixelSpace(self, device=device) diff --git a/comfy/utils.py b/comfy/utils.py index 0769cef447df..6e1d14419535 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -869,20 +869,31 @@ def safetensors_header(safetensors_path, max_size=100*1024*1024): ATTR_UNSET={} -def set_attr(obj, attr, value): +def resolve_attr(obj, attr): attrs = attr.split(".") for name in attrs[:-1]: obj = getattr(obj, name) - prev = getattr(obj, attrs[-1], ATTR_UNSET) + return obj, attrs[-1] + +def set_attr(obj, attr, value): + obj, name = resolve_attr(obj, attr) + prev = getattr(obj, name, ATTR_UNSET) if value is ATTR_UNSET: - delattr(obj, attrs[-1]) + delattr(obj, name) else: - setattr(obj, attrs[-1], value) + setattr(obj, name, value) return prev def set_attr_param(obj, attr, value): return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False)) +def set_attr_buffer(obj, attr, value): + obj, name = resolve_attr(obj, attr) + prev = getattr(obj, name, ATTR_UNSET) + persistent = name not in getattr(obj, "_non_persistent_buffers_set", set()) + obj.register_buffer(name, value, persistent=persistent) + return prev + def copy_to_param(obj, attr, value): # inplace update tensor instead of replacing it attrs = attr.split(".") diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 189d7d9bc39f..050031dc0546 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -1240,6 +1240,19 @@ def as_dict(self): return d +@comfytype(io_type="CURVE") +class Curve(ComfyTypeIO): + CurvePoint = tuple[float, float] + Type = list[CurvePoint] + + class Input(WidgetInput): + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, + socketless: bool=True, default: list[tuple[float, float]]=None, advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced) + if default is None: + self.default = [(0.0, 0.0), (1.0, 1.0)] + + DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {} def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]): DYNAMIC_INPUT_LOOKUP[io_type] = func @@ -2226,5 +2239,6 @@ def as_dict(self): "PriceBadgeDepends", "PriceBadge", "BoundingBox", + "Curve", "NodeReplace", ] diff --git a/execution.py b/execution.py index 75b021892de0..7ccdbf93e071 100644 --- a/execution.py +++ b/execution.py @@ -876,12 +876,14 @@ async def validate_inputs(prompt_id, prompt, item, validated): continue else: try: - # Unwraps values wrapped in __value__ key. This is used to pass - # list widget value to execution, as by default list value is - # reserved to represent the connection between nodes. - if isinstance(val, dict) and "__value__" in val: - val = val["__value__"] - inputs[x] = val + # Unwraps values wrapped in __value__ key or typed wrapper. + # This is used to pass list widget values to execution, + # as by default list value is reserved to represent the + # connection between nodes. + if isinstance(val, dict): + if "__value__" in val: + val = val["__value__"] + inputs[x] = val if input_type == "INT": val = int(val)