Skip to content

Commit

Permalink
Merge pull request #696 from matthieuheitz/multiple_epochs
Browse files Browse the repository at this point in the history
Get intermediate results at different epochs
  • Loading branch information
lmcinnes committed Sep 15, 2022
2 parents e664003 + d5d8af7 commit 70d0132
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 11 deletions.
26 changes: 23 additions & 3 deletions umap/layouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,12 @@ def optimize_layout_euclidean(
The indices of the heads of 1-simplices with non-zero membership.
tail: array of shape (n_1_simplices)
The indices of the tails of 1-simplices with non-zero membership.
n_epochs: int
The number of training epochs to use in optimization.
n_epochs: int, or list of int
The number of training epochs to use in optimization, or a list of
epochs at which to save the embedding. In case of a list, the optimization
will use the maximum number of epochs in the list, and will return a list
of embedding in the order of increasing epoch, regardless of the order in
the epoch list.
n_vertices: int
The number of vertices (0-simplices) in the dataset.
epochs_per_sample: array of shape (n_1_simplices)
Expand Down Expand Up @@ -332,6 +336,12 @@ def optimize_layout_euclidean(
dens_phi_sum = np.zeros(1, dtype=np.float32)
dens_re_sum = np.zeros(1, dtype=np.float32)

epochs_list = None
embedding_list = []
if isinstance(n_epochs, list):
epochs_list = n_epochs
n_epochs = max(epochs_list)

if "disable" not in tqdm_kwds:
tqdm_kwds["disable"] = not verbose

Expand Down Expand Up @@ -398,7 +408,17 @@ def optimize_layout_euclidean(

alpha = initial_alpha * (1.0 - (float(n) / float(n_epochs)))

return head_embedding
if verbose and n % int(n_epochs / 10) == 0:
print("\tcompleted ", n, " / ", n_epochs, "epochs")

if epochs_list is not None and n in epochs_list:
embedding_list.append(head_embedding.copy())

# Add the last embedding to the list as well
if epochs_list is not None:
embedding_list.append(head_embedding.copy())

return head_embedding if epochs_list is None else embedding_list


def _optimize_layout_generic_single_epoch(
Expand Down
46 changes: 38 additions & 8 deletions umap/umap_.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,11 +992,14 @@ def simplicial_set_embedding(
in greater repulsive force being applied, greater optimization
cost, but slightly more accuracy.
n_epochs: int (optional, default 0)
n_epochs: int (optional, default 0), or list of int
The number of training epochs to be used in optimizing the
low dimensional embedding. Larger values result in more accurate
embeddings. If 0 is specified a value will be selected based on
the size of the input dataset (200 for large datasets, 500 for small).
If a list of int is specified, then the intermediate embeddings at the
different epochs specified in that list are returned in
``aux_data["embedding_list"]``.
init: string
How to initialize the low dimensional embedding. Options are:
Expand Down Expand Up @@ -1076,8 +1079,11 @@ def simplicial_set_embedding(
if n_epochs is None:
n_epochs = default_epochs

if n_epochs > 10:
graph.data[graph.data < (graph.data.max() / float(n_epochs))] = 0.0
# If n_epoch is a list, get the maximum epoch to reach
n_epochs_max = max(n_epochs) if isinstance(n_epochs, list) else n_epochs

if n_epochs_max > 10:
graph.data[graph.data < (graph.data.max() / float(n_epochs_max))] = 0.0
else:
graph.data[graph.data < (graph.data.max() / float(default_epochs))] = 0.0

Expand Down Expand Up @@ -1122,7 +1128,7 @@ def simplicial_set_embedding(
else:
embedding = init_data

epochs_per_sample = make_epochs_per_sample(graph.data, n_epochs)
epochs_per_sample = make_epochs_per_sample(graph.data, n_epochs_max)

head = graph.row
tail = graph.col
Expand Down Expand Up @@ -1213,6 +1219,11 @@ def simplicial_set_embedding(
tqdm_kwds=tqdm_kwds,
move_other=True,
)

if isinstance(embedding, list):
aux_data["embedding_list"] = embedding
embedding = embedding[-1].copy()

if output_dens:
if verbose:
print(ts() + " Computing embedding densities")
Expand Down Expand Up @@ -1761,10 +1772,19 @@ def _validate_parameters(self):
raise ValueError("n_components must be an int")
if self.n_components < 1:
raise ValueError("n_components must be greater than 0")
if self.n_epochs is not None and (
self.n_epochs < 0 or not isinstance(self.n_epochs, int)
self.n_epochs_list = None
if isinstance(self.n_epochs, list) or isinstance(self.n_epochs, tuple) or \
isinstance(self.n_epochs, np.ndarray):
if not issubclass(np.array(self.n_epochs).dtype.type, np.integer) or \
not np.all(np.array(self.n_epochs) >= 0):
raise ValueError("n_epochs must be a nonnegative integer "
"or a list of nonnegative integers")
self.n_epochs_list = list(self.n_epochs)
elif self.n_epochs is not None and (
self.n_epochs < 0 or not isinstance(self.n_epochs, int)
):
raise ValueError("n_epochs must be a nonnegative integer")
raise ValueError("n_epochs must be a nonnegative integer "
"or a list of nonnegative integers")
if self.metric_kwds is None:
self._metric_kwds = {}
else:
Expand Down Expand Up @@ -2723,12 +2743,22 @@ def fit(self, X, y=None):
print(ts(), "Construct embedding")

if self.transform_mode == "embedding":
epochs = self.n_epochs_list if self.n_epochs_list is not None else self.n_epochs
self.embedding_, aux_data = self._fit_embed_data(
self._raw_data[index],
self.n_epochs,
epochs,
init,
random_state, # JH why raw data?
)

if self.n_epochs_list is not None:
if "embedding_list" not in aux_data:
raise KeyError("No list of embedding were found in 'aux_data'. "
"It is likely the layout optimization function "
"doesn't support the list of int for 'n_epochs'.")
else:
self.embedding_list_ = [e[inverse] for e in aux_data["embedding_list"]]

# Assign any points that are fully disconnected from our manifold(s) to have embedding
# coordinates of np.nan. These will be filtered by our plotting functions automatically.
# They also prevent users from being deceived a distance query to one of these points.
Expand Down

0 comments on commit 70d0132

Please sign in to comment.