-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathkCNN.r
106 lines (98 loc) · 3.39 KB
/
kCNN.r
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
## The k conditional nearest neighbor approach (kcnn)
## This function is for the kcnn classification.
## The output of the function is a list that includes
## the predicted classes for the test data and
## the estimated probability for each class.
# IMPORTANT:
# The train and test data are either matrix or data frame for the feature variables X.
# The cl argument takes the true classes of training set
kcnn <- function(train, test, cl, k = 1, r = NULL, eps = 0.0000001, ensemble = TRUE)
{
X_tr <- train <- as.matrix(train)
X_ts <- test <- as.matrix(test)
freq_label <- table(cl)
n_class <- length(freq_label)
label <- names(freq_label)
X_tr2 <- list(1)
for(j in 1:n_class)
{
ind <- which(cl==label[j])
X_tr2[[j]] <- X_tr[ind,]
}
p <- ncol(X_tr)
if(is.null(r)) r <- p
predict_kcnn <- numeric(nrow(test))
prob_kcnn <- matrix(0,nrow=nrow(test),ncol=(n_class))
result2 <- numeric(n_class)
for(i in 1:nrow(test))
{
result <- prob <- matrix(0,nrow=k,ncol=(n_class))
for(j in 1:n_class)
{
if(freq_label[j]>1) temp2 <- t(t(X_tr2[[j]])-X_ts[i,])
if(freq_label[j]>1) dx <- sqrt(apply(temp2*temp2,1,sum)) + eps
if(freq_label[j]==1) temp2 <- (X_tr2[[j]]-X_ts[i,])
if(freq_label[j]==1) dx <- sqrt(sum((temp2*temp2))) + eps
result[,j] <- sort(dx)[1:k]
ind <- order(dx)[1:k]
if(length(dx)==1) temp <- X_tr2[[j]]
if(length(dx)>1)
{
if(k==1) temp <- X_tr2[[j]][ind,]
if(k>1 && k<length(dx)) temp <- apply(X_tr2[[j]][ind,],2,mean)
if(k>1 && k>=length(dx)) temp <- apply(X_tr2[[j]],2,mean)
}
result2[j] <- sqrt(sum((temp-X_ts[i,])^2))
}
if(ensemble=="TRUE")
{
for(j in 1:k)
{
denom <- sum(k/result[j,]^(p/r),na.rm=TRUE)
if(denom>0) prob[j,] <- (k/result[j,]^(p/r))/denom
if(denom==0) prob[j,which.min(result[j,])] <- 1
ind <- which(is.na(prob[j,]))
prob[j,ind] <- 1
}
vote_prob <- prob_kcnn[i,] <- apply(prob,2,mean,na.rm=TRUE)
if(sum(max(vote_prob)==vote_prob)==1)
{
predict_kcnn[i] <- label[which.max(vote_prob)]
}
if(sum(max(vote_prob)==vote_prob)>1)
{
ind <- which(max(vote_prob)==vote_prob)
a <- sample(ind,1)
predict_kcnn[i] <- label[a]
}
}
if(ensemble=="FALSE")
{
denom <- sum(k/result[k,]^(p/r),na.rm=TRUE)
if(denom>0) prob_kcnn[i,] <- (k/result[k,]^(p/r))/denom
if(denom==0) prob_kcnn[i,which.min(result[k,])] <- 1
ind <- which(is.na(prob_kcnn[i,]))
prob_kcnn[i,ind] <- 1
if(length(ind)==1) prob_kcnn[i,ind] <- 1
if(length(ind)>1)
{
prob_kcnn[i,] <- rep(0,n_class)
ind2 <- which.max(k/result[k,])
prob_kcnn[i,ind2] <- 1
}
vote_prob <- prob_kcnn[i,]
if(sum(max(vote_prob,na.rm=TRUE)==vote_prob,na.rm=TRUE)==1)
{
predict_kcnn[i] <- label[which.max(vote_prob)]
}
if(sum(max(vote_prob,na.rm=TRUE)==vote_prob,na.rm=TRUE)>1)
{
ind <- which(max(vote_prob,na.rm=TRUE)==vote_prob)
a <- sample(ind,1)
predict_kcnn[i] <- label[a]
}
}
}
if(ensemble==FALSE) return(list(predict_kcnn=predict_kcnn,probability_kcnn=prob_kcnn))
if(ensemble==TRUE) return(list(predict_ekcnn=predict_kcnn,probability_ekcnn=prob_kcnn))
}