From 0e98056f7258474b6ed78f269e8e5474b7f6b4c3 Mon Sep 17 00:00:00 2001 From: fcakyon <34196005+fcakyon@users.noreply.github.com> Date: Thu, 14 May 2020 22:21:27 +0300 Subject: [PATCH] fix empty_cuda_cache (#13) --- README.md | 40 ++++++++++++++++++--------------- craft_text_detector/__init__.py | 9 +++++--- 2 files changed, 28 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index 164cf96..d115a21 100755 --- a/README.md +++ b/README.md @@ -60,6 +60,7 @@ from craft_text_detector import ( get_prediction, export_detected_regions, export_extra_results, + empty_cuda_cache ) # set image path and export folder directory @@ -75,31 +76,34 @@ craft_net = load_craftnet_model(cuda=True) # perform prediction prediction_result = get_prediction( - image=image, - craft_net=craft_net, - refine_net=refine_net, - text_threshold=0.7, - link_threshold=0.4, - low_text=0.4, - cuda=True, - long_size=1280 + image=image, + craft_net=craft_net, + refine_net=refine_net, + text_threshold=0.7, + link_threshold=0.4, + low_text=0.4, + cuda=True, + long_size=1280 ) # export detected text regions exported_file_paths = export_detected_regions( - image_path=image_path, - image=image, - regions=prediction_result["boxes"], - output_dir=output_dir, - rectify=True + image_path=image_path, + image=image, + regions=prediction_result["boxes"], + output_dir=output_dir, + rectify=True ) # export heatmap, detection points, box visualization export_extra_results( - image_path=image_path, - image=image, - regions=prediction_result["boxes"], - heatmaps=prediction_result["heatmaps"], - output_dir=output_dir + image_path=image_path, + image=image, + regions=prediction_result["boxes"], + heatmaps=prediction_result["heatmaps"], + output_dir=output_dir ) + +# unload models from gpu +empty_cuda_cache() ``` diff --git a/craft_text_detector/__init__.py b/craft_text_detector/__init__.py index 5792e60..cfdcb91 100644 --- a/craft_text_detector/__init__.py +++ b/craft_text_detector/__init__.py @@ -4,8 +4,9 @@ import craft_text_detector.file_utils as file_utils import craft_text_detector.image_utils as image_utils import craft_text_detector.predict as predict +import craft_text_detector.torch_utils as torch_utils -__version__ = "0.3.0" +__version__ = "0.3.1" __all__ = [ @@ -15,6 +16,7 @@ "get_prediction", "export_detected_regions", "export_extra_results", + "empty_cuda_cache", "Craft", ] @@ -24,6 +26,7 @@ get_prediction = predict.get_prediction export_detected_regions = file_utils.export_detected_regions export_extra_results = file_utils.export_extra_results +empty_cuda_cache = torch_utils.empty_cuda_cache class Craft: @@ -90,14 +93,14 @@ def unload_craftnet_model(self): Unloads craftnet model """ self.craft_net = None - craft_utils.empty_cuda_cache() + empty_cuda_cache() def unload_refinenet_model(self): """ Unloads refinenet model """ self.refine_net = None - craft_utils.empty_cuda_cache() + empty_cuda_cache() def detect_text(self, image_path): """