Part of the training data pipeline (triplet mining) requires calculating pairwise distance between a list of anchor embeddings and a list of negative embeddings. The function `src.models.train_model.get_anc_neg_distance` implements this logic. This notebook demonstrates the jax-numpy magic inside that function.

Input: 
- anc_embeddings (train_batch_size, embedding_dim)
- neg_embeddings (eval_batch_size, embedding_dim)

Output:
- Array (eval_batch_size, train_batch_size), where the ($j$, $k$) entry is the squared L2 distance between the $j$-th negative embedding and the $k$-th anchor embedding.

In [None]:
import jax.numpy as jnp

In [None]:
train_batch_size = 5  # anc batch size
eval_batch_size = 7  # neg batch size
embedding_dim = 2

In [None]:
# The output should be zero everywhere except along row 1 or column 0, 
# but zero at (1, 0)
anc_embeddings = jnp.zeros((train_batch_size, embedding_dim)).at[0, :].set(1)
neg_embeddings = jnp.zeros((eval_batch_size, embedding_dim)).at[1, :].set(1)

In [None]:
print(anc_embeddings)

In [None]:
print(neg_embeddings)

By default, `numpy.repeat` flattens the array. Dimensions should be presented in the same order as original while reshaping.

In [None]:
anc_embeddings_repeated = jnp.repeat(
    anc_embeddings, neg_embeddings.shape[0], axis=-1
).reshape((anc_embeddings.shape[0], neg_embeddings.shape[0], embedding_dim))

print(anc_embeddings_repeated.shape)
print(anc_embeddings_repeated)

In [None]:
# (
#     eval_batch_size * num_devices,
#     train_batch_size * num_devices,
#     embedding_dim,
# )
anc_embeddings_repeated_transposed = jnp.transpose(
    anc_embeddings_repeated, axes=(1, 0, 2)
)

print(anc_embeddings_repeated_transposed.shape)
print(anc_embeddings_repeated_transposed)

In [None]:
neg_embeddings_repeated = jnp.repeat(
    neg_embeddings, anc_embeddings.shape[0], axis=-1
).reshape((neg_embeddings.shape[0], anc_embeddings.shape[0], embedding_dim))

print(neg_embeddings_repeated.shape)
print(neg_embeddings_repeated)

In [None]:
squared_difference = (
    (anc_embeddings_repeated_transposed - neg_embeddings_repeated) 
    * (anc_embeddings_repeated_transposed - neg_embeddings_repeated)
)

print(squared_difference.shape)
print(squared_difference)

In [None]:
anc_neg_l2_difference = jnp.sum(squared_difference, axis=-1)
print(anc_neg_l2_difference.shape)
print(anc_neg_l2_difference)

In [None]:
Array = jnp.ndarray

def squared_l2_distance(x_1: Array, x_2: Array) -> Array:
    """
    Compute squared L2 distance along axis (-1).

    Args:
     x_1: (a, b, n)
     x_2: (a, b, n)

    Returns:
     (a, b). || x_1 - x_2 ||^{2}.
    """
    squared_difference = (x_1 - x_2) * (x_1 - x_2)

    l2_difference: Array = jnp.sum(squared_difference, axis=-1)
    return l2_difference


In [None]:
pos_embeddings = jnp.zeros((train_batch_size, embedding_dim)).at[-1, :].set(-1/2)

n_anc = anc_embeddings.shape[0]
n_neg = neg_embeddings.shape[0]
embedding_dim = anc_embeddings.shape[-1]

# (n_anc,)
anc_pos_distances = squared_l2_distance(anc_embeddings, pos_embeddings)

# (n_neg, n_anc)
anc_pos_distances_repeated = jnp.repeat(
    anc_pos_distances, 
    repeats=n_neg, 
    axis=0
).reshape((n_anc, n_neg)).transpose()

print(anc_pos_distances_repeated.shape)
print(anc_pos_distances_repeated)

In [None]:
print(anc_neg_l2_difference - anc_pos_distances_repeated)