Skip to content

Commit

Permalink
Load weights from hugging face hub
Browse files Browse the repository at this point in the history
  • Loading branch information
xianbaoqian committed Jul 16, 2022
1 parent 5657be1 commit c1bd87b
Showing 1 changed file with 19 additions and 10 deletions.
29 changes: 19 additions & 10 deletions mmocr/utils/ocr.py
Expand Up @@ -14,6 +14,8 @@
from mmcv.utils.config import Config
from PIL import Image

from huggingface_hub import hf_hub_url

try:
import tesserocr
except ImportError:
Expand Down Expand Up @@ -258,9 +260,9 @@ def __init__(self,
'PANet_IC15': {
'config':
'panet/panet_r18_fpem_ffm_600e_icdar2015.py',
'ckpt':
'panet/'
'panet_r18_fpem_ffm_sbn_600e_icdar2015_20210219-42dbe46a.pth'
'ckpt': hf_hub_url(
repo_id="xianbao/mmocr",
filename="panet_r18_fpem_ffm_sbn_600e_icdar2015_20210219-42dbe46a.pth")
},
'PS_CTW': {
'config': 'psenet/psenet_r50_fpnf_600e_ctw1500.py',
Expand Down Expand Up @@ -329,7 +331,8 @@ def __init__(self,
},
'SEG': {
'config': 'seg/seg_r31_1by16_fpnocr_academic.py',
'ckpt': 'seg/seg_r31_1by16_fpnocr_academic-72235b11.pth'
'ckpt': hf_hub_url(repo_id="xianbao/mmocr",
filename="seg_r31_1by16_fpnocr_academic-72235b11.pth")
},
'CRNN_TPS': {
'config': 'tps/crnn_tps_academic_dataset.py',
Expand Down Expand Up @@ -388,8 +391,10 @@ def __init__(self,
det_config = os.path.join(config_dir, 'textdet/',
textdet_models[self.td]['config'])
if not det_ckpt:
det_ckpt = 'https://download.openmmlab.com/mmocr/textdet/' + \
textdet_models[self.td]['ckpt']
det_ckpt = textdet_models[self.td]['ckpt']
if not det_ckpt.startswith('https://'):
det_ckpt = 'https://download.openmmlab.com/mmocr/textdet/' + \
det_ckpt

self.detect_model = init_detector(
det_config, det_ckpt, device=self.device)
Expand All @@ -409,8 +414,10 @@ def __init__(self,
config_dir, 'textrecog/',
textrecog_models[self.tr]['config'])
if not recog_ckpt:
recog_ckpt = 'https://download.openmmlab.com/mmocr/' + \
'textrecog/' + textrecog_models[self.tr]['ckpt']
recog_ckpt = textrecog_models[self.tr]['ckpt']
if not recog_ckpt.startswith('https://'):
recog_ckpt = 'https://download.openmmlab.com/mmocr/' + \
'textrecog/' + recog_ckpt

self.recog_model = init_detector(
recog_config, recog_ckpt, device=self.device)
Expand All @@ -423,8 +430,10 @@ def __init__(self,
kie_config = os.path.join(config_dir, 'kie/',
kie_models[self.kie]['config'])
if not kie_ckpt:
kie_ckpt = 'https://download.openmmlab.com/mmocr/' + \
'kie/' + kie_models[self.kie]['ckpt']
kie_ckpt = kie_models[self.kie]['ckpt']
if not kie_ckpt.startswith('https://'):
kie_ckpt = 'https://download.openmmlab.com/mmocr/' + \
'kie/' + kie_ckpt

kie_cfg = Config.fromfile(kie_config)
self.kie_model = build_detector(
Expand Down

0 comments on commit c1bd87b

Please sign in to comment.