Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,8 @@ def archive_model_dtypes(model):
for name, module in model.named_modules():
for param_name, param in module.named_parameters(recurse=False):
setattr(module, f"{param_name}_comfy_model_dtype", param.dtype)
for buf_name, buf in module.named_buffers(recurse=False):
setattr(module, f"{buf_name}_comfy_model_dtype", buf.dtype)


def cleanup_models():
Expand Down
22 changes: 18 additions & 4 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ def __init__(self, model, load_device, offload_device, size=0, weight_inplace_up

self.patches = {}
self.backup = {}
self.backup_buffers = {}
self.object_patches = {}
self.object_patches_backup = {}
self.weight_wrapper_patches = {}
Expand Down Expand Up @@ -309,7 +310,7 @@ def get_free_memory(self, device):
return comfy.model_management.get_free_memory(device)

def get_clone_model_override(self):
return self.model, (self.backup, self.object_patches_backup, self.pinned)
return self.model, (self.backup, self.backup_buffers, self.object_patches_backup, self.pinned)

def clone(self, disable_dynamic=False, model_override=None):
class_ = self.__class__
Expand All @@ -336,7 +337,7 @@ def clone(self, disable_dynamic=False, model_override=None):

n.force_cast_weights = self.force_cast_weights

n.backup, n.object_patches_backup, n.pinned = model_override[1]
n.backup, n.backup_buffers, n.object_patches_backup, n.pinned = model_override[1]

# attachments
n.attachments = {}
Expand Down Expand Up @@ -1579,11 +1580,22 @@ def force_load_param(self, param_key, device_to):
weight, _, _ = get_key_weight(self.model, key)
if key not in self.backup:
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight, False)
comfy.utils.set_attr_param(self.model, key, weight.to(device=device_to))
self.model.model_loaded_weight_memory += weight.numel() * weight.element_size()
model_dtype = getattr(m, param + "_comfy_model_dtype", None)
casted_weight = weight.to(dtype=model_dtype, device=device_to)
comfy.utils.set_attr_param(self.model, key, casted_weight)
self.model.model_loaded_weight_memory += casted_weight.numel() * casted_weight.element_size()

move_weight_functions(m, device_to)

for key, buf in self.model.named_buffers(recurse=True):
if key not in self.backup_buffers:
self.backup_buffers[key] = buf
module, buf_name = comfy.utils.resolve_attr(self.model, key)
model_dtype = getattr(module, buf_name + "_comfy_model_dtype", None)
casted_buf = buf.to(dtype=model_dtype, device=device_to)
comfy.utils.set_attr_buffer(self.model, key, casted_buf)
self.model.model_loaded_weight_memory += casted_buf.numel() * casted_buf.element_size()

force_load_stat = f" Force pre-loaded {len(self.backup)} weights: {self.model.model_loaded_weight_memory // 1024} KB." if len(self.backup) > 0 else ""
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.{force_load_stat}")

Expand All @@ -1607,6 +1619,8 @@ def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=Fals
for key in list(self.backup.keys()):
bk = self.backup.pop(key)
comfy.utils.set_attr_param(self.model, key, bk.weight)
for key in list(self.backup_buffers.keys()):
comfy.utils.set_attr_buffer(self.model, key, self.backup_buffers.pop(key))
freed += self.model.model_loaded_weight_memory
self.model.model_loaded_weight_memory = 0

Expand Down
2 changes: 1 addition & 1 deletion comfy/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1127,7 +1127,7 @@ class ZImagePixelSpace(ZImage):
latent_format = latent_formats.ZImagePixelSpace

# Much lower memory than latent-space models (no VAE, small patches).
memory_usage_factor = 0.05 # TODO: figure out the optimal value for this.
memory_usage_factor = 0.03 # TODO: figure out the optimal value for this.

def get_model(self, state_dict, prefix="", device=None):
return model_base.ZImagePixelSpace(self, device=device)
Expand Down
19 changes: 15 additions & 4 deletions comfy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,20 +869,31 @@ def safetensors_header(safetensors_path, max_size=100*1024*1024):

ATTR_UNSET={}

def set_attr(obj, attr, value):
def resolve_attr(obj, attr):
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
prev = getattr(obj, attrs[-1], ATTR_UNSET)
return obj, attrs[-1]

def set_attr(obj, attr, value):
obj, name = resolve_attr(obj, attr)
prev = getattr(obj, name, ATTR_UNSET)
if value is ATTR_UNSET:
delattr(obj, attrs[-1])
delattr(obj, name)
else:
setattr(obj, attrs[-1], value)
setattr(obj, name, value)
return prev

def set_attr_param(obj, attr, value):
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))

def set_attr_buffer(obj, attr, value):
obj, name = resolve_attr(obj, attr)
prev = getattr(obj, name, ATTR_UNSET)
persistent = name not in getattr(obj, "_non_persistent_buffers_set", set())
obj.register_buffer(name, value, persistent=persistent)
return prev

def copy_to_param(obj, attr, value):
# inplace update tensor instead of replacing it
attrs = attr.split(".")
Expand Down
14 changes: 14 additions & 0 deletions comfy_api/latest/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -1240,6 +1240,19 @@ def as_dict(self):
return d


@comfytype(io_type="CURVE")
class Curve(ComfyTypeIO):
CurvePoint = tuple[float, float]
Type = list[CurvePoint]

class Input(WidgetInput):
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
socketless: bool=True, default: list[tuple[float, float]]=None, advanced: bool=None):
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
if default is None:
self.default = [(0.0, 0.0), (1.0, 1.0)]


DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
DYNAMIC_INPUT_LOOKUP[io_type] = func
Expand Down Expand Up @@ -2226,5 +2239,6 @@ def as_dict(self):
"PriceBadgeDepends",
"PriceBadge",
"BoundingBox",
"Curve",
"NodeReplace",
]
14 changes: 8 additions & 6 deletions execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,12 +876,14 @@ async def validate_inputs(prompt_id, prompt, item, validated):
continue
else:
try:
# Unwraps values wrapped in __value__ key. This is used to pass
# list widget value to execution, as by default list value is
# reserved to represent the connection between nodes.
if isinstance(val, dict) and "__value__" in val:
val = val["__value__"]
inputs[x] = val
# Unwraps values wrapped in __value__ key or typed wrapper.
# This is used to pass list widget values to execution,
# as by default list value is reserved to represent the
# connection between nodes.
if isinstance(val, dict):
if "__value__" in val:
val = val["__value__"]
inputs[x] = val

if input_type == "INT":
val = int(val)
Expand Down
Loading