-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #11 from matlab-deep-learning/feature/pre-19b
Support R2019a
- Loading branch information
Showing
15 changed files
with
328 additions
and
54 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
*.mltbx | ||
code/mtcnn/weights/dag*.mat |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
classdef DagNetworkStrategy < handle | ||
|
||
properties (SetAccess=private) | ||
% Trained Dag networks | ||
Pnet | ||
Rnet | ||
Onet | ||
end | ||
|
||
methods | ||
function obj = DagNetworkStrategy() | ||
end | ||
|
||
function load(obj) | ||
% loadWeights Load the network weights from file. | ||
obj.Pnet = importdata(fullfile(mtcnnRoot(), "weights", "dagPNet.mat")); | ||
obj.Rnet = importdata(fullfile(mtcnnRoot(), "weights", "dagRNet.mat")); | ||
obj.Onet = importdata(fullfile(mtcnnRoot(), "weights", "dagONet.mat")); | ||
end | ||
|
||
function pnet = getPNet(obj) | ||
pnet = obj.Pnet; | ||
end | ||
|
||
function [probs, correction] = applyRNet(obj, im) | ||
output = obj.Rnet.predict(im); | ||
|
||
probs = output(:,1:2); | ||
correction = output(:,3:end); | ||
end | ||
|
||
function [probs, correction, landmarks] = applyONet(obj, im) | ||
output = obj.Onet.predict(im); | ||
|
||
probs = output(:,1:2); | ||
correction = output(:,3:6); | ||
landmarks = output(:,7:end); | ||
end | ||
|
||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
classdef DlNetworkStrategy < handle | ||
|
||
properties (SetAccess=private) | ||
UseGPU | ||
% Weights for the networks | ||
PnetWeights | ||
RnetWeights | ||
OnetWeights | ||
end | ||
|
||
methods | ||
function obj = DlNetworkStrategy(useGpu) | ||
obj.UseGPU = useGpu; | ||
end | ||
|
||
function load(obj) | ||
% loadWeights Load the network weights from file. | ||
obj.PnetWeights = load(fullfile(mtcnnRoot(), "weights", "pnet.mat")); | ||
obj.RnetWeights = load(fullfile(mtcnnRoot(), "weights", "rnet.mat")); | ||
obj.OnetWeights = load(fullfile(mtcnnRoot(), "weights", "onet.mat")); | ||
|
||
if obj.UseGPU | ||
obj.PnetWeights = dlupdate(@gpuArray, obj.PnetWeights); | ||
obj.RnetWeights = dlupdate(@gpuArray, obj.RnetWeights); | ||
obj.OnetWeights = dlupdate(@gpuArray, obj.OnetWeights); | ||
end | ||
end | ||
|
||
function pnet = getPNet(obj) | ||
pnet = obj.PnetWeights; | ||
end | ||
|
||
function [probs, correction] = applyRNet(obj, im) | ||
im = dlarray(im, "SSCB"); | ||
|
||
[probs, correction] = mtcnn.rnet(im, obj.RnetWeights); | ||
|
||
probs = extractdata(probs)'; | ||
correction = extractdata(correction)'; | ||
end | ||
|
||
function [probs, correction, landmarks] = applyONet(obj, im) | ||
im = dlarray(im, "SSCB"); | ||
|
||
[probs, correction, landmarks] = mtcnn.onet(im, obj.OnetWeights); | ||
|
||
probs = extractdata(probs)'; | ||
correction = extractdata(correction)'; | ||
landmarks = extractdata(landmarks)'; | ||
end | ||
|
||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
classdef (Abstract) NetworkStrategy < handle | ||
methods | ||
load(obj) | ||
pnet = getPNet(obj) | ||
[probs, correction] = applyRNet(obj, im) | ||
[probs, correction, landmarks] = applyONet(obj, im) | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
function net = convertToDagNet(stage) | ||
|
||
warnId = "deep:functionToLayerGraph:Placeholder"; | ||
warnState = warning('off', warnId); | ||
restoreWarn = onCleanup(@() warning(warnState)); | ||
|
||
switch stage | ||
case "p" | ||
inputSize = 12; | ||
nBlocks = 3; | ||
finalConnections = [sprintf("conv_%d", nBlocks), sprintf("prelu_%d", nBlocks)]; | ||
catConnections = ["sm_1", "conv_5"]; | ||
case "r" | ||
inputSize = 24; | ||
nBlocks = 4; | ||
finalConnections = [sprintf("prelu_%d", nBlocks-1), "fc_1"; | ||
"fc_1", sprintf("prelu_%d", nBlocks)]; | ||
catConnections = ["sm_1", "fc_3"]; | ||
case "o" | ||
inputSize = 48; | ||
nBlocks = 5; | ||
finalConnections = ["fc_1", sprintf("prelu_%d", nBlocks)]; | ||
catConnections = ["sm_1", "fc_3", "fc_4"]; | ||
otherwise | ||
error("mtcnn:convertToDagNet:unknownStage", ... | ||
"Stage '%s' is not recognised", stage) | ||
end | ||
|
||
matFilename = strcat(stage, "net.mat"); | ||
weightsFile = load(fullfile(mtcnnRoot, "weights", matFilename)); | ||
input = dlarray(zeros(inputSize, inputSize, 3, "single"), "SSCB"); | ||
|
||
switch stage | ||
case "p" | ||
netFunc = @(x) mtcnn.pnet(x, weightsFile); | ||
[a, b] = netFunc(input); | ||
output = cat(3, a, b); | ||
case "r" | ||
netFunc = @(x) mtcnn.rnet(x, weightsFile); | ||
[a, b] = netFunc(input); | ||
output = cat(1, a, b); | ||
case "o" | ||
netFunc = @(x) mtcnn.onet(x, weightsFile); | ||
[a, b, c] = netFunc(input); | ||
output = cat(1, a, b, c); | ||
end | ||
|
||
lgraph = functionToLayerGraph(netFunc, input); | ||
placeholders = findPlaceholderLayers(lgraph); | ||
lgraph = removeLayers(lgraph, {placeholders.Name}); | ||
|
||
for iPrelu = 1:nBlocks | ||
name = sprintf("prelu_%d", iPrelu); | ||
weightName = sprintf("features_prelu%d_weight", iPrelu); | ||
if iPrelu ~= nBlocks | ||
weights = weightsFile.(weightName); | ||
else | ||
weights = reshape(weightsFile.(weightName), 1, 1, []); | ||
end | ||
prelu = mtcnn.util.preluLayer(weights, name); | ||
lgraph = replaceLayer(lgraph, sprintf("plus_%d", iPrelu), prelu, "ReconnectBy", "order"); | ||
|
||
if iPrelu ~= nBlocks | ||
lgraph = connectLayers(lgraph, sprintf("conv_%d", iPrelu), sprintf("prelu_%d", iPrelu)); | ||
else | ||
% need to make different connections at the end of the | ||
% repeating blocks | ||
for iConnection = 1:size(finalConnections, 1) | ||
lgraph = connectLayers(lgraph, ... | ||
finalConnections(iConnection, 1), ... | ||
finalConnections(iConnection, 2)); | ||
end | ||
|
||
end | ||
end | ||
|
||
lgraph = addLayers(lgraph, imageInputLayer([inputSize, inputSize, 3], ... | ||
"Name", "input", ... | ||
"Normalization", "none")); | ||
lgraph = connectLayers(lgraph, "input", "conv_1"); | ||
|
||
lgraph = addLayers(lgraph, concatenationLayer(3, numel(catConnections), "Name", "concat")); | ||
for iConnection = 1:numel(catConnections) | ||
lgraph = connectLayers(lgraph, ... | ||
catConnections(iConnection), ... | ||
sprintf("concat/in%d", iConnection)); | ||
end | ||
lgraph = addLayers(lgraph, regressionLayer("Name", "output")); | ||
lgraph = connectLayers(lgraph, "concat", "output"); | ||
|
||
net = assembleNetwork(lgraph); | ||
result = net.predict(zeros(inputSize, inputSize, 3, "single")); | ||
|
||
difference = extractdata(sum(output - result', "all")); | ||
|
||
assert(difference < 1e-6, ... | ||
"mtcnn:convertToDagNet:outputMismatch", ... | ||
"Outputs of function and dag net do not match") | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
classdef preluLayer < nnet.layer.Layer | ||
% Example custom PReLU layer. | ||
% Taken from "Define Custom Deep Learning Layer with Learnable | ||
% Parameters" | ||
|
||
% Copyright 2020 The MathWorks, Inc. | ||
|
||
properties (Learnable) | ||
% Scaling coefficient | ||
Alpha | ||
end | ||
|
||
methods | ||
function layer = preluLayer(weights, name) | ||
% layer = preluLayer(numChannels, name) creates a PReLU layer | ||
% for 2-D image input with numChannels channels and specifies | ||
% the layer name. | ||
|
||
layer.Name = name; | ||
layer.Alpha = weights; | ||
end | ||
|
||
function Z = predict(layer, X) | ||
% Z = predict(layer, X) forwards the input data X through the | ||
% layer and outputs the result Z. | ||
Z = max(X,0) + layer.Alpha .* min(0,X); | ||
end | ||
|
||
function [dLdX, dLdAlpha] = backward(layer, X, ~, dLdZ, ~) | ||
dLdX = layer.Alpha .* dLdZ; | ||
dLdX(X>0) = dLdZ(X>0); | ||
dLdAlpha = min(0,X) .* dLdZ; | ||
dLdAlpha = sum(sum(dLdAlpha,1),2); | ||
|
||
% Sum over all observations in mini-batch. | ||
dLdAlpha = sum(dLdAlpha,4); | ||
end | ||
end | ||
end |
Oops, something went wrong.