/
surf_ann.m
80 lines (65 loc) · 1.98 KB
/
surf_ann.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
clear
trainingSet = LoadTrainSet();
%load('data/trainingSet_40_small.mat');
input = []; % the features and all samples
output = []; % which class is the correct answer of a given feature (in input)
%for i=1:size(trainingSet.class,2) % cycle through all classes
for i=1:10 % only image with 0-9
for j=1:size(trainingSet.class(i).image,2)
input = [input; trainingSet.class(i).image(j).features];
%temp = zeros(size(trainingSet.class(i).image(j).features,1),size(trainingSet.class,2)); % matrix nbClass x nbSamples
temp = zeros(size(trainingSet.class(i).image(j).features,1),10);
temp = temp';
temp(i,:) = 1;
temp = temp';
output = [output; temp];
end
end
input = input'; % features are lines, images are columns
output = output';
x = input;
t = output;
min = 1;
max = 0;
% Create a Pattern Recognition Network
hiddenLayerSize = 10;
for i=1:100
net = patternnet(hiddenLayerSize);
% Setup Division of Data for Training, Validation, Testing
net.divideParam.trainRatio = 70/100;
net.divideParam.valRatio = 15/100;
net.divideParam.testRatio = 15/100;
% Train the Network
[net,tr] = train(net,x,t);
model = struct('net',net,'tr',tr);
disp('--- Finished the training of the model for the algorithm ---');
% Test the Network
y = net(x);
e = gsubtract(t,y);
tind = vec2ind(t);
yind = vec2ind(y);
percentErrors = sum(tind ~= yind)/numel(tind);
percentCorrect = 1-percentErrors;
performance = perform(net,t,y);
if percentCorrect > max
max = percentCorrect;
end
if percentCorrect < min
min = percentCorrect;
end
if percentCorrect > 0.9
figure, plotconfusion(t,y)
end
end
min
max
% View the Network
%view(net)
% Plots
% Uncomment these lines to enable various plots.
%figure, plotperform(tr)
%figure, plottrainstate(tr)
%disp('--- Display the confusion matrix ---');
%figure, plotconfusion(t,y)
%figure, plotroc(t,y)
%figure, ploterrhist(e)