-
Notifications
You must be signed in to change notification settings - Fork 1
/
nfold.m
70 lines (61 loc) · 2.33 KB
/
nfold.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
function [auc_res,aupr_res]=nfold(Y,Sd,St,pred_fn,nr_fold,seed,cv_setting,use_WKNKN,K,eta,use_W_matrix)
%nfold is a helper function of crossValidation.m. Depending on the
%specified CV setting (or scenario) and supplied "seed", it divides the
%interaction matrix into "nr_fold" folds, performs a cross validation
%experiment and then reports the results (AUPR/AUC) back to
%crossValidation.m.
%
% INPUT:
% Y: matrix to be modified
% Sd: pairwise row similarities matrix
% St: pairwise column similarities matrix
% pred_fn: function/script of algorithm to be used for MDA prediction
% nr_fold: number of folds in cross validation experiment
% seed: seed used for random sampling
% cv_setting: miRNA prediction case or Disease prediction case
%
% OUTPUT:
% auc_res: AUC result
% aupr_res: AUPR result
%
if strcmp(cv_setting,'cv_p')
len = numel(Y);
end
rng('default')
rng(seed);
rand_ind = randperm(len);
% TIME PRINT ----------------------------
fprintf('n-fold experiment start: \t'); timeprint();
% ---------------------------------------
AUCs = zeros(1,nr_fold);
AUPRs = zeros(1,nr_fold);
for i=1:nr_fold
% leave out random miRNA/Disease pairs
if strcmp(cv_setting,'cv_p')
test_ind = rand_ind((floor((i-1)*len/nr_fold)+1:floor(i*len/nr_fold))');
left_out = test_ind;
end
% predict with test set being left out
y2 = Y;
y2(test_ind) = 0; % test set = ZERO
fprintf('****');
y3 = pred_fn(y2,Sd,St,cv_setting,nr_fold,left_out,use_WKNKN,K,eta,use_W_matrix); % predict!
% compute evaluation metrics based on obtained prediction scores
[AUCs(i), AUPRs(i)] = returnEvaluationMetrics(Y(test_ind)',y3(test_ind)');
fprintf('%.3g\t\t\t\tTIME: ',AUPRs(i)); timeprint();
diary off; diary on;
end
% TIME PRINT ----------------------------
fprintf('n-fold experiment end: \t'); timeprint();
% ---------------------------------------
auc_res = mean(AUCs);
aupr_res = mean(AUPRs);
fprintf('\n');
fprintf(' AUC: %g\n', auc_res);
fprintf(' AUPR: %g\n', aupr_res);
disp('==========================');
end
function timeprint()
clk = clock;
fprintf('%g : %g : %g %s\n',clk(4:6),date);
end