Skip to content
This repository has been archived by the owner on Dec 2, 2022. It is now read-only.

Commit

Permalink
fix empty_cuda_cache (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
fcakyon committed May 14, 2020
1 parent 1930e6a commit 0e98056
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 21 deletions.
40 changes: 22 additions & 18 deletions README.md
Expand Up @@ -60,6 +60,7 @@ from craft_text_detector import (
get_prediction,
export_detected_regions,
export_extra_results,
empty_cuda_cache
)

# set image path and export folder directory
Expand All @@ -75,31 +76,34 @@ craft_net = load_craftnet_model(cuda=True)

# perform prediction
prediction_result = get_prediction(
image=image,
craft_net=craft_net,
refine_net=refine_net,
text_threshold=0.7,
link_threshold=0.4,
low_text=0.4,
cuda=True,
long_size=1280
image=image,
craft_net=craft_net,
refine_net=refine_net,
text_threshold=0.7,
link_threshold=0.4,
low_text=0.4,
cuda=True,
long_size=1280
)

# export detected text regions
exported_file_paths = export_detected_regions(
image_path=image_path,
image=image,
regions=prediction_result["boxes"],
output_dir=output_dir,
rectify=True
image_path=image_path,
image=image,
regions=prediction_result["boxes"],
output_dir=output_dir,
rectify=True
)

# export heatmap, detection points, box visualization
export_extra_results(
image_path=image_path,
image=image,
regions=prediction_result["boxes"],
heatmaps=prediction_result["heatmaps"],
output_dir=output_dir
image_path=image_path,
image=image,
regions=prediction_result["boxes"],
heatmaps=prediction_result["heatmaps"],
output_dir=output_dir
)

# unload models from gpu
empty_cuda_cache()
```
9 changes: 6 additions & 3 deletions craft_text_detector/__init__.py
Expand Up @@ -4,8 +4,9 @@
import craft_text_detector.file_utils as file_utils
import craft_text_detector.image_utils as image_utils
import craft_text_detector.predict as predict
import craft_text_detector.torch_utils as torch_utils

__version__ = "0.3.0"
__version__ = "0.3.1"


__all__ = [
Expand All @@ -15,6 +16,7 @@
"get_prediction",
"export_detected_regions",
"export_extra_results",
"empty_cuda_cache",
"Craft",
]

Expand All @@ -24,6 +26,7 @@
get_prediction = predict.get_prediction
export_detected_regions = file_utils.export_detected_regions
export_extra_results = file_utils.export_extra_results
empty_cuda_cache = torch_utils.empty_cuda_cache


class Craft:
Expand Down Expand Up @@ -90,14 +93,14 @@ def unload_craftnet_model(self):
Unloads craftnet model
"""
self.craft_net = None
craft_utils.empty_cuda_cache()
empty_cuda_cache()

def unload_refinenet_model(self):
"""
Unloads refinenet model
"""
self.refine_net = None
craft_utils.empty_cuda_cache()
empty_cuda_cache()

def detect_text(self, image_path):
"""
Expand Down

0 comments on commit 0e98056

Please sign in to comment.