Permalink
Browse files

Add an efficient writer.

  • Loading branch information...
mblondel committed Apr 18, 2013
1 parent c4d3649 commit 02c227b61df3483e2e9bbf9f110796b15d2ae439
Showing with 90 additions and 36 deletions.
  1. +4 −2 README.rst
  2. +62 −1 _svmlight_loader.cpp
  3. +14 −22 svmlight_loader.py
  4. +10 −11 tests/test_svmlight_loader.py
View
@@ -3,7 +3,8 @@
svmlight-loader
===============
This is a fast and memory efficient loader for the svmlight / libsvm sparse data file format in Python.
This is a fast and memory efficient loader (and writer...) for the svmlight /
libsvm sparse data file format in Python.
Install
@@ -24,7 +25,8 @@ http://scikit-learn.org/dev/datasets/index.html#datasets-in-svmlight-libsvm-form
Public datasets
===============
Public datasets in svmlight / libsvm format available at http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/
Public datasets in svmlight / libsvm format available at
http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/
License
=======
View
@@ -308,6 +308,63 @@ static PyObject *load_svmlight_file(PyObject *self, PyObject *args)
}
}
static const char dump_svmlight_file_doc[] =
"Dump CSR matrix to a file in svmlight format.";
extern "C" {
static PyObject *dump_svmlight_file(PyObject *self, PyObject *args)
{
try {
// Read function arguments.
char const *file_path;
PyArrayObject *indices_array, *indptr_array, *data_array, *label_array;
int zero_based;
if (!PyArg_ParseTuple(args,
"sO!O!O!O!i",
&file_path,
&PyArray_Type, &data_array,
&PyArray_Type, &indices_array,
&PyArray_Type, &indptr_array,
&PyArray_Type, &label_array,
&zero_based))
return 0;
int n_samples = indptr_array->dimensions[0] - 1;
double *data = (double*) data_array->data;
int *indices = (int*) indices_array->data;
int *indptr = (int*) indptr_array->data;
double *y = (double*) label_array->data;
std::ofstream fout;
fout.open(file_path, std::ofstream::out);
int idx;
for (int i=0; i < n_samples; i++) {
fout << y[i] << " ";
for (int jj=indptr[i]; jj < indptr[i+1]; jj++) {
idx = indices[jj];
if (!zero_based)
idx++;
fout << idx << ":" << data[jj] << " ";

This comment has been minimized.

Show comment
Hide comment
@seamusabshere

seamusabshere Apr 18, 2013

i don't know c++ - how will idx and data[jj] be formatted?

@seamusabshere

seamusabshere Apr 18, 2013

i don't know c++ - how will idx and data[jj] be formatted?

This comment has been minimized.

Show comment
Hide comment
@mblondel

mblondel Apr 18, 2013

Owner

I don't know... but at least, the writer should be compatible with the loader, see https://github.com/mblondel/svmlight-loader/blob/master/_svmlight_loader.cpp#L235. @larsmans, any comment?

@mblondel

mblondel Apr 18, 2013

Owner

I don't know... but at least, the writer should be compatible with the loader, see https://github.com/mblondel/svmlight-loader/blob/master/_svmlight_loader.cpp#L235. @larsmans, any comment?

This comment has been minimized.

Show comment
Hide comment
@seamusabshere

seamusabshere Apr 18, 2013

FWIW, pull request scikit-learn/scikit-learn#1849 aimed to reduce the file size by using a specific printf type:

if X.dtype.kind == 'i':
    value_pattern = u("%d:%d")
else:
    value_pattern = u("%d:%.16g")

if y.dtype.kind == 'i':
    line_pattern = u("%d")
else:
    line_pattern = u("%.16g")

So if that was replicated in the C++, then you'd have something fast AND small.

@seamusabshere

seamusabshere Apr 18, 2013

FWIW, pull request scikit-learn/scikit-learn#1849 aimed to reduce the file size by using a specific printf type:

if X.dtype.kind == 'i':
    value_pattern = u("%d:%d")
else:
    value_pattern = u("%d:%.16g")

if y.dtype.kind == 'i':
    line_pattern = u("%d")
else:
    line_pattern = u("%.16g")

So if that was replicated in the C++, then you'd have something fast AND small.

This comment has been minimized.

Show comment
Hide comment
@mblondel

mblondel Apr 18, 2013

Owner

Since << is an overloaded method, the output will definitely be a double for data[jj] and an int for idx. But for double, I don't know what's the default string representation.

@mblondel

mblondel Apr 18, 2013

Owner

Since << is an overloaded method, the output will definitely be a double for data[jj] and an int for idx. But for double, I don't know what's the default string representation.

This comment has been minimized.

Show comment
Hide comment
@larsmans

larsmans Apr 18, 2013

Output formatting is controlled with <iomanip>, e.g. std::setprecision and std::fixed. I'm not sure what the default is exactly, I don't do much I/O in C++.

@larsmans

larsmans Apr 18, 2013

Output formatting is controlled with <iomanip>, e.g. std::setprecision and std::fixed. I'm not sure what the default is exactly, I don't do much I/O in C++.

}
fout << std::endl;
}
fout.close();
Py_INCREF(Py_None);
return Py_None;
} catch (std::exception const &e) {
std::string msg("error in SVMlight/libSVM writer: ");
msg += e.what();
PyErr_SetString(PyExc_RuntimeError, msg.c_str());
return 0;
}
}
}
/*
* Python module setup.
@@ -316,11 +373,15 @@ static PyObject *load_svmlight_file(PyObject *self, PyObject *args)
static PyMethodDef svmlight_format_methods[] = {
{"_load_svmlight_file", load_svmlight_file,
METH_VARARGS, load_svmlight_file_doc},
{"_dump_svmlight_file", dump_svmlight_file,
METH_VARARGS, dump_svmlight_file_doc},
{NULL, NULL, 0, NULL}
};
static const char svmlight_format_doc[] =
"Loader for svmlight / libsvm datasets - C++ helper routines";
"Loader/Writer for svmlight / libsvm datasets - C++ helper routines";
extern "C" {
PyMODINIT_FUNC init_svmlight_loader(void)
View
@@ -11,6 +11,7 @@
import scipy.sparse as sp
from _svmlight_loader import _load_svmlight_file
from _svmlight_loader import _dump_svmlight_file
def load_svmlight_file(file_path, n_features=None, dtype=None,
@@ -117,20 +118,6 @@ def load_svmlight_files(files, n_features=None, dtype=None, buffer_mb=40):
return result
def _dump_svmlight(X, y, f, zero_based):
if X.shape[0] != y.shape[0]:
raise ValueError("X.shape[0] and y.shape[0] should be the same, "
"got: %r and %r instead." % (X.shape[0], y.shape[0]))
is_sp = int(hasattr(X, "tocsr"))
one_based = not zero_based
for i in xrange(X.shape[0]):
s = u" ".join([u"%d:%f" % (j + one_based, X[i, j])
for j in X[i].nonzero()[is_sp]])
f.write((u"%f %s\n" % (y[i], s)).encode('ascii'))
def dump_svmlight_file(X, y, f, zero_based=True):
"""Dump the dataset in svmlight / libsvm file format.
@@ -142,23 +129,28 @@ def dump_svmlight_file(X, y, f, zero_based=True):
Parameters
----------
X : {array-like, sparse matrix}, shape = [n_samples, n_features]
X : CSR sparse matrix, shape = [n_samples, n_features]
Training vectors, where n_samples is the number of samples and
n_features is the number of features.
y : array-like, shape = [n_samples]
Target values.
f : str or file-like in binary mode
If string it specifies the path that will contain the data.
If f is a file-like then data will be written to f.
f : str
Specifies the path that will contain the data.
zero_based : boolean, optional
Whether column indices should be written zero-based (True) or one-based
(False).
"""
if hasattr(f, "write"):
_dump_svmlight(X, y, f, zero_based)
else:
with open(f, "wb") as f:
_dump_svmlight(X, y, f, zero_based)
raise ValueError("File handler not supported. Use a file path.")
if X.shape[0] != y.shape[0]:
raise ValueError("X.shape[0] and y.shape[0] should be the same, "
"got: %r and %r instead." % (X.shape[0], y.shape[0]))
X = sp.csr_matrix(X, dtype=np.float64)
y = np.array(y, dtype=np.float64)
_dump_svmlight_file(f, X.data, X.indices, X.indptr, y, int(zero_based))
@@ -1,6 +1,5 @@
import numpy as np
import os.path
from StringIO import StringIO
import os
from numpy.testing import assert_equal, assert_array_equal
from nose.tools import raises
@@ -96,13 +95,13 @@ def test_invalid_filename():
def test_dump():
Xs, y = load_svmlight_file(datafile)
Xd = Xs.toarray()
for X in (Xs, Xd):
f = StringIO()
dump_svmlight_file(X, y, f, zero_based=False)
f.seek(0)
X2, y2 = sk_load_svmlight_file(f)
assert_array_equal(Xd, X2.toarray())
try:
Xs, y = load_svmlight_file(datafile)
tmpfile = "tmp_dump.txt"
dump_svmlight_file(Xs, y, tmpfile, zero_based=False)
X2, y2 = sk_load_svmlight_file(tmpfile)
assert_array_equal(Xs.toarray(), X2.toarray())
assert_array_equal(y, y2)
finally:
os.remove(tmpfile)

0 comments on commit 02c227b

Please sign in to comment.