diff --git a/code/mtcnn/+mtcnn/Detector.m b/code/mtcnn/+mtcnn/Detector.m index 2cc7df9..89f36fa 100644 --- a/code/mtcnn/+mtcnn/Detector.m +++ b/code/mtcnn/+mtcnn/Detector.m @@ -70,9 +70,7 @@ % % See also: mtcnn.detectFaces - if obj.UseGPU() - im = gpuArray(single(im)); - end + im = obj.prepImage(im); bboxes = []; scores = []; @@ -103,7 +101,7 @@ end %% Stage 2 - Refinement - [cropped, bboxes] = obj.prepImages(im, bboxes, obj.RnetSize); + [cropped, bboxes] = obj.prepBbox(im, bboxes, obj.RnetSize); [probs, correction] = mtcnn.rnet(cropped, obj.RnetWeights); [scores, bboxes] = obj.processOutputs(probs, correction, bboxes, 2); @@ -112,7 +110,7 @@ end %% Stage 3 - Output - [cropped, bboxes] = obj.prepImages(im, bboxes, obj.OnetSize); + [cropped, bboxes] = obj.prepBbox(im, bboxes, obj.OnetSize); % Adjust bboxes for the behaviour of imcrop bboxes(:, 1:2) = bboxes(:, 1:2) - 0.5; @@ -144,12 +142,12 @@ function loadWeights(obj) obj.OnetWeights = load(fullfile(mtcnnRoot(), "weights", "onet.mat")); end - function [cropped, bboxes] = prepImages(obj, im, bboxes, outputSize) + function [cropped, bboxes] = prepBbox(obj, im, bboxes, outputSize) % prepImages Pre-process the images and bounding boxes. bboxes = mtcnn.util.makeSquare(bboxes); bboxes = round(bboxes); cropped = mtcnn.util.cropImage(im, bboxes, outputSize); - cropped = dlarray(single(cropped)./255*2 - 1, "SSCB"); + cropped = dlarray(cropped, "SSCB"); end @@ -167,5 +165,29 @@ function loadWeights(obj) "OverlapThreshold", obj.NmsThresholds(netIdx)); end end + + function outIm = prepImage(obj, im) + % convert the image to the correct scaling and type + % All images should be scaled to -1 to 1 and of single type + % also place on the GPU if required + + switch class(im) + case "uint8" + outIm = single(im)/255*2 - 1; + case "single" + % expect floats to be 0-1 scaled + outIm = im*2 - 1; + case "double" + outIm = single(im)*2 - 1; + otherwise + error("mtcnn:Detector:UnsupportedType", ... + "Input image is of unsupported type '%s'", class(im)); + end + + if obj.UseGPU() + outIm = gpuArray(outIm); + end + + end end end \ No newline at end of file diff --git a/code/mtcnn/+mtcnn/proposeRegions.m b/code/mtcnn/+mtcnn/proposeRegions.m index e2ff283..0f9c7bd 100644 --- a/code/mtcnn/+mtcnn/proposeRegions.m +++ b/code/mtcnn/+mtcnn/proposeRegions.m @@ -2,7 +2,7 @@ % proposeRegions Generate region proposals at a given scale. % % Args: -% im - Input image 0-255 range +% im - Input image -1 to 1 range, type single % scale - Scale to run proposal at % threshold - Confidence threshold to accept proposal % weights - P-Net weights struct @@ -19,7 +19,7 @@ pnetSize = 12; im = imresize(im, 1/scale); - im = dlarray(single(im)./255*2 - 1, "SSCB"); + im = dlarray(im, "SSCB"); [probability, correction] = mtcnn.pnet(im, weights); diff --git a/test/+tests/DetectorTest.m b/test/+tests/DetectorTest.m index 7dffdb0..462e59a 100644 --- a/test/+tests/DetectorTest.m +++ b/test/+tests/DetectorTest.m @@ -8,6 +8,12 @@ Reference end + properties (TestParameter) + imageTypeConversion = struct("uint8", @(x) x, ... + "single", @(x) single(x)/255, ... + "double", @(x) double(x)/255) + end + methods (TestClassSetup) function setupTestImage(test) test.Image = imread("visionteam.jpg"); @@ -26,10 +32,14 @@ function testCreate(test) detector = mtcnn.Detector(); end - function testDetectwithDefaults(test) + function testDetectwithDefaults(test, imageTypeConversion) + % Test expected inputs with images of type uint8, single, + % double (float images are scaled 0-1); detector = mtcnn.Detector(); - [bboxes, scores, landmarks] = detector.detect(test.Image); + inputImage = imageTypeConversion(test.Image); + + [bboxes, scores, landmarks] = detector.detect(inputImage); test.verifyEqual(size(bboxes), [6, 4]); test.verifyEqual(size(scores), [6, 1]); @@ -118,4 +128,4 @@ function testGpuDetect(test) test.verifyEqual(landmarks, test.Reference.landmarks, "RelTol", 1e-1); end end -end \ No newline at end of file + end diff --git a/test/makeDetectionReference.m b/test/makeDetectionReference.m new file mode 100644 index 0000000..0f96a2d --- /dev/null +++ b/test/makeDetectionReference.m @@ -0,0 +1,10 @@ +function makeDetectionReference() + % Run the detector in known good config to create reference boxes, + % scores and landmarks for regression tests. + im = imread("visionteam.jpg"); + [bboxes, scores, landmarks] = mtcnn.detectFaces(im); + + filename = fullfile(mtcnnTestRoot(), "resources", "ref.mat"); + save(filename, "bboxes", "scores", "landmarks"); + +end \ No newline at end of file diff --git a/test/resources/ref.mat b/test/resources/ref.mat index 585811a..11c595b 100644 Binary files a/test/resources/ref.mat and b/test/resources/ref.mat differ