Permalink
Browse files

Support for sparse matrices and added some tests.

Closes first issue(1): fixes #1
  • Loading branch information...
lianos committed Sep 10, 2011
1 parent dc47df4 commit 8d218a5575a7665cf7be061f2954f6f7ddcbb646
Showing with 164 additions and 27 deletions.
  1. +1 −0 NAMESPACE
  2. +2 −0 R/AllS4.R
  3. +42 −11 R/buckshot-data.R
  4. +2 −2 R/buckshot.R
  5. +25 −5 inst/tests/test-basic.R
  6. +86 −7 src/buckshot.cpp
  7. +6 −2 src/buckshot.h
View
@@ -18,6 +18,7 @@ exportClasses(
exportMethods(
buckshot,
BuckshotData,
+ designMatrix,
labels,
coef,
View
@@ -3,6 +3,8 @@
## ----------------------------------------------------------------------------
setClassUnion("ptrOrNULL", c("externalptr", "NULL"))
+setClassUnion("MatrixLike", c("matrix", "Matrix"))
+
##' Base object for Buckshot classes
setClass("BuckshotObject", contains="VIRTUAL",
representation=representation(cache="environment"))
View
@@ -47,47 +47,78 @@ function(x, data=NULL, ..., na.action=na.omit, scaled=TRUE) {
ret
})
-setMethod("BuckshotData", c(x="matrix"),
-function(x, y, scaled=TRUE, ...) {
- if (!is.numeric(x)) {
+preprocess.xy <- function(x, y) {
+ if (!is.numeric(x[1L])) {
stop("Only numeric data is supported")
}
+ if (!is.matrix(x) && !inherits(x, 'Matrix')) {
+ stop("x needs to be a matrix (or Matrix)")
+ }
if (missing(y) || !is.numeric(y)) {
stop("Numberic labels (y) are required")
}
- if (length(y) != nrow(x)) {
+ if (nrow(x) != length(y)) {
stop("Number of labels != number of observations")
}
- storage.mode(x) <- 'numeric'
- y <- as.numeric(y)
-
rm.cols <- which(colSums(x) == 0)
if (length(rm.cols) > 0L) {
warning("Removing ", length(rm.cols), " columns from design matrix ",
"since they are all 0s. See ?BuckshotData for more info.")
x <- x[, -rm.cols]
}
- ret <- .Call("create_shotgun_data_dense", x, y, PACKAGE="buckshot")
- new("BuckshotData", ptr=ret$ptr, dim=dim(x), nnz=ret$nnz, rm.cols=rm.cols)
+ list(x=x, y=as.numeric(y), rm.cols=rm.cols)
+}
+
+setMethod("BuckshotData", c(x="matrix"),
+function(x, y, scaled=TRUE, ...) {
+ dat <- preprocess.xy(x, y)
+ storage.mode(dat$x) <- 'numeric'
+ ret <- .Call("create_shotgun_data_dense", dat$x, dat$y, PACKAGE="buckshot")
+ new("BuckshotData", ptr=ret$ptr, dim=dim(dat$x), nnz=ret$nnz,
+ rm.cols=dat$rm.cols)
})
setMethod("BuckshotData", c(x="Matrix"),
function(x, y, ...) {
BuckshotData(as(x, "matrix"), y, ...)
})
+setMethod("BuckshotData", c(x="TsparseMatrix"),
+function(x, y, ...) {
+ dat <- preprocess.xy(x, y)
+ ret <- .Call("create_shotgun_data_tsparse", as.numeric(dat$x@x),
+ dat$x@i, dat$x@j, nrow(dat$x), ncol(dat$x), dat$y,
+ PACKAGE="buckshot")
+ new("BuckshotData", ptr=ret$ptr, dim=dim(x), nnz=ret$nnz,
+ rm.cols=dat$rm.cols)
+})
+
setMethod("BuckshotData", c(x="CsparseMatrix"),
function(x, y, ...) {
- stop("TODO: BuckshotData,sparseMatrix")
+ return(BuckshotData(as(x, "TsparseMatrix"), y, ...))
+
+ ## TODO: Figure out how to properly index CsparseMatrix in order to skip copy
+ validate.xy(x, y)
rm.cols <- which(colSums(x) == 0)
if (length(rm.cols) > 0L) {
x <- x[, -rm.cols]
}
- ret <- .Call("create_shotgun_data_csparse", x)
+ vals <- as.numeric(x@x)
+ y <- as.numeric(y)
+
+ ## x@x : The non-zero values of the matrix, column order
+ ## x@p : 0-based-index position of the *first* element the n-th column
+ ## x@i : 0-based row index for the i-5h value in @x
+ ## x@Dimnames : length(2) of dimnames
+ ## x@factors : empty list (of wut?)
+ ## x@Dim : length 2 integer vector of row,col dims
+
+ ret <- .Call("create_shotgun_data_csparse", as.numeric(x@x), x@i, x@p,
+ nrow(x), ncol(x), y, PACKAGE="buckshot")
new("BuckshotData", ptr=ret$ptr, dim=dim(x), nnz=ret$nnz, rm.cols=rm.cols)
})
View
@@ -18,7 +18,7 @@ function(x, data=NULL, type='lasso', ..., na.action=na.omit, scaled=TRUE) {
buckshot(bdata, type=type, ...)
})
-setMethod("buckshot", c(x="matrix"),
+setMethod("buckshot", c(x="MatrixLike"),
function(x, y, type='lasso', na.action=na.omit, scaled=TRUE, ...) {
type <- buckshot:::matchLearningAlgo(type)
if (missing(y)) {
@@ -115,7 +115,7 @@ function(object, newdata=NULL, type="decision", ...) {
}
}
- y <- newdata %*% x
+ y <- as.vector(newdata %*% x)
if (object@type == 'logistic' && type == 'decision') {
y <- sign(y)
View
@@ -1,8 +1,28 @@
context("basic tests")
+data(arcene, package="buckshot")
-test_that("logistic regression works on arcene data", {
- data(arcene, package="buckshot")
- model <- buckshot(A.arcene, y.arcene, 'logistic', lambda=0.166)
- accuracy <- sum(predict(model, A.arcene) == y.arcene) / length(y.arcene)
- expect_equal(accuracy, 1)
+test_that("logistic regression on dense data", {
+ suppressWarnings({
+ model <- buckshot(A.arcene, y.arcene, 'logistic', lambda=0.166)
+ bingo <- all(predict(model, A.arcene) == y.arcene)
+ })
+ expect_true(bingo, info="Dense model")
+})
+
+test_that("logistic regression on sparse data", {
+ ms <- Matrix(A.arcene, sparse=TRUE)
+ suppressWarnings({
+ msparse <- buckshot(ms, y.arcene, 'logistic', lambda=0.166)
+ bingo.sparse <- all(predict(msparse, ms) == y.arcene)
+ })
+ expect_true(bingo.sparse, info="Sparse model")
+})
+
+test_that("warnings emmited when design matrices have all-zero columns", {
+ expect_warning({
+ bd <- BuckshotData(A.arcene, y.arcene)
+ }, "Removing.*?69", info="Warn on building data object")
+
+ model <- buckshot(bd, lambda=0.166)
+ expect_warning(predict(model, A.arcene), "remove", info="warn on predict")
})
View
@@ -49,14 +49,84 @@ END_RCPP
}
SEXP
-create_shotgun_data_csparse(SEXP matrix_, SEXP nnz_, SEXP nrows_, SEXP ncols_,
- SEXP labels_) {
+create_shotgun_data_csparse(SEXP vals_, SEXP rows_, SEXP cols_, SEXP nrows_,
+ SEXP ncols_, SEXP labels_) {
BEGIN_RCPP
- int nnz = Rcpp::as<int>(nnz_);
+ throw std::runtime_error("Use TsparseMatrix instead");
+ Rcpp::NumericVector vals(vals_);
+ Rcpp::IntegerVector rows(rows_);
+ Rcpp::IntegerVector cols(cols_);
+ int nrows = Rcpp::as<int>(nrows_);
+ int ncols = Rcpp::as<int>(ncols_);
+ Rcpp::NumericVector labels(labels_);
+ int nnz = vals.size();
int nnz_out = 0;
+
+ SEXP ptr;
+ Rcpp::List out;
+
+ shotgun_data *prob = new shotgun_data;
+ prob->A_rows.resize(nrows);
+ prob->A_cols.resize(ncols);
+ prob->y.resize(nrows);
+ prob->nx = ncols;
+ prob->ny = nrows;
+
+ int prevrow = -1;
+ int row;
+ int colidx = 0;
+ int col = cols[0];
+ double val;
+
+ Rprintf("nnz: %d\n", nnz);
+
+ for (int i = 0; i < nnz; i++) {
+ row = rows[i];
+ val = vals[i];
+ if (row < prevrow) {
+ while (col == cols[colidx]) {
+ // Guard against all-zero rows (shouldn't happen)
+ // Rprintf("col, colidx, cols[colidx] : %d %d %d\n",
+ // col, colidx, cols[colidx]);
+ colidx++;
+ }
+ col = colidx;
+ Rprintf(" col: %d, i: %d\n", col, i);
+ }
+ prob->A_cols[col].add(row, val);
+ prob->A_rows[row].add(col, val);
+ nnz_out++;
+ prevrow = row;
+ }
+
+ for (int i = 0; i < nrows; i++) {
+ prob->y[i] = labels[i];
+ }
+
+ ptr = R_MakeExternalPtr(prob, R_NilValue, R_NilValue);
+ R_RegisterCFinalizer(ptr, shotgun_data_finalizer);
+
+ out = Rcpp::List::create(
+ Rcpp::Named("nnz") = Rcpp::wrap(nnz_out),
+ Rcpp::Named("ptr") = ptr);
+
+ return out;
+END_RCPP
+}
+
+SEXP
+create_shotgun_data_tsparse(SEXP vals_, SEXP rows_, SEXP cols_, SEXP nrows_,
+ SEXP ncols_, SEXP labels_) {
+BEGIN_RCPP
+ Rcpp::NumericVector vals(vals_);
+ Rcpp::IntegerVector rows(rows_);
+ Rcpp::IntegerVector cols(cols_);
int nrows = Rcpp::as<int>(nrows_);
int ncols = Rcpp::as<int>(ncols_);
Rcpp::NumericVector labels(labels_);
+ int nnz = vals.size();
+ int nnz_out = 0;
+
SEXP ptr;
Rcpp::List out;
@@ -67,11 +137,20 @@ BEGIN_RCPP
prob->nx = ncols;
prob->ny = nrows;
+ int row;
+ int col;
+ double val;
+
+ for (int i = 0; i < nnz; i++) {
+ row = rows[i];
+ col = cols[i];
+ val = vals[i];
+
+ prob->A_cols[col].add(row, val);
+ prob->A_rows[row].add(col, val);
+ nnz_out++;
+ }
- // TODO: Fill in sparse matrix stuff
- // for () {
- // nnz_out++;
- // }
for (int i = 0; i < nrows; i++) {
prob->y[i] = labels[i];
}
View
@@ -21,8 +21,12 @@ RcppExport SEXP
create_shotgun_data_dense(SEXP matrix_, SEXP labels_);
RcppExport SEXP
-create_shotgun_data_csparse(SEXP matrix_, SEXP nnz_, SEXP nrows_, SEXP ncols_,
- SEXP labels_);
+create_shotgun_data_csparse(SEXP vals_, SEXP rows_, SEXP cols_,
+ SEXP nrows_, SEXP ncols_, SEXP labels_);
+
+RcppExport SEXP
+create_shotgun_data_tsparse(SEXP vals_, SEXP rows_, SEXP cols_,
+ SEXP nrows_, SEXP ncols_, SEXP labels_);
RcppExport SEXP
shotgun_data_labels(SEXP prob_);

0 comments on commit 8d218a5

Please sign in to comment.