-
Notifications
You must be signed in to change notification settings - Fork 0
/
glcic_add_conv_transpose_block.m
63 lines (61 loc) · 2.08 KB
/
glcic_add_conv_transpose_block.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
% add a Convolutional + BatchNorm + ReLU block
% input:
% net: a network structure
% lastAdded: a structure has two properties:{var,depth},var is the name
% of last added var,depth is the channel of last layer
% name: name of the block
% output:
% net:
% lastAdded
function [net, lastAdded] = glcic_add_conv_transpose_block(net, opts, lastAdded, name, ksize, upsample, depth, varargin)
% Helper function to add a Convolutional + BatchNorm + ReLU
% sequence to the network.
args.relu = true ;
args.bias = true ;
args.bn = true;
args = vl_argparse(args, varargin) ;
if args.bias
pars = {[name '_f'], [name '_b']};
else
pars = {[name '_f']};
end
if mod(ksize, 2)~=1
error('ksize should be a odd')
end
if ksize < upsample
error('ksize should >= upsample')
end
crop_h = ksize - upsample;
crop_w = ksize - upsample;
crop_top = floor(crop_h/2);
crop_bottom = crop_h - crop_top;
crop_left = floor(crop_w/2);
crop_right = crop_w - crop_left;
% addLayer(name, block, inputs, outputs, params, varargin)
net.addLayer([name '_conv_transpose'], ...
dagnn.ConvTranspose('size', [ksize ksize depth lastAdded.depth], ...
'upsample', upsample, ...
'crop', [crop_top, crop_bottom, crop_left, crop_right], ...
'hasBias', args.bias, ...
'opts', {'cudnnworkspacelimit', opts.cudnnWorkspaceLimit}), ...
lastAdded.var, ...
[name '_conv_transpose'], ...
pars) ;
lastAdded.var = [name '_conv_transpose'];
lastAdded.depth = depth ;
if args.bn
net.addLayer([name '_bn'], ...
dagnn.BatchNorm('numChannels', depth, 'epsilon', 1e-5), ...
lastAdded.var, ...
[name '_bn'], ...
{[name '_bn_w'], [name '_bn_b'], [name '_bn_m']}) ;
lastAdded.var = [name '_bn'] ;
end
if args.relu
net.addLayer([name '_relu'] , ...
dagnn.ReLU('leak', 0.2), ...
lastAdded.var, ...
[name '_relu']) ;
lastAdded.var = [name '_relu'] ;
end
end