-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_aux_sketch.m
110 lines (101 loc) · 3.08 KB
/
main_aux_sketch.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
%%
% path = '/h2/yijun/Documents/MATLAB/RandNLA/CUR/';
% path_target = '/h2/yijun/Documents/MATLAB/RandNLA/dataset/';
% path_target = '/Users/ydong/Documents/MATLAB/OdenUT/RandNLA/dataset/';
% path = '/Users/ydong/Documents/MATLAB/OdenUT/RandNLA/CUR/';
% tag = 'large';
% tag = 'snn-1e3-1e3_a2b1_k100_r1e3_s1e-3';
tag = 'yaleface-64x64';
% tag = 'mnist-train';
% target = load(fullfile(path_target, sprintf('target_%s',tag)));
%%
% n = 1000;
% ks = 30:30:339; % snn-1e3-1e3_a2b1_k100_r1e3_s1e-3
% ks = 10:15:164; % yaleface-64x64
% ks = 50:50:713; % mnist-train
% ks = 40:40:400; % large
% embeds = {'gauss','srft','sparse3'};
% repeat = 10;
% test_sketch_rank(tag, ks, embeds, path, path_target, repeat);
%%
out = load(fullfile(pwd, sprintf('%s_%s', 'sketch-rank', tag)));
%%
markers = {'o','s','d','^','v','>','<','p','h','+','*','x'};
legmap = struct('gauss','Gauss',...
'srft','SRFT',...
'sparse3','Sparse sign $\zeta=3$');
embeds = out.embeds;
labels = arrayfun(@(i) legmap.(embeds{i}), 1:length(embeds), 'UniformOutput',false);
%% time - error
% xdata = out.times;
% ydata = out.errfs;
%
% figure()
% e = embeds{1};
% loglog(xdata.(e), ydata.(e), strcat(markers{1},'-'), 'LineWidth', 1.5)
% hold on
% for eidx = 2:length(embeds)
% e = embeds{eidx};
% plot(xdata.(e), ydata.(e), strcat(markers{eidx},'-'), 'LineWidth', 1.5)
% end
% hold off
% legend(labels{:}, 'interpreter', 'latex')
% ylabel('$||A - Q Q^T A||_2$', 'interpreter', 'latex')
% xlabel('time (sec)', 'interpreter', 'latex')
% set(gca,'fontsize',16)
%% err-rank
sig2 = out.sigma;
sigf = sqrt(cumsum(sig2.^2,'reverse'));
if strcmpi(tag,'yaleface-64x64')
over = 5;
else
over = 10;
end
fs = 20;
figure()
xdata = out.ks;
ydata = out.err2s;
semilogy(xdata, sig2(xdata-over), strcat('.','-'), 'LineWidth', 1.5)
hold on
for eidx = 1:length(embeds)
e = embeds{eidx};
plot(xdata, ydata.(e), strcat(markers{eidx},'-'), 'LineWidth', 1.5)
end
hold off
% legend('$\sigma_{k+1}$', labels{:}, 'interpreter', 'latex')
ylabel('$||A - Q Q^T A||_2$', 'interpreter', 'latex')
xlabel('rank', 'interpreter', 'latex')
xlim([xdata(1), xdata(end)])
set(gca,'fontsize',fs)
figure()
xdata = out.ks;
ydata = out.errfs;
semilogy(xdata, sigf(xdata-over), strcat('.','-'), 'LineWidth', 1.5)
hold on
for eidx = 1:length(embeds)
e = embeds{eidx};
plot(xdata, ydata.(e), strcat(markers{eidx},'-'), 'LineWidth', 1.5)
end
hold off
legend('$\sqrt{\sum_{i > k} \sigma_i^2}$', labels{:}, 'interpreter', 'latex')
ylabel('$||A - Q Q^T A||_F$', 'interpreter', 'latex')
xlabel('rank', 'interpreter', 'latex')
xlim([xdata(1), xdata(end)])
set(gca,'fontsize',fs)
% figure()
% xdata = out.ks;
% ydata = out.times;
% e = embeds{1};
% semilogy(xdata, ydata.(e), strcat(markers{1},'-'), 'LineWidth', 1.5)
% hold on
% for eidx = 2:length(embeds)
% e = embeds{eidx};
% plot(xdata, ydata.(e), strcat(markers{eidx},'-'), 'LineWidth', 1.5)
% end
% hold off
% legend(labels{:}, 'interpreter', 'latex')
% ylabel('time (sec)', 'interpreter', 'latex')
% xlabel('rank', 'interpreter', 'latex')
% xlim([xdata(1), xdata(end)])
% set(gca,'fontsize',fs)
%%