Skip to content

Commit

Permalink
[MRG] Unified download function for pretrained models (resolves #400) (
Browse files Browse the repository at this point in the history
…#417)

* added unified download function

* renamed constants
  • Loading branch information
droidadroit authored and haifeng-jin committed Jan 9, 2019
1 parent 392b389 commit f32d270
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 19 deletions.
12 changes: 8 additions & 4 deletions autokeras/constant.py
Expand Up @@ -63,20 +63,24 @@ class Constant:

PRE_TRAIN_DETECTION_FILE_LINK = "https://s3.amazonaws.com/amdegroot-models/ssd300_mAP_77.43_v2.pth"

# constants for pretrained model of face detection
FACE_DETECTION_PRETRAINED = {
'PRETRAINED_MODEL_LINKS': [
FACE_DETECTOR = {
'MODEL_LINKS': [
'https://raw.githubusercontent.com/kuaikuaikim/DFace/master/model_store/pnet_epoch.pt',
'https://raw.githubusercontent.com/kuaikuaikim/DFace/master/model_store/rnet_epoch.pt',
'https://raw.githubusercontent.com/kuaikuaikim/DFace/master/model_store/onet_epoch.pt'
],
'FILE_NAMES': [
'MODEL_NAMES': [
'pnet.pt',
'rnet.pt',
'onet.pt'
]
}

OBJECT_DETECTOR = {
'MODEL_LINK': 'https://s3.amazonaws.com/amdegroot-models/ssd300_mAP_77.43_v2.pth',
'MODEL_NAME': 'object_detection_pretrained.pth'
}

# Image Resize

MAX_IMAGE_SIZE = 128 * 128
Expand Down
14 changes: 5 additions & 9 deletions autokeras/pretrained/face_detector.py
Expand Up @@ -13,7 +13,7 @@

from autokeras.constant import Constant
from autokeras.pretrained.base import Pretrained
from autokeras.utils import download_file, temp_path_generator, ensure_dir, get_device
from autokeras.utils import download_model, get_device


def weights_init(m):
Expand Down Expand Up @@ -276,10 +276,8 @@ class FaceDetector(Pretrained):
def __init__(self):
super(FaceDetector, self).__init__()

self.load()
pnet, rnet, onet = self.load()
self.device = get_device()
pnet, rnet, onet = list(map(lambda file_name: f'{temp_path_generator()}/{file_name}',
Constant.FACE_DETECTION_PRETRAINED['FILE_NAMES']))

self.pnet_detector = PNet()
if torch.cuda.is_available():
Expand Down Expand Up @@ -311,11 +309,9 @@ def __init__(self):
self.scale_factor = 0.709

def load(self, model_path=None):
temp_path = temp_path_generator()
ensure_dir(temp_path)
for model_link, file_path in zip(Constant.FACE_DETECTION_PRETRAINED['PRETRAINED_MODEL_LINKS'],
Constant.FACE_DETECTION_PRETRAINED['FILE_NAMES']):
download_file(model_link, f'{temp_path}/{file_path}')
model_paths = [download_model(model_link, file_name) for model_link, file_name in zip(
Constant.FACE_DETECTOR['MODEL_LINKS'], Constant.FACE_DETECTOR['MODEL_NAMES'])]
return model_paths

def predict(self, img_path, output_file_path=None):
"""Predicts faces in an image.
Expand Down
8 changes: 2 additions & 6 deletions autokeras/pretrained/object_detector.py
Expand Up @@ -5,7 +5,7 @@
# ----------------------------------

from autokeras.pretrained.base import Pretrained
from autokeras.utils import download_file, temp_path_generator, get_device
from autokeras.utils import download_model, get_device
from autokeras.constant import Constant
import numpy as np
import cv2
Expand Down Expand Up @@ -517,12 +517,8 @@ def _build_ssd(self, phase, size=300, num_classes=21):
return SSD(phase, size, base_, extras_, head_, num_classes, self.device)

def load(self, model_path=None):
# https://s3.amazonaws.com/amdegroot-models/ssd300_mAP_77.43_v2.pth
if model_path is None:
file_link = Constant.PRE_TRAIN_DETECTION_FILE_LINK
# model_path = os.path.join(temp_path_generator(), "object_detection_pretrained.pth")
model_path = temp_path_generator() + '_object_detection_pretrained.pth'
download_file(file_link, model_path)
model_path = download_model(Constant.OBJECT_DETECTOR['MODEL_LINK'], Constant.OBJECT_DETECTOR['MODEL_NAME'])
# load net
num_classes = len(VOC_CLASSES) + 1 # +1 for background
self.model = self._build_ssd('test', 300, num_classes) # initialize SSD
Expand Down
8 changes: 8 additions & 0 deletions autokeras/utils.py
Expand Up @@ -130,6 +130,14 @@ def download_file(file_link, file_path):
sys.stdout.flush()


def download_model(model_link, model_file_name):
temp_path = temp_path_generator()
ensure_dir(temp_path)
model_path = f'{temp_path}/{model_file_name}'
download_file(model_link, model_path)
return model_path


def download_file_with_extract(file_link, file_path, extract_path):
"""Download the file specified in `file_link`, save to `file_path` and extract to the directory `extract_path`."""
if not os.path.exists(extract_path):
Expand Down

0 comments on commit f32d270

Please sign in to comment.