Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions code/mtcnn/+mtcnn/Detector.m
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,7 @@
%
% See also: mtcnn.detectFaces

if obj.UseGPU()
im = gpuArray(single(im));
end
im = obj.prepImage(im);

bboxes = [];
scores = [];
Expand Down Expand Up @@ -103,7 +101,7 @@
end

%% Stage 2 - Refinement
[cropped, bboxes] = obj.prepImages(im, bboxes, obj.RnetSize);
[cropped, bboxes] = obj.prepBbox(im, bboxes, obj.RnetSize);
[probs, correction] = mtcnn.rnet(cropped, obj.RnetWeights);
[scores, bboxes] = obj.processOutputs(probs, correction, bboxes, 2);

Expand All @@ -112,7 +110,7 @@
end

%% Stage 3 - Output
[cropped, bboxes] = obj.prepImages(im, bboxes, obj.OnetSize);
[cropped, bboxes] = obj.prepBbox(im, bboxes, obj.OnetSize);

% Adjust bboxes for the behaviour of imcrop
bboxes(:, 1:2) = bboxes(:, 1:2) - 0.5;
Expand Down Expand Up @@ -144,12 +142,12 @@ function loadWeights(obj)
obj.OnetWeights = load(fullfile(mtcnnRoot(), "weights", "onet.mat"));
end

function [cropped, bboxes] = prepImages(obj, im, bboxes, outputSize)
function [cropped, bboxes] = prepBbox(obj, im, bboxes, outputSize)
% prepImages Pre-process the images and bounding boxes.
bboxes = mtcnn.util.makeSquare(bboxes);
bboxes = round(bboxes);
cropped = mtcnn.util.cropImage(im, bboxes, outputSize);
cropped = dlarray(single(cropped)./255*2 - 1, "SSCB");
cropped = dlarray(cropped, "SSCB");

end

Expand All @@ -167,5 +165,29 @@ function loadWeights(obj)
"OverlapThreshold", obj.NmsThresholds(netIdx));
end
end

function outIm = prepImage(obj, im)
% convert the image to the correct scaling and type
% All images should be scaled to -1 to 1 and of single type
% also place on the GPU if required

switch class(im)
case "uint8"
outIm = single(im)/255*2 - 1;
case "single"
% expect floats to be 0-1 scaled
outIm = im*2 - 1;
case "double"
outIm = single(im)*2 - 1;
otherwise
error("mtcnn:Detector:UnsupportedType", ...
"Input image is of unsupported type '%s'", class(im));
end

if obj.UseGPU()
outIm = gpuArray(outIm);
end

end
end
end
4 changes: 2 additions & 2 deletions code/mtcnn/+mtcnn/proposeRegions.m
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
% proposeRegions Generate region proposals at a given scale.
%
% Args:
% im - Input image 0-255 range
% im - Input image -1 to 1 range, type single
% scale - Scale to run proposal at
% threshold - Confidence threshold to accept proposal
% weights - P-Net weights struct
Expand All @@ -19,7 +19,7 @@
pnetSize = 12;

im = imresize(im, 1/scale);
im = dlarray(single(im)./255*2 - 1, "SSCB");
im = dlarray(im, "SSCB");

[probability, correction] = mtcnn.pnet(im, weights);

Expand Down
16 changes: 13 additions & 3 deletions test/+tests/DetectorTest.m
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
Reference
end

properties (TestParameter)
imageTypeConversion = struct("uint8", @(x) x, ...
"single", @(x) single(x)/255, ...
"double", @(x) double(x)/255)
end

methods (TestClassSetup)
function setupTestImage(test)
test.Image = imread("visionteam.jpg");
Expand All @@ -26,10 +32,14 @@ function testCreate(test)
detector = mtcnn.Detector();
end

function testDetectwithDefaults(test)
function testDetectwithDefaults(test, imageTypeConversion)
% Test expected inputs with images of type uint8, single,
% double (float images are scaled 0-1);
detector = mtcnn.Detector();

[bboxes, scores, landmarks] = detector.detect(test.Image);
inputImage = imageTypeConversion(test.Image);

[bboxes, scores, landmarks] = detector.detect(inputImage);

test.verifyEqual(size(bboxes), [6, 4]);
test.verifyEqual(size(scores), [6, 1]);
Expand Down Expand Up @@ -118,4 +128,4 @@ function testGpuDetect(test)
test.verifyEqual(landmarks, test.Reference.landmarks, "RelTol", 1e-1);
end
end
end
end
10 changes: 10 additions & 0 deletions test/makeDetectionReference.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
function makeDetectionReference()
% Run the detector in known good config to create reference boxes,
% scores and landmarks for regression tests.
im = imread("visionteam.jpg");
[bboxes, scores, landmarks] = mtcnn.detectFaces(im);

filename = fullfile(mtcnnTestRoot(), "resources", "ref.mat");
save(filename, "bboxes", "scores", "landmarks");

end
Binary file modified test/resources/ref.mat
Binary file not shown.