Skip to content

Commit

Permalink
Proper fix issue #88
Browse files Browse the repository at this point in the history
  • Loading branch information
lmcinnes committed Jul 20, 2018
1 parent f606dda commit 3ee1c5e
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions umap/umap_.py
Original file line number Diff line number Diff line change
Expand Up @@ -1413,20 +1413,24 @@ def fit(self, X, y=None):

if callable(self.metric):
self._distance_func = self.metric
elif self.metric in dist.named_distances or self.metric == 'precomputed':
elif self.metric in dist.named_distances:
self._distance_func = dist.named_distances[self.metric]
elif self.metric == 'precomputed':
warn('Using precomputed metric; transform will be unavailable for new data')
else:
raise ValueError(
"Metric is neither callable, " + "nor a recognised string"
)
self._dist_args = tuple(self._metric_kwds.values())

self._random_init, self._tree_init = make_initialisations(
self._distance_func, self._dist_args
)
self._search = make_initialized_nnd_search(
self._distance_func, self._dist_args
)
if self.metric != 'precomputed':
self._dist_args = tuple(self._metric_kwds.values())

self._random_init, self._tree_init = make_initialisations(
self._distance_func, self._dist_args
)
self._search = make_initialized_nnd_search(
self._distance_func, self._dist_args
)

if y is not None:
if self.target_metric == "categorical":
Expand Down Expand Up @@ -1542,6 +1546,9 @@ def transform(self, X):

if self._sparse_data:
raise ValueError("Transform not available for sparse input.")
elif self.metric == 'precomputed':
raise ValueError("Transform of new data not available for "
"precomputed metric.")

X = check_array(X, dtype=np.float32, order="C")
random_state = check_random_state(self.transform_seed)
Expand Down

0 comments on commit 3ee1c5e

Please sign in to comment.