Skip to content

Commit

Permalink
supported ensembles for libsvm (radial kernel)
Browse files Browse the repository at this point in the history
  • Loading branch information
gtesei committed Oct 1, 2015
1 parent 2820248 commit b6cd638
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 7 deletions.
3 changes: 2 additions & 1 deletion R-package/DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,5 @@ Imports:
plyr,
xgboost,
magrittr,
stringr
stringr,
e1071
18 changes: 12 additions & 6 deletions R-package/R/fastRegression.R
Original file line number Diff line number Diff line change
Expand Up @@ -973,7 +973,7 @@ ff.createEnsemble = function(Xtrain,
test_i = Xtrain[ indexOut[[i]] , ]

## classification
if (! regression) {
if (! regression && caretModelName != "libsvm" ) {
y_i = y.cat[ index[[i]] ]
}

Expand All @@ -995,7 +995,10 @@ ff.createEnsemble = function(Xtrain,
}


if (! is.null(bestTune) ) {
if (caretModelName == "libsvm" && ! regression) {
model = e1071::svm(x = train_i , y = y_i , kernel = "radial" , gamma = bestTune$gamma , cost = bestTune$C)

} else if (! is.null(bestTune) ) {
model <- caret::train(y = y_i, x = train_i ,
method = caretModelName,
tuneGrid = bestTune,
Expand All @@ -1013,7 +1016,7 @@ ff.createEnsemble = function(Xtrain,

##
ret = NULL
if (regression) {
if ( regression || (! regression && caretModelName == "libsvm") ) {
ret = predict(model,test_i)
} else {
ret = predict(model,test_i,type = "prob")[,fact.sign]
Expand Down Expand Up @@ -1074,13 +1077,16 @@ ff.createEnsemble = function(Xtrain,
}

ytrain = NULL
if (regression) {
if (regression || (! regression && caretModelName == "libsvm") ) {
ytrain = y
} else {
ytrain = y.cat
}

if (! is.null(bestTune) ) {
if (caretModelName == "libsvm" && ! regression) {
model = e1071::svm(x = Xtrain , y = ytrain , kernel = "radial" , gamma = bestTune$gamma , cost = bestTune$C)

} else if (! is.null(bestTune) ) {
model <- caret::train(y = ytrain, x = Xtrain ,
method = caretModelName,
tuneGrid = bestTune,
Expand All @@ -1098,7 +1104,7 @@ ff.createEnsemble = function(Xtrain,

##
predTest = NULL
if (regression) {
if (regression || (! regression && caretModelName == "libsvm") ) {
predTest = predict(model,Xtest)
} else {
predTest = predict(model,Xtest,type = "prob")[,fact.sign]
Expand Down

0 comments on commit b6cd638

Please sign in to comment.