Skip to content

Commit

Permalink
Merge pull request #74 from ivirshup/fix-transformer
Browse files Browse the repository at this point in the history
Start fix for tranformer return shape
  • Loading branch information
lmcinnes committed Sep 5, 2019
2 parents ec8a539 + eb676a5 commit 546f710
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 4 deletions.
6 changes: 2 additions & 4 deletions pynndescent/pynndescent_.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,7 +1024,7 @@ def transform(self, X, y=None):
Returns
-------
Xt : CSR sparse matrix, shape (n_samples_fit, n_samples_transform)
Xt : CSR sparse matrix, shape (n_samples_transform, n_samples_fit)
Xt[i, j] is assigned the weight of edge that connects i to j.
Only the neighbors have an explicit value.
"""
Expand All @@ -1041,9 +1041,7 @@ def transform(self, X, y=None):
X, k=self.n_neighbors, queue_size=self.search_queue_size
)

result = lil_matrix(
(n_samples_transform, n_samples_transform), dtype=np.float32
)
result = lil_matrix((n_samples_transform, self.n_samples_fit), dtype=np.float32)
result.rows = indices
result.data = distances

Expand Down
23 changes: 23 additions & 0 deletions pynndescent/tests/test_pynndescent_.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,29 @@ def test_sparse_nn_descent_query_accuracy():
)


def test_transformer_equivalence():
N_NEIGHBORS = 15
QUEUE_SIZE = 5.0
train = nn_data[:400]
test = nn_data[:200]

nnd = NNDescent(data=train, n_neighbors=N_NEIGHBORS, random_state=42)
indices, dists = nnd.query(test, k=N_NEIGHBORS, queue_size=QUEUE_SIZE)
sort_idx = np.argsort(indices, axis=1)
indices_sorted = np.vstack(
[indices[i, sort_idx[i]] for i in range(sort_idx.shape[0])]
)
dists_sorted = np.vstack([dists[i, sort_idx[i]] for i in range(sort_idx.shape[0])])

transformer = PyNNDescentTransformer(
n_neighbors=N_NEIGHBORS, search_queue_size=QUEUE_SIZE, random_state=42
).fit(train)
Xt = transformer.transform(test).sorted_indices()

assert np.all(Xt.indices == indices_sorted.flat)
assert np.allclose(Xt.data, dists_sorted.flat)


def test_random_state_none():
knn_indices, _ = NNDescent(
nn_data, "euclidean", {}, 10, random_state=None
Expand Down

0 comments on commit 546f710

Please sign in to comment.