-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
121 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
|
||
import numpy as np | ||
|
||
|
||
def create_rank_k_dataset( | ||
n_rows=5, | ||
n_cols=5, | ||
k=3, | ||
fraction_missing=0.1, | ||
symmetric=False, | ||
random_seed=0): | ||
np.random.seed(random_seed) | ||
x = np.random.randn(n_rows, k) | ||
y = np.random.randn(k, n_cols) | ||
|
||
XY = np.dot(x, y) | ||
|
||
if symmetric: | ||
assert n_rows == n_cols | ||
XY = 0.5 * XY + 0.5 * XY.T | ||
|
||
missing_raw_values = np.random.uniform(0, 1, (n_rows, n_cols)) | ||
missing_mask = missing_raw_values < fraction_missing | ||
|
||
XY_incomplete = XY.copy() | ||
# fill missing entries with NaN | ||
XY_incomplete[missing_mask] = np.nan | ||
|
||
return XY, XY_incomplete, missing_mask | ||
|
||
|
||
# create some default data to be shared across tests | ||
XY, XY_incomplete, missing_mask = create_rank_k_dataset( | ||
n_rows=500, | ||
n_cols=10, | ||
k=3, | ||
fraction_missing=0.25) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from time import time | ||
import numpy as np | ||
from nose.tools import eq_ | ||
|
||
from knnimpute import ( | ||
all_pairs_normalized_distances, | ||
all_pairs_normalized_distances_reference | ||
) | ||
from low_rank_data import XY_incomplete | ||
|
||
def test_normalized_distance_same_results_as_reference_implementation(): | ||
|
||
D_reference = all_pairs_normalized_distances_reference(XY_incomplete) | ||
D_fast = all_pairs_normalized_distances(XY_incomplete) | ||
|
||
eq_(D_fast.shape, D_reference.shape) | ||
|
||
assert not np.isnan(D_reference).any(), "NaN in distance matrix" | ||
assert not np.isnan(D_fast).any(), "NaN in distance matrix" | ||
|
||
reference_finite_mask = np.isfinite(D_reference) | ||
fast_finite_mask = np.isfinite(D_fast) | ||
n_inf_reference = (~reference_finite_mask).sum() | ||
n_inf_fast = (~fast_finite_mask).sum() | ||
print("# infinity reference=%d fast=%d" % ( | ||
n_inf_reference, | ||
n_inf_fast)) | ||
eq_(n_inf_reference, n_inf_fast) | ||
|
||
assert (reference_finite_mask == fast_finite_mask).all() | ||
|
||
finite_diffs = (D_fast[fast_finite_mask] - D_reference[reference_finite_mask]) | ||
abs_diff = np.abs(finite_diffs) | ||
print(np.where(abs_diff > 0.1)) | ||
mae = np.mean(abs_diff) | ||
print("MAE", mae) | ||
assert mae < 0.0001, \ | ||
"Difference between distance matrices (MAE=%0.4f)" % mae | ||
|
||
def test_normalized_distance_faster_than_reference_implementation(): | ||
start_t = time() | ||
all_pairs_normalized_distances(XY_incomplete, verbose=False) | ||
fast_t = time() - start_t | ||
start_t = time() | ||
all_pairs_normalized_distances_reference(XY_incomplete) | ||
reference_t = time() - start_t | ||
print("Fast implementation: %0.2fs, reference implementation: %0.2fs" % ( | ||
fast_t, reference_t)) | ||
assert reference_t / fast_t > 2, "Expected 2x performance gain" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import numpy as np | ||
from nose.tools import eq_ | ||
|
||
from knnimpute import ( | ||
knn_impute_few_observed, | ||
knn_impute_with_argpartition, | ||
knn_impute_optimistic, | ||
knn_impute_reference, | ||
) | ||
from low_rank_data import XY_incomplete, missing_mask | ||
|
||
def _knn_implementation(impute_fn): | ||
X_filled_reference = knn_impute_reference(XY_incomplete.copy(), missing_mask, k=3) | ||
X_filled_other = impute_fn(XY_incomplete.copy(), missing_mask, k=3) | ||
eq_(X_filled_reference.shape, X_filled_other.shape) | ||
diff = X_filled_reference - X_filled_other | ||
abs_diff = np.abs(diff) | ||
mae = np.mean(abs_diff) | ||
assert mae < 0.1, \ | ||
"Difference between imputed values! MAE=%0.4f, 1st rows: %s vs. %s" % ( | ||
mae, | ||
X_filled_reference[0], | ||
X_filled_other[0] | ||
) | ||
|
||
def test_knn_argpartition_same_as_reference(): | ||
_knn_implementation(knn_impute_with_argpartition) | ||
|
||
|
||
def test_knn_optimistic_same_as_reference(): | ||
_knn_implementation(knn_impute_optimistic) | ||
|
||
|
||
def test_knn_optimistic_few_observed(): | ||
_knn_implementation(knn_impute_few_observed) |