From d7785bea293c57225d7c98a18c6d7519af3fd4b1 Mon Sep 17 00:00:00 2001 From: JPPhoto Date: Mon, 6 Feb 2023 06:57:41 -0600 Subject: [PATCH] In exception handlers, clear the torch CUDA cache (if we're using CUDA) to free up memory for other programs using the GPU and to reduce fragmentation. --- invokeai/backend/invoke_ai_web_server.py | 12 ++++++++++++ ldm/generate.py | 10 +++++++++- ldm/invoke/generator/txt2img2img.py | 4 ++++ 3 files changed, 25 insertions(+), 1 deletion(-) diff --git a/invokeai/backend/invoke_ai_web_server.py b/invokeai/backend/invoke_ai_web_server.py index 9dd18ebe65d..c08dee596a8 100644 --- a/invokeai/backend/invoke_ai_web_server.py +++ b/invokeai/backend/invoke_ai_web_server.py @@ -1208,12 +1208,18 @@ def diffusers_step_callback_adapter(*cb_args, **kwargs): ) except KeyboardInterrupt: + # Clear the CUDA cache on an exception + self.empty_cuda_cache() self.socketio.emit("processingCanceled") raise except CanceledException: + # Clear the CUDA cache on an exception + self.empty_cuda_cache() self.socketio.emit("processingCanceled") pass except Exception as e: + # Clear the CUDA cache on an exception + self.empty_cuda_cache() print(e) self.socketio.emit("error", {"message": (str(e))}) print("\n") @@ -1221,6 +1227,12 @@ def diffusers_step_callback_adapter(*cb_args, **kwargs): traceback.print_exc() print("\n") + def empty_cuda_cache(self): + if self.generate.device.type == "cuda": + import torch.cuda + + torch.cuda.empty_cache() + def parameters_to_generated_image_metadata(self, parameters): try: # top-level metadata minus `image` or `images` diff --git a/ldm/generate.py b/ldm/generate.py index 002ba47a97b..9101ac3f01f 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -211,7 +211,7 @@ def __init__( print('>> xformers memory-efficient attention is available but disabled') else: print('>> xformers not installed') - + # model caching system for fast switching self.model_manager = ModelManager(mconfig,self.device,self.precision,max_loaded_models=max_loaded_models) # don't accept invalid models @@ -565,11 +565,19 @@ def process_image(image,seed): image_callback = image_callback) except KeyboardInterrupt: + # Clear the CUDA cache on an exception + if self._has_cuda(): + torch.cuda.empty_cache() + if catch_interrupts: print('**Interrupted** Partial results will be returned.') else: raise KeyboardInterrupt except RuntimeError: + # Clear the CUDA cache on an exception + if self._has_cuda(): + torch.cuda.empty_cache() + print(traceback.format_exc(), file=sys.stderr) print('>> Could not generate image.') diff --git a/ldm/invoke/generator/txt2img2img.py b/ldm/invoke/generator/txt2img2img.py index 1c398fb95d8..0e9493aa44d 100644 --- a/ldm/invoke/generator/txt2img2img.py +++ b/ldm/invoke/generator/txt2img2img.py @@ -65,6 +65,10 @@ def make_image(x_T): mode="bilinear" ) + # Free up memory from the last generation. + if self.model.device.type == 'cuda': + torch.cuda.empty_cache() + second_pass_noise = self.get_noise_like(resized_latents) verbosity = get_verbosity()