Skip to content

Commit

Permalink
Merge pull request #863 from hndgzkn/fix_aligned_umap_update
Browse files Browse the repository at this point in the history
fix relations_dictionary problems which prevents from correctly updating aligned_umap
  • Loading branch information
lmcinnes committed Aug 21, 2023
2 parents 8e54853 + 885f184 commit 04bc761
Showing 1 changed file with 39 additions and 14 deletions.
53 changes: 39 additions & 14 deletions umap/aligned_umap.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import numba
from sklearn.base import BaseEstimator
from sklearn.utils import check_random_state, check_array
from sklearn.utils import check_array

from umap.sparse import arr_intersect as intersect1d
from umap.sparse import arr_union as union1d
Expand Down Expand Up @@ -314,13 +314,19 @@ def fit(self, X, y=None, **fit_params):

self.n_models_ = len(X)

if self.n_epochs is None:
self.n_epochs = 200

n_epochs = self.n_epochs

self.mappers_ = [
UMAP(
n_neighbors=get_nth_item_or_val(self.n_neighbors, n),
min_dist=get_nth_item_or_val(self.min_dist, n),
n_epochs=get_nth_item_or_val(self.n_epochs, n),
repulsion_strength=get_nth_item_or_val(self.repulsion_strength, n),
learning_rate=get_nth_item_or_val(self.learning_rate, n),
init=self.init,
spread=get_nth_item_or_val(self.spread, n),
negative_sample_rate=get_nth_item_or_val(self.negative_sample_rate, n),
local_connectivity=get_nth_item_or_val(self.local_connectivity, n),
Expand All @@ -346,11 +352,6 @@ def fit(self, X, y=None, **fit_params):
for n in range(self.n_models_)
]

if self.n_epochs is None:
n_epochs = 200
else:
n_epochs = self.n_epochs

window_size = fit_params.get("window_size", self.alignment_window_size)
relations = expand_relations(self.dict_relations_, window_size)

Expand Down Expand Up @@ -448,10 +449,21 @@ def update(self, X, y=None, **fit_params):
)

new_dict_relations = fit_params["relations"]
assert isinstance(new_dict_relations, dict)

X = check_array(X)

self.__dict__ = set_aligned_params(fit_params, self.__dict__, self.n_models_)

# We need n_components to be constant or this won't work
if type(self.n_components) in (list, tuple, np.ndarray):
raise ValueError("n_components must be a single integer, and cannot vary")

if self.n_epochs is None:
self.n_epochs = 200

n_epochs = self.n_epochs

new_mapper = UMAP(
n_neighbors=get_nth_item_or_val(self.n_neighbors, self.n_models_),
min_dist=get_nth_item_or_val(self.min_dist, self.n_models_),
Expand All @@ -460,6 +472,7 @@ def update(self, X, y=None, **fit_params):
self.repulsion_strength, self.n_models_
),
learning_rate=get_nth_item_or_val(self.learning_rate, self.n_models_),
init=self.init,
spread=get_nth_item_or_val(self.spread, self.n_models_),
negative_sample_rate=get_nth_item_or_val(
self.negative_sample_rate, self.n_models_
Expand All @@ -470,20 +483,30 @@ def update(self, X, y=None, **fit_params):
set_op_mix_ratio=get_nth_item_or_val(self.set_op_mix_ratio, self.n_models_),
unique=get_nth_item_or_val(self.unique, self.n_models_),
n_components=self.n_components,
metric=self.metric,
metric_kwds=self.metric_kwds,
low_memory=self.low_memory,
random_state=self.random_state,
angular_rp_forest=self.angular_rp_forest,
transform_queue_size=self.transform_queue_size,
target_n_neighbors=self.target_n_neighbors,
target_metric=self.target_metric,
target_metric_kwds=self.target_metric_kwds,
target_weight=self.target_weight,
transform_seed=self.transform_seed,
force_approximation_algorithm=self.force_approximation_algorithm,
verbose=self.verbose,
a=self.a,
b=self.b,
).fit(X, y)

self.n_models_ += 1
self.mappers_ += [new_mapper]

# TODO: We can likely make this more efficient and not recompute each time
self.dict_relations_ += [invert_dict(new_dict_relations)]
self.dict_relations_ += [new_dict_relations]

if self.n_epochs is None:
n_epochs = 200
else:
n_epochs = self.n_epochs
window_size = fit_params.get("window_size", self.alignment_window_size)
new_relations = expand_relations(self.dict_relations_, window_size)

indptr_list = numba.typed.List.empty_list(numba.types.int32[::1])
indices_list = numba.typed.List.empty_list(numba.types.int32[::1])
Expand All @@ -505,15 +528,17 @@ def update(self, X, y=None, **fit_params):
np.full(mapper.embedding_.shape[0], n_epochs + 1, dtype=np.float64)
)

new_relations = expand_relations(self.dict_relations_)
new_regularisation_weights = build_neighborhood_similarities(
indptr_list,
indices_list,
new_relations,
)

# TODO: We can likely make this more efficient and not recompute each time
inv_dict_relations = invert_dict(new_dict_relations)

new_embedding = init_from_existing(
self.embeddings_[-1], new_mapper.graph_, new_dict_relations
self.embeddings_[-1], new_mapper.graph_, inv_dict_relations
)

self.embeddings_.append(new_embedding)
Expand Down

0 comments on commit 04bc761

Please sign in to comment.