Skip to content

Commit

Permalink
Merge pull request #856 from jgraving/master
Browse files Browse the repository at this point in the history
Add numerically stable cross entropy loss
  • Loading branch information
lmcinnes committed Apr 26, 2022
2 parents 4398844 + 17b2089 commit 2c5232f
Showing 1 changed file with 16 additions and 13 deletions.
29 changes: 16 additions & 13 deletions umap/parametric_umap.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,9 +633,9 @@ def init_embedding_from_graph(
return embedding


def convert_distance_to_probability(distances, a=1.0, b=1.0):
def convert_distance_to_log_probability(distances, a=1.0, b=1.0):
"""
convert distance representation into probability,
convert distance representation into log probability,
as a function of a, b params
Parameters
Expand All @@ -650,13 +650,13 @@ def convert_distance_to_probability(distances, a=1.0, b=1.0):
Returns
-------
float
probability in embedding space
log probability in embedding space
"""
return 1.0 / (1.0 + a * distances ** (2 * b))
return -tf.math.log1p(a * distances ** (2 * b))


def compute_cross_entropy(
probabilities_graph, probabilities_distance, EPS=1e-4, repulsion_strength=1.0
probabilities_graph, log_probabilities_distance, EPS=1e-4, repulsion_strength=1.0
):
"""
Compute cross entropy between low and high probability
Expand All @@ -665,8 +665,8 @@ def compute_cross_entropy(
----------
probabilities_graph : array
high dimensional probabilities
probabilities_distance : array
low dimensional probabilities
log_probabilities_distance : array
low dimensional log probabilities
EPS : float, optional
offset to to ensure log is taken of a positive number, by default 1e-4
repulsion_strength : float, optional
Expand All @@ -683,12 +683,15 @@ def compute_cross_entropy(
"""
# cross entropy
attraction_term = -probabilities_graph * tf.math.log(
tf.clip_by_value(probabilities_distance, EPS, 1.0)
attraction_term = -probabilities_graph * tf.math.log_sigmoid(
log_probabilities_distance
)
# use numerically stable repellant term
# Shi et al. 2022 (https://arxiv.org/abs/2111.08851)
# log(1 - sigmoid(logits)) = log(sigmoid(logits)) - logits
repellant_term = (
-(1.0 - probabilities_graph)
* tf.math.log(tf.clip_by_value(1.0 - probabilities_distance, EPS, 1.0))
* (tf.math.log_sigmoid(log_probabilities_distance) - log_probabilities_distance)
* repulsion_strength
)

Expand Down Expand Up @@ -759,8 +762,8 @@ def loss(placeholder_y, embed_to_from):
axis=0,
)

# convert probabilities to distances
probabilities_distance = convert_distance_to_probability(
# convert distances to probabilities
log_probabilities_distance = convert_distance_to_log_probability(
distance_embedding, _a, _b
)

Expand All @@ -772,7 +775,7 @@ def loss(placeholder_y, embed_to_from):
# compute cross entropy
(attraction_loss, repellant_loss, ce_loss) = compute_cross_entropy(
probabilities_graph,
probabilities_distance,
log_probabilities_distance,
repulsion_strength=repulsion_strength,
)

Expand Down

0 comments on commit 2c5232f

Please sign in to comment.