Skip to content

Commit

Permalink
Merge pull request #18 from matlab-deep-learning/bugfix/selectStronge…
Browse files Browse the repository at this point in the history
…stOrder/17

Fixes issue with ordering of boxes and landmarks
  • Loading branch information
justinpinkney committed Sep 14, 2020
2 parents 3240147 + 423a49f commit 3ec4d03
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 6 deletions.
13 changes: 8 additions & 5 deletions code/mtcnn/+mtcnn/Detector.m
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@
landmarks = cat(3, x, y);
landmarks(probs(:, 2) < obj.ConfidenceThresholds(3), :, :) = [];

[scores, bboxes] = obj.processOutputs(probs, correction, bboxes, 3);
[scores, bboxes, landmarks] = obj.processOutputs(probs, correction, bboxes, 3, landmarks);

% Gather and cast the outputs
bboxes= gather(double(bboxes));
Expand All @@ -151,17 +151,20 @@
cropped = mtcnn.util.cropImage(im, bboxes, outputSize);
end

function [scores, bboxes] = ...
processOutputs(obj, probs, correction, bboxes, netIdx)
function [scores, bboxes, landmarks] = ...
processOutputs(obj, probs, correction, bboxes, netIdx, landmarks)
% processOutputs Post-process the output values.
faceProbs = probs(:, 2);
bboxes = mtcnn.util.applyCorrection(bboxes, correction);
bboxes(faceProbs < obj.ConfidenceThresholds(netIdx), :) = [];
scores = faceProbs(faceProbs > obj.ConfidenceThresholds(netIdx));
scores = faceProbs(faceProbs >= obj.ConfidenceThresholds(netIdx));
if ~isempty(scores)
[bboxes, ~] = selectStrongestBbox(gather(bboxes), scores, ...
[bboxes, scores, index] = selectStrongestBbox(gather(bboxes), scores, ...
"RatioType", "Min", ...
"OverlapThreshold", obj.NmsThresholds(netIdx));
if netIdx == 3
landmarks = landmarks(index, :, :);
end
end
end

Expand Down
2 changes: 1 addition & 1 deletion test/+tests/DetectorTest.m
Original file line number Diff line number Diff line change
Expand Up @@ -131,4 +131,4 @@ function testGpuDetect(test, imageTypeConversion, useDagNet)
test.verifyEqual(landmarks, test.Reference.landmarks, "RelTol", 1e-1);
end
end
end
end
29 changes: 29 additions & 0 deletions test/+tests/RegressionTest.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
classdef RegressionTest < matlab.unittest.TestCase
% Test cases for known bugs that have been fixed

% Copyright 2020 The MathWorks, Inc.

methods (Test)
function testSelectStrongestBug(test)
% GitHub issue #17
im = imread("visionteam1.jpg");

[bboxes, scores, landmarks] = mtcnn.detectFaces(im, "ConfidenceThresholds", repmat(0.01, [3, 1]));
for iBox = 1:size(bboxes, 1)
test.assertInBox(landmarks(iBox, :, :), bboxes(iBox, :));
end
end
end

methods
function assertInBox(test, landmarks, box)
% check that all landmarks are within the bounding box
tf = all(inpolygon(landmarks(1, :, 1), ...
landmarks(1, :, 2), ...
[box(1), box(1) + box(3), box(1) + box(3), box(1)], ...
[box(2), box(2), box(2) + box(4), box(2) + box(4)]));
test.assertTrue(tf, "Landmarks should all be inside bounding box");
end
end

end

0 comments on commit 3ec4d03

Please sign in to comment.