-
Notifications
You must be signed in to change notification settings - Fork 1
/
ConvSparseLearning2.m
136 lines (113 loc) · 5.32 KB
/
ConvSparseLearning2.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
if exist([dictionaryPath filesep str '.mat'], 'file');
load([dictionaryPath filesep str]);
return;
end
poolstride = poolstride2;
vneighbors = vneighbors2;
maxepoch = 10;
dgap = 10;
if exist('tiedparams', 'var')
clear tiedparams;
end
tiedparams.lambda = 0; % sparsity regularization
tiedparams.alpha = 0; % L1 decay
tiedparams.beta = 0; % L2 decay
tiedparams.epsilonw = 1e-3; % for unnormalized MNIST 1e-3
tiedparams.epsilonb = 1e-4; % for unnormalized MNIST 1e-3
tiedparams.epsilono = 1e-4; % for unnormalized MNIST 1e-2
tiedparams.momentum = 0.5;
tiedparams.momentumf = 0.99;
params = tiedparams;
noutputmaps = numfeatures2;
randConnect = wiredConnect;
kernels = InitialKernels(kernelsize2, noutputmaps, randConnect);
count = 0;
tiedflag = 1;
plf = 0;
for pl = 1:length(net.layers)
if strcmpi(net.layers{pl}.type,'conv')
plf = plf + 1;
if plf == 2, break; end
end
end
padflag = net.layers{pl}.pad;
ngroup = nummaps/randConnect;
groups = noutputmaps/ngroup; % number of filters in each group
for i = 1:noutputmaps
ind = ceil(i/groups);
cnt(:,i) = (ind - 1) * randConnect + 1 : ind * randConnect;
end
params.winc = zeros(size(kernels,1), randConnect * groups);
params = repmat(params,ngroup,1);
hbias = zeros(noutputmaps,1);
obias = zeros(ngroup,1);
derror = cell(ngroup,1);
dmerror = cell(ngroup,1);
nloop = numTrains;
if nloop > 5000, nloop = 5000; end;
for epoch = 1:maxepoch
fprintf('epoch: %d/%d\n', epoch, maxepoch);
index = randperm(numTrains);
%%% load data from disk
if strcmpi(datasetName, 'Caltech101') | strcmpi(datasetName, 'Caltech256')
utrain = zeros(mapsize(1)+1,mapsize(2)+1,mapsize(3),nloop);
for ut = 1:nloop
load(database.path{index(ut)});
if any(padflag)
map = padarray(map,[1,1],'post');
end
utrain(:,:,:, ut) = map;
end
elseif strcmpi(datasetName,'mnist')
utrain = map(:,:,:,index);
end
for gup = 1:ngroup
gkernels{gup} = kernels(:, (gup - 1) * randConnect * groups + 1 : gup * randConnect * groups);
ghbias{gup} = hbias((gup - 1) * groups + 1 : gup * groups,1);
end
for gup = 1:ngroup
valindin = (gup - 1) * randConnect + 1 : gup * randConnect;
tcount = (epoch - 1) * nloop;
for loop = 1:nloop
tcount = tcount + 1;
if strcmpi(datasetName, 'CIFAR10') | strcmpi(datasetName, 'MNIST');
inputdata = squeeze(map(:,:,valindin,index(loop)));
else
inputdata = squeeze(utrain(:,:,valindin,loop));
end
[gkernels{gup}, ghbias{gup}, obias(gup), params(gup), ri, error] = TiedRecstConvNets(...
inputdata, acttype, gkernels{gup},...
ghbias{gup}, obias(gup), params(gup), poolstride, tiedflag, vneighbors);
% [kernels, hbias, obias, params, ri, error] = TiedRecstConvNets2(inputdata, acttype, kernels,...
% hbias, obias, params, cnt, poolstride, tiedflag, vneighbors);
if tcount < 200
params(gup).momentum = tcount/200 * params(gup).momentumf + (1 - tcount/200) * params(gup).momentum;
else
params(gup).momentum = params(gup).momentumf;
end
%%%
derror{gup} = [derror{gup}, error];
if tcount > 10;
dmerror{gup} = [dmerror{gup}, sum(derror{gup}(end-dgap+1:end))/dgap];
% if ~mod(tcount,100), figure(1); plot(dmerror{gup}); drawnow; end
end
% if ~mod(tcount, 100)
% figure(2); display_network(gkernels{gup}); %subplot(122), display_network(dkernels); title 'Kernels'
% figure(3); display_network(reshape(inputdata,size(inputdata,1)^2,[])); title 'Original image'%imagesc(img); title 'Original image'
% figure(4); display_network(reshape(ri,size(ri,1)^2,[])); title 'Reconstructed image'%imagesc(img+E); title 'Reconstructed image'
% figure(5); display_network(reshape(inputdata - ri,size(inputdata,1)^2,[])); title 'Residue image'%imagesc(-E); title 'Residue image'
% end
end
end
kernels = cat(2, gkernels{:});
hbias = cat(1, ghbias{:});
%%% update learning params, it seems to update these params in every
%%% loop works better.
for gup = 1:ngroup
params(gup).epsilonw = 0.95^epoch * params(gup).epsilonw;
params(gup).epsilonb = 0.95^epoch * params(gup).epsilonb;
params(gup).epsilono = 0.95^epoch * params(gup).epsilono;
end
end
% save results
save([dictionaryPath filesep str], 'kernels','hbias','obias','poolstride','vneighbors','tiedparams','dmerror', '-v7.3');