Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 20 additions & 15 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,13 @@ def lowvram_patch_counter(self):
return self.model.lowvram_patch_counter

def get_free_memory(self, device):
return comfy.model_management.get_free_memory(device)
#Prioritize batching (incl. CFG/conds etc) over keeping the model resident. In
#the vast majority of setups a little bit of offloading on the giant model more
#than pays for CFG. So return everything both torch and Aimdo could give us
aimdo_mem = 0
if comfy.memory_management.aimdo_enabled:
aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze()
return comfy.model_management.get_free_memory(device) + aimdo_mem

def get_clone_model_override(self):
return self.model, (self.backup, self.backup_buffers, self.object_patches_backup, self.pinned)
Expand Down Expand Up @@ -699,7 +705,7 @@ def unpin_all_weights(self):
for key in list(self.pinned):
self.unpin_weight(key)

def _load_list(self, prio_comfy_cast_weights=False, default_device=None):
def _load_list(self, for_dynamic=False, default_device=None):
loading = []
for n, m in self.model.named_modules():
default = False
Expand Down Expand Up @@ -727,8 +733,13 @@ def check_module_offload_mem(key):
return 0
module_offload_mem += check_module_offload_mem("{}.weight".format(n))
module_offload_mem += check_module_offload_mem("{}.bias".format(n))
prepend = (not hasattr(m, "comfy_cast_weights"),) if prio_comfy_cast_weights else ()
loading.append(prepend + (module_offload_mem, module_mem, n, m, params))
# Dynamic: small weights (<64KB) first, then larger weights prioritized by size.
# Non-dynamic: prioritize by module offload cost.
if for_dynamic:
sort_criteria = (module_offload_mem >= 64 * 1024, -module_offload_mem)
else:
sort_criteria = (module_offload_mem,)
loading.append(sort_criteria + (module_mem, n, m, params))
return loading

def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
Expand Down Expand Up @@ -1460,12 +1471,6 @@ def loaded_size(self):
vbar = self._vbar_get()
return (vbar.loaded_size() if vbar is not None else 0) + self.model.model_loaded_weight_memory

def get_free_memory(self, device):
#NOTE: on high condition / batch counts, estimate should have already vacated
#all non-dynamic models so this is safe even if its not 100% true that this
#would all be avaiable for inference use.
return comfy.model_management.get_total_memory(device) - self.model_size()

#Pinning is deferred to ops time. Assert against this API to avoid pin leaks.

def pin_weight_to_device(self, key):
Expand Down Expand Up @@ -1508,11 +1513,11 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False
if vbar is not None:
vbar.prioritize()

loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to)
loading.sort(reverse=True)
loading = self._load_list(for_dynamic=True, default_device=device_to)
loading.sort()

for x in loading:
_, _, _, n, m, params = x
*_, module_mem, n, m, params = x

def set_dirty(item, dirty):
if dirty or not hasattr(item, "_v_signature"):
Expand Down Expand Up @@ -1627,9 +1632,9 @@ def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=Fals
return freed

def partially_unload_ram(self, ram_to_unload):
loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device)
loading = self._load_list(for_dynamic=True, default_device=self.offload_device)
for x in loading:
_, _, _, _, m, _ = x
*_, m, _ = x
ram_to_unload -= comfy.pinned_memory.unpin_memory(m)
if ram_to_unload <= 0:
return
Expand Down
4 changes: 2 additions & 2 deletions comfy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,8 @@ def uncast_bias_weight(s, weight, bias, offload_stream):
return
os, weight_a, bias_a = offload_stream
device=None
#FIXME: This is not good RTTI
if not isinstance(weight_a, torch.Tensor):
#FIXME: This is really bad RTTI
if weight_a is not None and not isinstance(weight_a, torch.Tensor):
comfy_aimdo.model_vbar.vbar_unpin(s._v)
device = weight_a
if os is None:
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ alembic
SQLAlchemy
av>=14.2.0
comfy-kitchen>=0.2.7
comfy-aimdo>=0.2.4
comfy-aimdo>=0.2.5
requests

#non essential dependencies:
Expand Down
Loading