Skip to content

Commit

Permalink
Merge pull request #5 from hukkelas/max_resolution_resize
Browse files Browse the repository at this point in the history
Max resolution resize
  • Loading branch information
hukkelas committed Apr 5, 2020
2 parents 22aa803 + 17b0838 commit 2528608
Show file tree
Hide file tree
Showing 10 changed files with 225 additions and 345 deletions.
103 changes: 100 additions & 3 deletions face_detection/base.py
@@ -1,23 +1,120 @@
import numpy as np
from abc import ABC
import torch
import typing
from abc import ABC, abstractmethod
from torchvision.ops import nms
from .box_utils import scale_boxes


class Detector(ABC):

def __init__(
self,
confidence_threshold: float,
nms_iou_threshold: float):
nms_iou_threshold: float,
device: torch.device,
max_resolution: int):
"""
Args:
confidence_threshold (float): Threshold to filter out bounding boxes
nms_iou_threshold (float): Intersection over union threshold for non-maxima threshold
device ([type], optional): Defaults to cuda if cuda capable device is available.
max_resolution (int, optional): Max image resolution to do inference to.
"""
self.confidence_threshold = confidence_threshold
self.nms_iou_threshold = nms_iou_threshold
self.device = device
self.max_resolution = max_resolution
self.mean = np.array(
[123, 117, 104], dtype=np.float32).reshape(1, 1, 1, 3)

def detect(
self, image: np.ndarray) -> np.ndarray:
self, image: np.ndarray, shrink=1.0) -> np.ndarray:
"""Takes an RGB image and performs and returns a set of bounding boxes as
detections
Args:
image (np.ndarray): shape [height, width, 3]
Returns:
np.ndarray: shape [N, 5] with (xmin, ymin, xmax, ymax, score)
"""
image = image[None]
boxes = self.batched_detect(image, shrink)
return boxes[0]

@abstractmethod
def _detect(self, image: torch.Tensor) -> torch.Tensor:
"""Takes N RGB image and performs and returns a set of bounding boxes as
detections
Args:
image (torch.Tensor): shape [N, 3, height, width]
Returns:
torch.Tensor: of shape [N, B, 5] with (xmin, ymin, xmax, ymax, score)
"""
raise NotImplementedError

def filter_boxes(self, boxes: torch.Tensor) -> typing.List[np.ndarray]:
"""Performs NMS and score thresholding
Args:
boxes (torch.Tensor): shape [N, B, 5] with (xmin, ymin, xmax, ymax, score)
Returns:
list: N np.ndarray of shape [B, 5]
"""
final_output = []
for i in range(len(boxes)):
scores = boxes[i, :, 4]
keep_idx = scores >= self.confidence_threshold
boxes_ = boxes[i, keep_idx, :-1]
scores = scores[keep_idx]
if scores.dim() == 0:
final_output.append(torch.empty(0, 5))
continue
keep_idx = nms(boxes_, scores, self.nms_iou_threshold)
scores = scores[keep_idx].view(-1, 1)
boxes_ = boxes_[keep_idx].view(-1, 4)
output = torch.cat((boxes_, scores), dim=-1)
final_output.append(output)
return final_output

def _pre_process(self, image: np.ndarray, shrink: float) -> torch.Tensor:
"""Takes N RGB image and performs and returns a set of bounding boxes as
detections
Args:
image (np.ndarray): shape [N, height, width, 3]
Returns:
torch.Tensor: shape [N, 3, height, width]
"""
assert image.dtype == np.uint8
height, width = image.shape[1:3]
image = image.astype(np.float32) - self.mean
image = np.moveaxis(image, -1, 1)
image = torch.from_numpy(image)
if self.max_resolution is not None:
shrink_factor = self.max_resolution / max((height, width))
if shrink_factor <= shrink:
shrink = shrink_factor
image = torch.nn.functional.interpolate(image, scale_factor=shrink)
image = image.to(self.device)
return image

def _batched_detect(self, image: np.ndarray) -> typing.List[np.ndarray]:
boxes = self._detect(image)
boxes = self.filter_boxes(boxes)
return boxes

@torch.no_grad()
def batched_detect(
self, image: np.ndarray, shrink=1.0) -> typing.List[np.ndarray]:
"""Takes N RGB image and performs and returns a set of bounding boxes as
detections
Args:
image (np.ndarray): shape [N, height, width, 3]
Returns:
np.ndarray: a list with N set of bounding boxes of
shape [B, 5] with (xmin, ymin, xmax, ymax, score)
"""
height, width = image.shape[1:3]
image = self._pre_process(image, shrink)
boxes = self._batched_detect(image)
boxes = [scale_boxes((height, width), box).cpu().numpy() for box in boxes]
return boxes
24 changes: 15 additions & 9 deletions face_detection/box_utils.py
@@ -1,23 +1,29 @@
import torch


# Adapted from https://github.com/Hakuyume/chainer-ssd
def decode(loc, priors, variances):
def batched_decode(loc, priors, variances):
"""Decode locations from predictions using priors to undo
the encoding we did for offset regression at train time.
Args:
loc (tensor): location predictions for loc layers,
Shape: [num_priors,4]
priors (tensor): Prior boxes in center-offset form.
Shape: [num_priors,4].
Shape: [N, num_priors,4].
variances: (list[float]) Variances of priorboxes
Return:
decoded bounding box predictions
"""

priors = priors[None]
boxes = torch.cat((
priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
boxes[:, :2] -= boxes[:, 2:] / 2
boxes[:, 2:] += boxes[:, :2]
priors[:, :, :2] + loc[:, :, :2] * variances[0] * priors[:, :, 2:],
priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * variances[1])),
dim=2)
boxes[:, :, :2] -= boxes[:, :, 2:] / 2
boxes[:, :, 2:] += boxes[:, :, :2]
return boxes


def scale_boxes(imshape, boxes):
height, width = imshape
boxes[:, [0, 2]] *= width
boxes[:, [1, 3]] *= height
return boxes
10 changes: 8 additions & 2 deletions face_detection/build.py
@@ -1,5 +1,6 @@
from .registry import build_from_cfg, Registry
from .base import Detector
from .torch_utils import get_device

available_detectors = [
"DSFDDetector",
Expand All @@ -12,14 +13,19 @@
def build_detector(
name: str = "DSFDDetector",
confidence_threshold: float = 0.5,
nms_iou_threshold: float = 0.3) -> Detector:
nms_iou_threshold: float = 0.3,
device=get_device(),
max_resolution: int = None
) -> Detector:
assert name in available_detectors,\
f"Detector not available. Chooce one of the following"+\
",".join(available_detectors)
args = dict(
type=name,
confidence_threshold=confidence_threshold,
nms_iou_threshold=nms_iou_threshold
nms_iou_threshold=nms_iou_threshold,
device=device,
max_resolution=max_resolution
)
detector = build_from_cfg(args, DETECTOR_REGISTRY)
return detector

0 comments on commit 2528608

Please sign in to comment.