diff --git a/nodes.py b/nodes.py index 4683514..225097a 100644 --- a/nodes.py +++ b/nodes.py @@ -67,15 +67,40 @@ def patch_weight_to_device(self, key, device_to=None, inplace_update=False): comfy.utils.set_attr_param(self.model, key, out_weight) def unpatch_model(self, device_to=None, unpatch_weights=True): + self.eject_model() if unpatch_weights: for p in self.model.parameters(): if is_torch_compatible(p): continue - patches = getattr(p, "patches", []) - if len(patches) > 0: + if len(getattr(p, "patches", [])) > 0: p.patches = [] - # TODO: Find another way to not unload after patches - return super().unpatch_model(device_to=device_to, unpatch_weights=unpatch_weights) + # Mirror of base unpatch_model's unpatch_weights block, skipping the + # self.model.to(device_to) walk that faults on mmap'd quantized + # tensors (#444). Tracking upstream at Comfy-Org/ComfyUI#14142. + self.unpatch_hooks() + self.unpin_all_weights() + if self.model.model_lowvram: + for m in self.model.modules(): + comfy.model_patcher.move_weight_functions(m, device_to) + comfy.model_patcher.wipe_lowvram_weight(m) + self.model.model_lowvram = False + self.model.lowvram_patch_counter = 0 + for k in list(self.backup.keys()): + bk = self.backup[k] + if bk.inplace_update: + comfy.utils.copy_to_param(self.model, k, bk.weight) + else: + comfy.utils.set_attr_param(self.model, k, bk.weight) + self.model.current_weight_patches_uuid = None + self.backup.clear() + if device_to is not None: + self.model.device = device_to + self.model.model_loaded_weight_memory = 0 + self.model.model_offload_buffer_memory = 0 + for m in self.model.modules(): + if hasattr(m, "comfy_patched_weights"): + del m.comfy_patched_weights + return super().unpatch_model(device_to=device_to, unpatch_weights=False) def pin_weight_to_device(self, key):