Skip to content

Commit

Permalink
Added accuracy tests for c and n_trees variation.
Browse files Browse the repository at this point in the history
  • Loading branch information
maheshakya committed Jul 14, 2014
1 parent a7a5788 commit 6a0366a
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 4 deletions.
8 changes: 4 additions & 4 deletions sklearn/neighbors/lsh_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import itertools
from ..base import BaseEstimator
from ..utils.validation import safe_asarray
from sklearn.utils import check_random_state
from ..utils import check_random_state

from sklearn.random_projection import GaussianRandomProjection
from ..random_projection import GaussianRandomProjection

__all__ = ["LSHForest"]

Expand Down Expand Up @@ -105,7 +105,7 @@ class LSHForest(BaseEstimator):
lowerest hash length to be searched when candidate selection is
performed for nearest neighbors.
random_state: float, optional(default = 0)
random_state: float, optional(default = 1)
A random value to initialize random number generator.
Attributes
Expand Down Expand Up @@ -134,7 +134,7 @@ class LSHForest(BaseEstimator):
>>> lshf = LSHForest()
>>> lshf.fit(X)
LSHForest(c=50, hashing_algorithm='random_projections', lower_bound=4,
max_label_length=32, n_neighbors=1, n_trees=10, seed=None)
max_label_length=32, n_neighbors=1, n_trees=10, random_state=None)
>>> lshf.kneighbors(X[:5], n_neighbors=3, return_distance=True)
(array([[0, 1, 2],
Expand Down
87 changes: 87 additions & 0 deletions sklearn/neighbors/tests/test_lsh_forest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""
Testing for the Locality Sensitive Hashing Forest
module (sklearn.neighbors.LSHForest).
"""

# Author: Gilles Louppe

import numpy as np

from sklearn.utils.testing import assert_array_equal
from sklearn.utils.testing import assert_array_almost_equal
from sklearn.utils.testing import assert_equal
from sklearn.utils.testing import assert_raises
from sklearn.utils.testing import assert_greater
from sklearn.utils.testing import assert_less
from sklearn.utils.testing import assert_true
from sklearn.utils.testing import assert_warns

from sklearn.metrics import euclidean_distances
from sklearn.neighbors import LSHForest


def test_neighbors_accuracy_with_c():
"""Accuracy increases as `c` increases."""
c_values = np.array([10, 50, 250])
samples = 1000
dim = 50
n_iter = 10
n_points = 20
accuracies = np.zeros(c_values.shape[0], dtype=float)
X = np.random.rand(samples, dim)

for i in range(c_values.shape[0]):
lshf = LSHForest(c=c_values[i])
lshf.fit(X)
for j in range(n_iter):
point = X[np.random.randint(0, samples)]
neighbors = lshf.kneighbors(point, n_neighbors=n_points,
return_distance=False)
distances = euclidean_distances(point, X)
ranks = np.argsort(distances)[0, :n_points]

intersection = np.intersect1d(ranks, neighbors).shape[0]
ratio = intersection/float(n_points)
accuracies[i] = accuracies[i] + ratio

accuracies[i] = accuracies[i]/float(n_iter)

# Sorted accuracies should be equal to original accuracies
assert_array_equal(accuracies, np.sort(accuracies),
err_msg="Accuracies are not non-decreasing.")


def test_neighbors_accuracy_with_n_trees():
"""Accuracy increases as `n_trees` increases."""
n_trees = np.array([1, 10, 100])
samples = 1000
dim = 50
n_iter = 10
n_points = 20
accuracies = np.zeros(n_trees.shape[0], dtype=float)
X = np.random.rand(samples, dim)

for i in range(n_trees.shape[0]):
lshf = LSHForest(c=500, n_trees=n_trees[i])
lshf.fit(X)
for j in range(n_iter):
point = X[np.random.randint(0, samples)]
neighbors = lshf.kneighbors(point, n_neighbors=n_points,
return_distance=False)
distances = euclidean_distances(point, X)
ranks = np.argsort(distances)[0, :n_points]

intersection = np.intersect1d(ranks, neighbors).shape[0]
ratio = intersection/float(n_points)
accuracies[i] = accuracies[i] + ratio

accuracies[i] = accuracies[i]/float(n_iter)

# Sorted accuracies should be equal to original accuracies
assert_array_equal(accuracies, np.sort(accuracies),
err_msg="Accuracies are not non-decreasing.")


if __name__ == "__main__":
import nose
nose.runmodule()

0 comments on commit 6a0366a

Please sign in to comment.