Skip to content

Commit

Permalink
Fix bugs.
Browse files Browse the repository at this point in the history
  • Loading branch information
mblondel committed Dec 13, 2011
1 parent 8c81446 commit 0659d43
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
11 changes: 7 additions & 4 deletions svmlight_loader.py
Expand Up @@ -7,12 +7,14 @@

import os.path

import numpy as np
import scipy.sparse as sp

from _svmlight_loader import _load_svmlight_file


def load_svmlight_file(file_path, n_features=None, buffer_mb=40):
def load_svmlight_file(file_path, n_features=None, dtype=np.float64,
buffer_mb=40):
"""Load datasets in the svmlight / libsvm format into sparse CSR matrix
This format is a text-based format, with one sample per line. It does
Expand Down Expand Up @@ -56,12 +58,13 @@ def load_svmlight_file(file_path, n_features=None, buffer_mb=40):
else:
shape = None # inferred

data = np.array(data, dtype=dtype)
X_train = sp.csr_matrix((data, indices, indptr), shape)

return (X_train, labels)


def load_svmlight_files(files, n_features=None, buffer_mb=40):
def load_svmlight_files(files, n_features=None, dtype=np.float64, buffer_mb=40):
"""Load dataset from multiple files in SVMlight format
This function is equivalent to mapping load_svmlight_file over a list of
Expand Down Expand Up @@ -99,11 +102,11 @@ def load_svmlight_files(files, n_features=None, buffer_mb=40):
load_svmlight_file
"""
files = iter(files)
result = list(load_svmlight_file(files.next(), n_features, dtype))
result = list(load_svmlight_file(files.next(), n_features, dtype, buffer_mb))
n_features = result[0].shape[1]

for f in files:
result += load_svmlight_file(f, n_features, buffer_mb)
result += load_svmlight_file(f, n_features, dtype, buffer_mb)

return result

Expand Down
7 changes: 4 additions & 3 deletions tests/test_svmlight_loader.py
Expand Up @@ -5,8 +5,9 @@
from numpy.testing import assert_equal, assert_array_equal
from nose.tools import raises

from sklearn.datasets import (load_svmlight_file, load_svmlight_files,
dump_svmlight_file)
from svmlight_loader import (load_svmlight_file, load_svmlight_files,
dump_svmlight_file)
from sklearn.datasets import load_svmlight_file as sk_load_svmlight_file

currdir = os.path.dirname(os.path.abspath(__file__))
datafile = os.path.join(currdir, "data", "svmlight_classification.txt")
Expand Down Expand Up @@ -102,6 +103,6 @@ def test_dump():
f = StringIO()
dump_svmlight_file(X, y, f)
f.seek(0)
X2, y2 = load_svmlight_file(f)
X2, y2 = sk_load_svmlight_file(f)
assert_array_equal(Xd, X2.toarray())
assert_array_equal(y, y2)

0 comments on commit 0659d43

Please sign in to comment.