Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions code/mtcnn/+mtcnn/+util/DagNetworkStrategy.m
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,16 @@
Pnet
Rnet
Onet
ExecutionEnvironment
end

methods
function obj = DagNetworkStrategy()
function obj = DagNetworkStrategy(useGpu)
if useGpu
obj.ExecutionEnvironment = "gpu";
else
obj.ExecutionEnvironment = "cpu";
end
end

function load(obj)
Expand All @@ -18,19 +24,24 @@ function load(obj)
obj.Onet = importdata(fullfile(mtcnnRoot(), "weights", "dagONet.mat"));
end

function pnet = getPNet(obj)
pnet = obj.Pnet;
function [probability, correction] = applyPNet(obj, im)
% need to use activations as we don't know what size it will be
result = obj.Pnet.activations(im, "concat", ...
"ExecutionEnvironment", obj.ExecutionEnvironment);

probability = result(:,:,1:2,:);
correction = result(:,:,3:end,:);
end

function [probs, correction] = applyRNet(obj, im)
output = obj.Rnet.predict(im);
output = obj.Rnet.predict(im, "ExecutionEnvironment", obj.ExecutionEnvironment);

probs = output(:,1:2);
correction = output(:,3:end);
end

function [probs, correction, landmarks] = applyONet(obj, im)
output = obj.Onet.predict(im);
output = obj.Onet.predict(im, "ExecutionEnvironment", obj.ExecutionEnvironment);

probs = output(:,1:2);
correction = output(:,3:6);
Expand Down
9 changes: 7 additions & 2 deletions code/mtcnn/+mtcnn/+util/DlNetworkStrategy.m
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,13 @@ function load(obj)
end
end

function pnet = getPNet(obj)
pnet = obj.PnetWeights;
function [probability, correction] = applyPNet(obj, im)
im = dlarray(im, "SSCB");

[probability, correction] = mtcnn.pnet(im, obj.PnetWeights);

probability = extractdata(gather(probability));
correction = extractdata(gather(correction));
end

function [probs, correction] = applyRNet(obj, im)
Expand Down
6 changes: 3 additions & 3 deletions code/mtcnn/+mtcnn/Detector.m
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
end

if obj.UseDagNet
obj.Networks = mtcnn.util.DagNetworkStrategy();
obj.Networks = mtcnn.util.DagNetworkStrategy(obj.UseGPU);
else
obj.Networks = mtcnn.util.DlNetworkStrategy(obj.UseGPU);
end
Expand Down Expand Up @@ -89,7 +89,7 @@
[thisBox, thisScore] = ...
mtcnn.proposeRegions(im, scale, ...
obj.ConfidenceThresholds(1), ...
obj.Networks.getPNet());
obj.Networks);
bboxes = cat(1, bboxes, thisBox);
scores = cat(1, scores, thisScore);
end
Expand Down Expand Up @@ -183,7 +183,7 @@
"Input image is of unsupported type '%s'", class(im));
end

if obj.UseGPU()
if obj.UseGPU && ~obj.UseDagNet
outIm = gpuArray(outIm);
end

Expand Down
25 changes: 3 additions & 22 deletions code/mtcnn/+mtcnn/proposeRegions.m
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function [bboxes, scores] = proposeRegions(im, scale, threshold, weightsOrNet)
function [bboxes, scores] = proposeRegions(im, scale, threshold, networkStrategy)
% proposeRegions Generate region proposals at a given scale.
%
% Args:
Expand All @@ -13,34 +13,15 @@

% Copyright 2019 The MathWorks, Inc.

useDagNet = isa(weightsOrNet, "DAGNetwork");
if isa(im, "gpuArray")
imClass = classUnderlying(im);
else
imClass = class(im);
end
assert(imClass == "single", "mtcnn:proposeRegions:wrongImageType", ...
"Input image should be a single scale -1 to 1");

% Stride of the proposal network
stride = 2;
% Field of view of the proposal network in pixels
pnetSize = 12;

im = imresize(im, 1/scale);

if useDagNet
% need to use activations as we don't know what size it will be
result = weightsOrNet.activations(im, "concat");
probability = gather(result(:,:,1:2,:));
correction = gather(result(:,:,3:end,:));
else
im = dlarray(im, "SSCB");
[probability, correction] = mtcnn.pnet(im, weightsOrNet);
probability = extractdata(gather(probability));
correction = extractdata(gather(correction));
end


[probability, correction] = networkStrategy.applyPNet(im);

faces = probability(:,:,2) > threshold;
if sum(faces, 'all') == 0
Expand Down
8 changes: 5 additions & 3 deletions test/+tests/DetectorTest.m
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,14 @@ function testNmsThresholds(test)
end

%% GPU
function testGpuDetect(test)
function testGpuDetect(test, imageTypeConversion, useDagNet)

% filter if no GPU present
test.assumeGreaterThan(gpuDeviceCount, 0);

detector = mtcnn.Detector("UseGPU", true);
[bboxes, scores, landmarks] = detector.detect(test.Image);
inputImage = imageTypeConversion(test.Image);
detector = mtcnn.Detector("UseGPU", true, "UseDagNet", useDagNet);
[bboxes, scores, landmarks] = detector.detect(inputImage);

test.verifyEqual(size(bboxes), [6, 4]);
test.verifyEqual(size(scores), [6, 1]);
Expand Down
14 changes: 8 additions & 6 deletions test/+tests/ProposeRegionsTest.m
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,18 @@
end

properties (TestParameter)
getNet = struct("weights", @() load(fullfile(mtcnnRoot, "weights", "pnet.mat")), ...
"net", @() importdata(fullfile(mtcnnRoot, "weights", "dagPNet.mat")));
getNet = struct("dl", @() mtcnn.util.DlNetworkStrategy(false) , ...
"dag", @() mtcnn.util.DagNetworkStrategy(false));
end

methods (Test)
function testOutputs(test, getNet)
scale = 2;
conf = 0.5;
weights = getNet();
strategy = getNet();
strategy.load();

[box, score] = mtcnn.proposeRegions(test.Image, scale, conf, weights);
[box, score] = mtcnn.proposeRegions(test.Image, scale, conf, strategy);

test.verifyOutputs(box, score);
end
Expand All @@ -29,9 +30,10 @@ function test1DActivations(test, getNet)
cropped = imcrop(test.Image, [300, 42, 65, 38]);
scale = 3;
conf = 0.5;
weights = getNet();
strategy = getNet();
strategy.load();

[box, score] = mtcnn.proposeRegions(cropped, scale, conf, weights);
[box, score] = mtcnn.proposeRegions(cropped, scale, conf, strategy);

test.verifyOutputs(box, score);
end
Expand Down