Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add numerically stable cross entropy loss #856

Merged
merged 1 commit into from
Apr 26, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
29 changes: 16 additions & 13 deletions umap/parametric_umap.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,9 +633,9 @@ def init_embedding_from_graph(
return embedding


def convert_distance_to_probability(distances, a=1.0, b=1.0):
def convert_distance_to_log_probability(distances, a=1.0, b=1.0):
"""
convert distance representation into probability,
convert distance representation into log probability,
as a function of a, b params

Parameters
Expand All @@ -650,13 +650,13 @@ def convert_distance_to_probability(distances, a=1.0, b=1.0):
Returns
-------
float
probability in embedding space
log probability in embedding space
"""
return 1.0 / (1.0 + a * distances ** (2 * b))
return -tf.math.log1p(a * distances ** (2 * b))


def compute_cross_entropy(
probabilities_graph, probabilities_distance, EPS=1e-4, repulsion_strength=1.0
probabilities_graph, log_probabilities_distance, EPS=1e-4, repulsion_strength=1.0
):
"""
Compute cross entropy between low and high probability
Expand All @@ -665,8 +665,8 @@ def compute_cross_entropy(
----------
probabilities_graph : array
high dimensional probabilities
probabilities_distance : array
low dimensional probabilities
log_probabilities_distance : array
low dimensional log probabilities
EPS : float, optional
offset to to ensure log is taken of a positive number, by default 1e-4
repulsion_strength : float, optional
Expand All @@ -683,12 +683,15 @@ def compute_cross_entropy(

"""
# cross entropy
attraction_term = -probabilities_graph * tf.math.log(
tf.clip_by_value(probabilities_distance, EPS, 1.0)
attraction_term = -probabilities_graph * tf.math.log_sigmoid(
log_probabilities_distance
)
# use numerically stable repellant term
# Shi et al. 2022 (https://arxiv.org/abs/2111.08851)
# log(1 - sigmoid(logits)) = log(sigmoid(logits)) - logits
repellant_term = (
-(1.0 - probabilities_graph)
* tf.math.log(tf.clip_by_value(1.0 - probabilities_distance, EPS, 1.0))
* (tf.math.log_sigmoid(log_probabilities_distance) - log_probabilities_distance)
* repulsion_strength
)

Expand Down Expand Up @@ -759,8 +762,8 @@ def loss(placeholder_y, embed_to_from):
axis=0,
)

# convert probabilities to distances
probabilities_distance = convert_distance_to_probability(
# convert distances to probabilities
log_probabilities_distance = convert_distance_to_log_probability(
distance_embedding, _a, _b
)

Expand All @@ -772,7 +775,7 @@ def loss(placeholder_y, embed_to_from):
# compute cross entropy
(attraction_loss, repellant_loss, ce_loss) = compute_cross_entropy(
probabilities_graph,
probabilities_distance,
log_probabilities_distance,
repulsion_strength=repulsion_strength,
)

Expand Down