diff --git a/code/mtcnn/+mtcnn/+util/DagNetworkStrategy.m b/code/mtcnn/+mtcnn/+util/DagNetworkStrategy.m index ec7d049..e6cec7c 100644 --- a/code/mtcnn/+mtcnn/+util/DagNetworkStrategy.m +++ b/code/mtcnn/+mtcnn/+util/DagNetworkStrategy.m @@ -5,10 +5,16 @@ Pnet Rnet Onet + ExecutionEnvironment end methods - function obj = DagNetworkStrategy() + function obj = DagNetworkStrategy(useGpu) + if useGpu + obj.ExecutionEnvironment = "gpu"; + else + obj.ExecutionEnvironment = "cpu"; + end end function load(obj) @@ -18,19 +24,24 @@ function load(obj) obj.Onet = importdata(fullfile(mtcnnRoot(), "weights", "dagONet.mat")); end - function pnet = getPNet(obj) - pnet = obj.Pnet; + function [probability, correction] = applyPNet(obj, im) + % need to use activations as we don't know what size it will be + result = obj.Pnet.activations(im, "concat", ... + "ExecutionEnvironment", obj.ExecutionEnvironment); + + probability = result(:,:,1:2,:); + correction = result(:,:,3:end,:); end function [probs, correction] = applyRNet(obj, im) - output = obj.Rnet.predict(im); + output = obj.Rnet.predict(im, "ExecutionEnvironment", obj.ExecutionEnvironment); probs = output(:,1:2); correction = output(:,3:end); end function [probs, correction, landmarks] = applyONet(obj, im) - output = obj.Onet.predict(im); + output = obj.Onet.predict(im, "ExecutionEnvironment", obj.ExecutionEnvironment); probs = output(:,1:2); correction = output(:,3:6); diff --git a/code/mtcnn/+mtcnn/+util/DlNetworkStrategy.m b/code/mtcnn/+mtcnn/+util/DlNetworkStrategy.m index b5d3a97..8e375b4 100644 --- a/code/mtcnn/+mtcnn/+util/DlNetworkStrategy.m +++ b/code/mtcnn/+mtcnn/+util/DlNetworkStrategy.m @@ -26,8 +26,13 @@ function load(obj) end end - function pnet = getPNet(obj) - pnet = obj.PnetWeights; + function [probability, correction] = applyPNet(obj, im) + im = dlarray(im, "SSCB"); + + [probability, correction] = mtcnn.pnet(im, obj.PnetWeights); + + probability = extractdata(gather(probability)); + correction = extractdata(gather(correction)); end function [probs, correction] = applyRNet(obj, im) diff --git a/code/mtcnn/+mtcnn/Detector.m b/code/mtcnn/+mtcnn/Detector.m index 61f56a0..0510d85 100644 --- a/code/mtcnn/+mtcnn/Detector.m +++ b/code/mtcnn/+mtcnn/Detector.m @@ -49,7 +49,7 @@ end if obj.UseDagNet - obj.Networks = mtcnn.util.DagNetworkStrategy(); + obj.Networks = mtcnn.util.DagNetworkStrategy(obj.UseGPU); else obj.Networks = mtcnn.util.DlNetworkStrategy(obj.UseGPU); end @@ -89,7 +89,7 @@ [thisBox, thisScore] = ... mtcnn.proposeRegions(im, scale, ... obj.ConfidenceThresholds(1), ... - obj.Networks.getPNet()); + obj.Networks); bboxes = cat(1, bboxes, thisBox); scores = cat(1, scores, thisScore); end @@ -183,7 +183,7 @@ "Input image is of unsupported type '%s'", class(im)); end - if obj.UseGPU() + if obj.UseGPU && ~obj.UseDagNet outIm = gpuArray(outIm); end diff --git a/code/mtcnn/+mtcnn/proposeRegions.m b/code/mtcnn/+mtcnn/proposeRegions.m index 24daa85..592024d 100644 --- a/code/mtcnn/+mtcnn/proposeRegions.m +++ b/code/mtcnn/+mtcnn/proposeRegions.m @@ -1,4 +1,4 @@ -function [bboxes, scores] = proposeRegions(im, scale, threshold, weightsOrNet) +function [bboxes, scores] = proposeRegions(im, scale, threshold, networkStrategy) % proposeRegions Generate region proposals at a given scale. % % Args: @@ -13,14 +13,6 @@ % Copyright 2019 The MathWorks, Inc. - useDagNet = isa(weightsOrNet, "DAGNetwork"); - if isa(im, "gpuArray") - imClass = classUnderlying(im); - else - imClass = class(im); - end - assert(imClass == "single", "mtcnn:proposeRegions:wrongImageType", ... - "Input image should be a single scale -1 to 1"); % Stride of the proposal network stride = 2; @@ -28,19 +20,8 @@ pnetSize = 12; im = imresize(im, 1/scale); - - if useDagNet - % need to use activations as we don't know what size it will be - result = weightsOrNet.activations(im, "concat"); - probability = gather(result(:,:,1:2,:)); - correction = gather(result(:,:,3:end,:)); - else - im = dlarray(im, "SSCB"); - [probability, correction] = mtcnn.pnet(im, weightsOrNet); - probability = extractdata(gather(probability)); - correction = extractdata(gather(correction)); - end - + + [probability, correction] = networkStrategy.applyPNet(im); faces = probability(:,:,2) > threshold; if sum(faces, 'all') == 0 diff --git a/test/+tests/DetectorTest.m b/test/+tests/DetectorTest.m index d75aa25..3f11ab7 100644 --- a/test/+tests/DetectorTest.m +++ b/test/+tests/DetectorTest.m @@ -112,12 +112,14 @@ function testNmsThresholds(test) end %% GPU - function testGpuDetect(test) + function testGpuDetect(test, imageTypeConversion, useDagNet) + % filter if no GPU present test.assumeGreaterThan(gpuDeviceCount, 0); - detector = mtcnn.Detector("UseGPU", true); - [bboxes, scores, landmarks] = detector.detect(test.Image); + inputImage = imageTypeConversion(test.Image); + detector = mtcnn.Detector("UseGPU", true, "UseDagNet", useDagNet); + [bboxes, scores, landmarks] = detector.detect(inputImage); test.verifyEqual(size(bboxes), [6, 4]); test.verifyEqual(size(scores), [6, 1]); diff --git a/test/+tests/ProposeRegionsTest.m b/test/+tests/ProposeRegionsTest.m index f98b7ed..853b510 100644 --- a/test/+tests/ProposeRegionsTest.m +++ b/test/+tests/ProposeRegionsTest.m @@ -8,17 +8,18 @@ end properties (TestParameter) - getNet = struct("weights", @() load(fullfile(mtcnnRoot, "weights", "pnet.mat")), ... - "net", @() importdata(fullfile(mtcnnRoot, "weights", "dagPNet.mat"))); + getNet = struct("dl", @() mtcnn.util.DlNetworkStrategy(false) , ... + "dag", @() mtcnn.util.DagNetworkStrategy(false)); end methods (Test) function testOutputs(test, getNet) scale = 2; conf = 0.5; - weights = getNet(); + strategy = getNet(); + strategy.load(); - [box, score] = mtcnn.proposeRegions(test.Image, scale, conf, weights); + [box, score] = mtcnn.proposeRegions(test.Image, scale, conf, strategy); test.verifyOutputs(box, score); end @@ -29,9 +30,10 @@ function test1DActivations(test, getNet) cropped = imcrop(test.Image, [300, 42, 65, 38]); scale = 3; conf = 0.5; - weights = getNet(); + strategy = getNet(); + strategy.load(); - [box, score] = mtcnn.proposeRegions(cropped, scale, conf, weights); + [box, score] = mtcnn.proposeRegions(cropped, scale, conf, strategy); test.verifyOutputs(box, score); end