Skip to content

Conversation

lstein
Copy link
Collaborator

@lstein lstein commented Apr 10, 2024

Summary

If a CUDA Out-Of-Memory (OOM) exception occurs during model loading, this commit recovers from the error by clearing out the model manager's cache for that model. This prevents the partially-loaded model from getting "stuck" in VRAM and preventing further generations.

Related Issues / Discussions

To reproduce the underlying issue on current main, be sure to use a Linux system (the Windows NVIDIA driver behaves differently).

  1. Do a generation with an SDXL model. Run nvidia-smi afterward. It should show a small amount of VRAM being used by invokeai. Typically about 600Mb.
  2. Run another process that uses lots of VRAM. For my tests, I load a large language model using ollama.
  3. Trigger an OOM error by running a generation with the same SDXL model. You should get an error like torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 26.00 MiB. GPU 0 has a total capacity of 11.73 GiB of which 7.12 MiB is free. If you check nvidia-smi now, you'll see that a significant amount of VRAM (>1G) is still allocated to the invokeai process.
  4. Free sufficient VRAM by exiting the process from (2). I do this by shutting down the ollama server or by unloading its current LLM.
  5. Try to run a generation with the same SDXL model. Even though there is now sufficient space to load the model, you'll get an error like this one: Error while invoking session 95917faf-1faf-4094-9dcd-86cfd1c943e4, invocation fb7a92ba-a6e8-4345-837c-ed842665fccd (denoise_latents): Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument weight in method wrapper_CUDA__native_group_norm)
  6. You cannot clear this problem without either (1) switching to a new model and then switching back (may need to do this repeatedly with different models, depending on VRAM capacity), or (2) killing invokeai and restarting.

QA Instructions

Follow the recipe given above before and after applying this PR. With this PR, when the OOM error occurs, the RAM cache is cleared of the partially-loaded model, and as soon as there is sufficient unused VRAM the model should load and generate.

Note that this is a CUDA-specific fix. If there is an equivalent problem on MPS systems, this will not fix it. However, we haven't had any reports from Mac users yet.

Merge Plan

Merge when approved.

Checklist

  • The PR has a short but descriptive title, suitable for a changelog
  • Tests added / updated (if applicable)
  • Documentation added / updated (if applicable)

@github-actions github-actions bot added python PRs that change python files invocations PRs that change invocations backend PRs that change backend files services PRs that change app services labels Apr 10, 2024
@psychedelicious
Copy link
Collaborator

psychedelicious commented Apr 10, 2024

Here's my testing procedure for main. My Invoke uses ~670 MB VRAM at idle.

  1. Open a python REPL
  2. Load a few models using diffusers:
from diffusers import StableDiffusionPipeline
i = StableDiffusionPipeline.from_pretrained("/home/bat/invokeai-4.0.0/models/sd-1/main/dreamshaper-8-inpainting", use_safetensors=True, variant="fp16").to("cuda")
j = StableDiffusionPipeline.from_pretrained("/home/bat/invokeai-4.0.0/models/sdxl/main/Juggernaut-XL-v9", use_safetensors=True, variant="fp16").to("cuda")
k = StableDiffusionPipeline.from_pretrained("/home/bat/invokeai-4.0.0/models/sd-1/main/Analog-Diffusion2", use_safetensors=True, variant="fp16").to("cuda")

This uses about 22GB VRAM:

image
4. Run Invoke and generate with a model, get a torch.cuda.OutOfMemoryError. Sometimes I get a different error: RuntimeError: CUDA error: out of memory. I couldn't figure out a pattern.
5. Quit the REPL
6. Generate again. If I use the same model from step 4, I get at one of these two errors, could be others I missed:

  • RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.HalfTensor) should be the same
  • RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument weight in method wrapper_CUDA__native_group_norm).

Takeaway

We need to handle at least torch.cuda.OutOfMemoryError and RuntimeError. Note that torch.cuda.OutOfMemoryError inherits from RuntimeError.

try:
  cache_entry.model.to(target_device)
except (torch.cuda.OutOfMemoryError, RuntimeError) as e:  # blow away cache entry
    if isinstance(e, torch.cuda.OutOfMemoryError) or "CUDA out of memory" in str(e):
        self._delete_cache_entry(cache_entry)
    raise e

But to really be safe, we probably should clear the cache on any exception when calling cache_entry.model.to(target_device). If I just catch any Exception, the issue appears to be resolved, and my VRAM usage drops down to the idle level.

Is there a reason to not make it a catch-all?

Side-note: I used the default VRAM cache setting of 0.25 for testing. I had it at 12 at first but that made it harder to understand what was happening.

@lstein
Copy link
Collaborator Author

lstein commented Apr 11, 2024

  • RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.HalfTensor) should be the same
  • RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument weight in method wrapper_CUDA__native_group_norm).

Takeaway

We need to handle at least torch.cuda.OutOfMemoryError and RuntimeError. Note that torch.cuda.OutOfMemoryError inherits from RuntimeError.

try:
  cache_entry.model.to(target_device)
except (torch.cuda.OutOfMemoryError, RuntimeError) as e:  # blow away cache entry
    if isinstance(e, torch.cuda.OutOfMemoryError) or "CUDA out of memory" in str(e):
        self._delete_cache_entry(cache_entry)
    raise e

But to really be safe, we probably should clear the cache on any exception when calling cache_entry.model.to(target_device). If I just catch any Exception, the issue appears to be resolved, and my VRAM usage drops down to the idle level.

Is there a reason to not make it a catch-all?

Side-note: I used the default VRAM cache setting of 0.25 for testing. I had it at 12 at first but that made it harder to understand what was happening.

Interesting, I consistently get torch.cuda.OutOfMemoryError. Might be differences in the torch library version.

Catching all Exceptions might swallow KeyboardInterrupt. How about we catch all RuntimeErrors?

        try:
            cache_entry.model.to(target_device)
        except RuntimeError as e:  # blow away cache entry
            self._delete_cache_entry(cache_entry)
            raise e

@psychedelicious
Copy link
Collaborator

Catching all Exceptions might swallow KeyboardInterrupt. How about we catch all RuntimeErrors?

We re-raise the error immediately afterwards anyways - what's the harm in catching everything and clearing the cache in the event of any error? Probably oughta do this anyways. There could be some other kind of exception that results in a borked cache for that model.

@lstein
Copy link
Collaborator Author

lstein commented Apr 11, 2024

Catching all Exceptions might swallow KeyboardInterrupt. How about we catch all RuntimeErrors?

We re-raise the error immediately afterwards anyways - what's the harm in catching everything and clearing the cache in the event of any error? Probably oughta do this anyways. There could be some other kind of exception that results in a borked cache for that model.

Ok. Done in latest commit.

@psychedelicious psychedelicious enabled auto-merge (rebase) April 11, 2024 21:10
@psychedelicious
Copy link
Collaborator

Thanks for tracking this down @lstein

@psychedelicious psychedelicious force-pushed the lstein/bugfix/recover-from-oom-errors branch from 6076b15 to 7605632 Compare April 11, 2024 21:11
@psychedelicious psychedelicious merged commit 651c0b3 into main Apr 11, 2024
@psychedelicious psychedelicious deleted the lstein/bugfix/recover-from-oom-errors branch April 11, 2024 21:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend PRs that change backend files invocations PRs that change invocations python PRs that change python files services PRs that change app services
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants