Skip to content
Browse files

[MRG] Object Detection (#396)

* object detection v0.1 added

* add import object_detection in under autokeras

* bug fixed

* add support for customized VOC-like dataset

* add load() and predict() in

* update .gitignore

* add download_file support & bug fixed

* update load() and predict(); add pretrained ABC

* update data utils

* update predict

* update predict()

* add test cases

* update get_device and remove code not used

* update load()

* remove redundant code

* test object detection

* pretrain

* bug fixed in load() when no cuda device avaialbe

* comment cudnn & update object detection tests

* bug fix

* bug fixed (#366)

* change

* refactor

* fixed unittests import
  • Loading branch information...
jhfjhfj1 committed Dec 26, 2018
1 parent b9ac220 commit 0076797c5f4730ddd96cab9083bab345791af786
@@ -1,3 +1,8 @@
# vim swp files
# caffe/pytorch model files

# Mkdocs
@@ -1,4 +1,8 @@
from autokeras.image.image_supervised import ImageClassifier, ImageRegressor, PortableImageSupervised
from autokeras.text.text_supervised import TextClassifier, TextRegressor
from autokeras.tabular.tabular_supervised import TabularClassifier, TabularRegressor
from autokeras.net_module import CnnModule, MlpModule

from autokeras.net_module import CnnModule, MlpModule

from autokeras.pretrained.object_detector import ObjectDetector
from autokeras.pretrained.face_detector import FaceDetector
@@ -63,6 +63,8 @@ class Constant:
PRE_TRAIN_FILE_NAME = "glove.6B.100d.txt"


# constants for pretrained model of face detection

This file was deleted.

Oops, something went wrong.
No changes.
@@ -0,0 +1,31 @@
from abc import ABC, abstractmethod

class Pretrained(ABC):
"""The base class for all pretrained task.
verbose: A boolean value indicating the verbosity mode.

def __init__(self):
"""Initialize the instance."""
self.model = None

def load(self):
"""load pretrained model into self.model

def predict(self, x_predict):
"""Return predict results for the given image
x_predict: An instance of numpy.ndarray containing the testing data.
A numpy.ndarray containing the results.

@@ -1,5 +1,7 @@
# This is DFace's implementation of MTCNN modified for AutoKeras
# Link to DFace:
import os

import cv2
import torch
import torch.nn as nn
@@ -10,6 +12,10 @@
import matplotlib.pyplot as plt
import matplotlib.patches as patches

from autokeras.constant import Constant
from autokeras.pretrained.base import Pretrained
from autokeras.utils import download_file

def weights_init(m):
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
@@ -599,3 +605,35 @@ def detect_faces(pnet_path, rnet_path, onet_path, img_path, output_file_path):
if output_file_path is not None:
vis_face(img_bg, bboxs, output_file_path, landmarks)
return bboxs, landmarks

class FaceDetector(Pretrained):
"""A class to predict faces using the MTCNN pre-trained model.

def __init__(self):
super(FaceDetector, self).__init__()

def load(self, model_path=None):
for model_link, file_path in zip(Constant.FACE_DETECTION_PRETRAINED['PRETRAINED_MODEL_LINKS'],
download_file(model_link, file_path)
self.pnet, self.rnet, self.onet = Constant.FACE_DETECTION_PRETRAINED['FILE_PATHS']

def predict(self, img_path, output_file_path=None):
"""Predicts faces in an image.
img_path: A string. The path to the image on which the prediction is to be done.
output_file_path: A string. The path where the output image is to be saved after the prediction. `None` by default.
A tuple containing numpy arrays of bounding boxes and landmarks. Bounding boxes are of shape `(n, 5)` and
landmarks are of shape `(n, 10)` where `n` is the number of faces predicted. Each bounding box is of length
5 and the corresponding rectangle is defined by the first four values. Each bounding box has five landmarks
represented by 10 coordinates.
if not os.path.exists(img_path):
raise ValueError('Image does not exist')
return detect_faces(self.pnet, self.rnet, self.onet, img_path, output_file_path)
Oops, something went wrong.

0 comments on commit 0076797

Please sign in to comment.
You can’t perform that action at this time.