/
linear_pop_risk.m
150 lines (106 loc) · 3.7 KB
/
linear_pop_risk.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
clear
close all
clc
addpath('./functions/');
addpath('./shaded_plots/');
%% hyperparameters
T = 100;
Nm = 20;
N = 20;
d = 1;
s = .5;
%% generate / load data
TNds.T = T; TNds.N = N; TNds.d = d; TNds.s = s;
dataname = get_dataname(TNds);
filename = ['./data/', dataname, '.mat'];
if exist(filename, 'file')
load(filename);
disp(['load data from ', filename]);
else
dataname = generate_data_trn_val(TNds);
load(filename);
disp(['generate and save data ', filename]);
end
Q_T = datagen_para.Q_T;
var_T = datagen_para.var_T;
theta_gt_T = datagen_para.theta_gt_T;
%% compute risks
Id = eye(d);
A_er_T = ones(T, d, d);
A_ma_T = ones(T, d, d);
A_bi_T = ones(T, d, d);
A_ba_T = ones(T, d, d);
% method hyper
len_hyper = 100;
alpha_all = 0:0.02:2 - 0.02;
gamma_all = 10.^(-5:0.1:4.9);
R_ma = ones(1, len_hyper);
R_bi = ones(1, len_hyper);
R_ba = ones(1, len_hyper);
for hyper_idx = 1:len_hyper
alpha = alpha_all(hyper_idx);
gamma = gamma_all(hyper_idx);
for i = 1:T
A_er_T(i, :, :) = (1 / var_T(i)) * Q_T(i, :, :);
A_ma_T(i, :, :) = (Id - alpha * squeeze(A_er_T(i, :, :))) * squeeze(A_er_T(i, :, :)) * ...
(Id - alpha * squeeze(A_er_T(i, :, :)));
A_bi_T(i, :, :) = inv(Id + (1 / gamma) * squeeze(A_er_T(i, :, :))) * ...
squeeze(A_er_T(i, :, :)) / (Id + (1 / gamma) * squeeze(A_er_T(i, :, :)));
A_ba_T(i, :, :) = inv(Id + (1 / gamma / s) * squeeze(A_er_T(i, :, :))) * ...
squeeze(A_er_T(i, :, :)) / (Id + (1 / gamma) * squeeze(A_er_T(i, :, :))); % gamma * Q_T(i, :, :);
end
% MAML
theta0_ma = theta0_fun(theta_gt_T, A_ma_T);
R_ma(hyper_idx) = pop_risk_fun(theta0_ma, theta_gt_T, A_ma_T);
% biMAML
theta0_bi = theta0_fun(theta_gt_T, A_bi_T);
R_bi(hyper_idx) = pop_risk_fun(theta0_bi, theta_gt_T, A_bi_T);
% baMAML
theta0_ba = theta0_fun(theta_gt_T, A_ba_T);
R_ba(hyper_idx) = pop_risk_fun(theta0_ba, theta_gt_T, A_ba_T);
end
% ERM
theta0_er = theta0_fun(theta_gt_T, A_er_T);
R_er_ = pop_risk_fun(theta0_er, theta_gt_T, A_er_T);
R_er = R_er_ * ones(1, len_hyper);
%% save results
save('./results/linear_pop_risk.mat', ...
'alpha_all', 'gamma_all', 'R_er', 'R_ma', 'R_ba');
%% plot figure
% population risks vs alpha gamma
load('./results/linear_pop_risk.mat')
define_color;
color_er = color_k;
color_ma = color_g;
color_ba = color_b;
color_bi = color_r;
fontsize = 12;
set(0, 'defaultTextFontName', 'Times')
set(0, 'defaultAxesFontName', 'Times')
set(0, 'defaultAxesFontSize', 12)
linewidth = 2;
f = figure;
f.Position = [10 10 280 300];
t = tiledlayout(1, 1);
ax1 = axes(t);
l1 = line('parent', ax1, 'xdata', alpha_all, 'ydata', R_ma, 'color', color_ma, 'LineStyle', '-', 'LineWidth', linewidth);
l3 = line('parent', ax1, 'xdata', alpha_all, 'ydata', R_er, 'color', color_er, 'LineStyle', '--', 'LineWidth', linewidth);
ax1.XAxisLocation = 'top';
ax1.Color = 'none';
ax1.YColor = 'k';
ax2 = axes(t);
l4 = line('parent', ax2, 'xdata', log10(gamma_all), 'ydata', R_bi, 'color', color_bi, 'LineStyle', '-', 'LineWidth', linewidth);
l2 = line('parent', ax2, 'xdata', log10(gamma_all), 'ydata', R_ba, 'color', color_ba, 'LineStyle', '-', 'LineWidth', linewidth);
ax2.XAxisLocation = 'bottom';
ax2.YAxisLocation = 'left';
ax2.Color = 'none';
ax1.XColor = color_g;
ax2.XColor = color_b;
ax1.Box = 'on';
ax2.Box = 'off';
linkaxes([ax1, ax2], 'y')
xlabel(ax1, '\alpha', 'FontSize', fontsize)
xlabel(ax2, '$\log_{10} (\gamma)$', 'FontSize', fontsize, 'Interpreter', 'latex')
ylabel(ax2, 'Optimal population risk', 'FontSize', fontsize)
legend([l3; l1; l4; l2], {'ERM', 'MAML','iMAML', 'BaMAML'}, 'FontSize', fontsize, 'Location', 'northwest', 'color','none');
exportgraphics(t, './figures/pop_risks.pdf')