From c32be4395905f29fbea2f94f2d0561a3989c645e Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sat, 4 May 2024 22:31:05 -0400 Subject: [PATCH 1/8] avoid copying model back from cuda to cpu --- .../model_manager/load/model_cache/model_cache_base.py | 1 + .../load/model_cache/model_cache_default.py | 10 +++++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py index a8c2dd3e92e..bd9286d79ba 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py @@ -46,6 +46,7 @@ class CacheRecord(Generic[T]): key: str model: T + state_dict: Optional[Dict[str, torch.Tensor]] # this is a copy that stays in CPU size: int loaded: bool = False _locks: int = 0 diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index 2ffe954e11e..25983cd463f 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -162,7 +162,8 @@ def put( if key in self._cached_models: return self.make_room(size) - cache_record = CacheRecord(key, model, size) + state_dict = model.state_dict() if hasattr(model, "state_dict") else None + cache_record = CacheRecord(key=key, model=model, state_dict=state_dict, size=size) self._cached_models[key] = cache_record self._cache_stack.append(key) @@ -267,6 +268,13 @@ def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device start_model_to_time = time.time() snapshot_before = self._capture_memory_snapshot() try: + if target_device == self.storage_device: + cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True) + else: + new_dict: Dict[str, torch.Tensor] = {} + for k, v in cache_entry.state_dict.items(): + new_dict[k] = v.to(torch.device(target_device), copy=True) + cache_entry.model.load_state_dict(new_dict, assign=True) cache_entry.model.to(target_device) except Exception as e: # blow away cache entry self._delete_cache_entry(cache_entry) From f8d5f13ef32b1156f87fe4145e89a8cc522e691a Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Tue, 7 May 2024 00:44:16 -0400 Subject: [PATCH 2/8] handle models that don't have state dicts --- .../load/model_cache/model_cache_default.py | 33 ++++++++++++++----- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index 25983cd463f..41a19bb81d1 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -162,7 +162,12 @@ def put( if key in self._cached_models: return self.make_room(size) - state_dict = model.state_dict() if hasattr(model, "state_dict") else None + + if isinstance(model, torch.nn.Module): + state_dict = model.state_dict() + assert model.device == self.storage_device + else: + state_dict = None cache_record = CacheRecord(key=key, model=model, state_dict=state_dict, size=size) self._cached_models[key] = cache_record self._cache_stack.append(key) @@ -265,16 +270,28 @@ def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device if torch.device(source_device).type == torch.device(target_device).type: return + # This roundabout method for moving the model around is done to avoid + # the cost of moving the model from RAM to VRAM and then back from VRAM to RAM. + # When moving to VRAM, we copy (not move) each element of the state dict from + # RAM to a new state dict in VRAM, and then inject it into the model. + # This operation is slightly faster than running `to()` on the whole model. + # + # When the model needs to be removed from VRAM we simply delete the copy + # of the state dict in VRAM, and reinject the state dict that is cached + # in RAM into the model. So this operation is very fast. start_model_to_time = time.time() snapshot_before = self._capture_memory_snapshot() + try: - if target_device == self.storage_device: - cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True) - else: - new_dict: Dict[str, torch.Tensor] = {} - for k, v in cache_entry.state_dict.items(): - new_dict[k] = v.to(torch.device(target_device), copy=True) - cache_entry.model.load_state_dict(new_dict, assign=True) + if cache_entry.state_dict is not None: + assert hasattr(cache_entry.model, "load_state_dict") + if target_device == self.storage_device: + cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True) + else: + new_dict: Dict[str, torch.Tensor] = {} + for k, v in cache_entry.state_dict.items(): + new_dict[k] = v.to(torch.device(target_device), copy=True) + cache_entry.model.load_state_dict(new_dict, assign=True) cache_entry.model.to(target_device) except Exception as e: # blow away cache entry self._delete_cache_entry(cache_entry) From 33590ef7a375260284f046ff38040302f9a6b28a Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sun, 12 May 2024 20:31:00 -0600 Subject: [PATCH 3/8] add assertions that models need a `device()` method --- .../model_manager/load/model_cache/model_cache_default.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index 41a19bb81d1..d73d943b991 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -165,7 +165,7 @@ def put( if isinstance(model, torch.nn.Module): state_dict = model.state_dict() - assert model.device == self.storage_device + assert hasattr(model, "device") and model.device == self.storage_device else: state_dict = None cache_record = CacheRecord(key=key, model=model, state_dict=state_dict, size=size) @@ -263,7 +263,7 @@ def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device if not (hasattr(cache_entry.model, "device") and hasattr(cache_entry.model, "to")): return - source_device = cache_entry.model.device + source_device = cache_entry.model.device if hasattr(cache_entry.model, "device") else self.StorageDevice # Note: We compare device types only so that 'cuda' == 'cuda:0'. # This would need to be revised to support multi-GPU. From 5c46343c0d2c2531214981f7a989b3eb8f212e8a Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sun, 12 May 2024 20:38:42 -0600 Subject: [PATCH 4/8] do not rely on torch.nn.Module having the device() method --- .../load/model_cache/model_cache_base.py | 1 + .../load/model_cache/model_cache_default.py | 11 ++++------- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py index bd9286d79ba..5c4925b6025 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py @@ -46,6 +46,7 @@ class CacheRecord(Generic[T]): key: str model: T + device: torch.device state_dict: Optional[Dict[str, torch.Tensor]] # this is a copy that stays in CPU size: int loaded: bool = False diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index d73d943b991..465d73b1fbc 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -163,12 +163,8 @@ def put( return self.make_room(size) - if isinstance(model, torch.nn.Module): - state_dict = model.state_dict() - assert hasattr(model, "device") and model.device == self.storage_device - else: - state_dict = None - cache_record = CacheRecord(key=key, model=model, state_dict=state_dict, size=size) + state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None + cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size) self._cached_models[key] = cache_record self._cache_stack.append(key) @@ -263,7 +259,7 @@ def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device if not (hasattr(cache_entry.model, "device") and hasattr(cache_entry.model, "to")): return - source_device = cache_entry.model.device if hasattr(cache_entry.model, "device") else self.StorageDevice + source_device = cache_entry.device # Note: We compare device types only so that 'cuda' == 'cuda:0'. # This would need to be revised to support multi-GPU. @@ -293,6 +289,7 @@ def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device new_dict[k] = v.to(torch.device(target_device), copy=True) cache_entry.model.load_state_dict(new_dict, assign=True) cache_entry.model.to(target_device) + cache_entry.device = target_device except Exception as e: # blow away cache entry self._delete_cache_entry(cache_entry) raise e From 90e5dfb2ad919c864a70c2650ddebf8025edfab6 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sat, 18 May 2024 00:17:55 -0400 Subject: [PATCH 5/8] apply all patches after model is on the execution device --- invokeai/app/invocations/compel.py | 41 +++++++++---------- .../load/model_cache/model_cache_default.py | 2 +- 2 files changed, 20 insertions(+), 23 deletions(-) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 158f11a58e8..211bb5aa539 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -65,11 +65,7 @@ class CompelInvocation(BaseInvocation): @torch.no_grad() def invoke(self, context: InvocationContext) -> ConditioningOutput: tokenizer_info = context.models.load(self.clip.tokenizer) - tokenizer_model = tokenizer_info.model - assert isinstance(tokenizer_model, CLIPTokenizer) text_encoder_info = context.models.load(self.clip.text_encoder) - text_encoder_model = text_encoder_info.model - assert isinstance(text_encoder_model, CLIPTextModel) def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in self.clip.loras: @@ -84,19 +80,21 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: ti_list = generate_ti_list(self.prompt, text_encoder_info.config.base, context) with ( - ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as ( - tokenizer, - ti_manager, - ), + # apply all patches while the model is on the target device text_encoder_info as text_encoder, - # Apply the LoRA after text_encoder has been moved to its target device for faster patching. + tokenizer_info as tokenizer, ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. - ModelPatcher.apply_clip_skip(text_encoder_model, self.clip.skipped_layers), + ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers), + ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as ( + patched_tokenizer, + ti_manager, + ), ): assert isinstance(text_encoder, CLIPTextModel) + assert isinstance(tokenizer, CLIPTokenizer) compel = Compel( - tokenizer=tokenizer, + tokenizer=patched_tokenizer, text_encoder=text_encoder, textual_inversion_manager=ti_manager, dtype_for_device_getter=TorchDevice.choose_torch_dtype, @@ -136,11 +134,7 @@ def run_clip_compel( zero_on_empty: bool, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: tokenizer_info = context.models.load(clip_field.tokenizer) - tokenizer_model = tokenizer_info.model - assert isinstance(tokenizer_model, CLIPTokenizer) text_encoder_info = context.models.load(clip_field.text_encoder) - text_encoder_model = text_encoder_info.model - assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection)) # return zero on empty if prompt == "" and zero_on_empty: @@ -177,20 +171,23 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: ti_list = generate_ti_list(prompt, text_encoder_info.config.base, context) with ( - ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as ( - tokenizer, - ti_manager, - ), + # apply all patches while the model is on the target device text_encoder_info as text_encoder, - # Apply the LoRA after text_encoder has been moved to its target device for faster patching. + tokenizer_info as tokenizer, ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. - ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers), + ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers), + ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as ( + patched_tokenizer, + ti_manager, + ), ): assert isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)) + assert isinstance(tokenizer, CLIPTokenizer) + text_encoder = cast(CLIPTextModel, text_encoder) compel = Compel( - tokenizer=tokenizer, + tokenizer=patched_tokenizer, text_encoder=text_encoder, textual_inversion_manager=ti_manager, dtype_for_device_getter=TorchDevice.choose_torch_dtype, diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index 465d73b1fbc..ba4c31c1fdd 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -163,7 +163,7 @@ def put( return self.make_room(size) - state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None + state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size) self._cached_models[key] = cache_record self._cache_stack.append(key) From 7000c0b5fb16ecdc377873d70ff6cca4d169f665 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sat, 18 May 2024 00:25:24 -0400 Subject: [PATCH 6/8] fix model patching in latents too --- invokeai/app/invocations/latent.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index b3ac3973bf3..a88eff0fcb6 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -930,9 +930,9 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: assert isinstance(unet_info.model, UNet2DConditionModel) with ( ExitStack() as exit_stack, - ModelPatcher.apply_freeu(unet_info.model, self.unet.freeu_config), - set_seamless(unet_info.model, self.unet.seamless_axes), # FIXME unet_info as unet, + ModelPatcher.apply_freeu(unet, self.unet.freeu_config), + set_seamless(unet, self.unet.seamless_axes), # FIXME # Apply the LoRA after unet has been moved to its target device for faster patching. ModelPatcher.apply_lora_unet(unet, _lora_loader()), ): From 0a0c1d1c8ca1a055e6d9273fe36c6e003a0bf041 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sat, 18 May 2024 13:40:30 -0400 Subject: [PATCH 7/8] log patched tokenizer --- invokeai/app/invocations/compel.py | 4 ++-- .../load/model_cache/model_cache_base.py | 18 ++++++++++++++++-- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 211bb5aa539..766b44fdc8a 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -104,7 +104,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: conjunction = Compel.parse_prompt_string(self.prompt) if context.config.get().log_tokenization: - log_tokenization_for_conjunction(conjunction, tokenizer) + log_tokenization_for_conjunction(conjunction, patched_tokenizer) c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction) @@ -200,7 +200,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: if context.config.get().log_tokenization: # TODO: better logging for and syntax - log_tokenization_for_conjunction(conjunction, tokenizer) + log_tokenization_for_conjunction(conjunction, patched_tokenizer) # TODO: ask for optimizations? to not run text_encoder twice c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction) diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py index 5c4925b6025..2ecb3b5d794 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py @@ -42,12 +42,26 @@ def model(self) -> AnyModel: @dataclass class CacheRecord(Generic[T]): - """Elements of the cache.""" + """ + Elements of the cache: + + key: Unique key for each model, same as used in the models database. + model: Model in memory. + state_dict: A read-only copy of the model's state dict in RAM. It will be + used as a template for creating a copy in the VRAM. + size: Size of the model + loaded: True if the model's state dict is currently in VRAM + + Before a model is executed, the state_dict template is copied into VRAM, + and then injected into the model. When the model is finished, the VRAM + copy of the state dict is deleted, and the RAM version is reinjected + into the model. + """ key: str model: T device: torch.device - state_dict: Optional[Dict[str, torch.Tensor]] # this is a copy that stays in CPU + state_dict: Optional[Dict[str, torch.Tensor]] size: int loaded: bool = False _locks: int = 0 From c775b5974b031e678bf170542abda56a502bc75c Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sat, 18 May 2024 20:18:20 -0400 Subject: [PATCH 8/8] closes #6375 --- .../load/model_cache/model_cache_default.py | 36 ++----------------- 1 file changed, 2 insertions(+), 34 deletions(-) diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index ba4c31c1fdd..a3016a63ef8 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -20,7 +20,6 @@ import gc import math -import sys import time from contextlib import suppress from logging import Logger @@ -369,43 +368,12 @@ def make_room(self, size: int) -> None: while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack): model_key = self._cache_stack[pos] cache_entry = self._cached_models[model_key] - - refs = sys.getrefcount(cache_entry.model) - - # HACK: This is a workaround for a memory-management issue that we haven't tracked down yet. We are directly - # going against the advice in the Python docs by using `gc.get_referrers(...)` in this way: - # https://docs.python.org/3/library/gc.html#gc.get_referrers - - # manualy clear local variable references of just finished function calls - # for some reason python don't want to collect it even by gc.collect() immidiately - if refs > 2: - while True: - cleared = False - for referrer in gc.get_referrers(cache_entry.model): - if type(referrer).__name__ == "frame": - # RuntimeError: cannot clear an executing frame - with suppress(RuntimeError): - referrer.clear() - cleared = True - # break - - # repeat if referrers changes(due to frame clear), else exit loop - if cleared: - gc.collect() - else: - break - device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None self.logger.debug( - f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}," - f" refs: {refs}" + f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}" ) - # Expected refs: - # 1 from cache_entry - # 1 from getrefcount function - # 1 from onnx runtime object - if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2): + if not cache_entry.locked: self.logger.debug( f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)" )