Skip to content
This repository has been archived by the owner on Dec 2, 2022. It is now read-only.

Commit

Permalink
Enable package to load model from local path (#53)
Browse files Browse the repository at this point in the history
* Use headless version of opencv

* Provide possibility to load net from local path

* Remove headless again for merge to official repo

Co-authored-by: Tanja Bayer <tanja.bayer@widas.de>
  • Loading branch information
TanjaBayer and Tanja Bayer committed May 9, 2022
1 parent c856736 commit 725af71
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
result*
weights*
.vscode
.pypirc

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
15 changes: 9 additions & 6 deletions craft_text_detector/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import absolute_import

import os
from typing import Optional

import craft_text_detector.craft_utils as craft_utils
import craft_text_detector.file_utils as file_utils
Expand Down Expand Up @@ -44,6 +45,8 @@ def __init__(
long_size=1280,
refiner=True,
crop_type="poly",
weight_path_craft_net: Optional[str] = None,
weight_path_refine_net: Optional[str] = None,
):
"""
Arguments:
Expand Down Expand Up @@ -72,22 +75,22 @@ def __init__(
self.crop_type = crop_type

# load craftnet
self.load_craftnet_model()
self.load_craftnet_model(weight_path_craft_net)
# load refinernet if required
if refiner:
self.load_refinenet_model()
self.load_refinenet_model(weight_path_refine_net)

def load_craftnet_model(self):
def load_craftnet_model(self, weight_path: Optional[str] = None):
"""
Loads craftnet model
"""
self.craft_net = load_craftnet_model(self.cuda)
self.craft_net = load_craftnet_model(self.cuda, weight_path=weight_path)

def load_refinenet_model(self):
def load_refinenet_model(self, weight_path: Optional[str] = None):
"""
Loads refinenet model
"""
self.refine_net = load_refinenet_model(self.cuda)
self.refine_net = load_refinenet_model(self.cuda, weight_path=weight_path)

def unload_craftnet_model(self):
"""
Expand Down

0 comments on commit 725af71

Please sign in to comment.