In [78]:
import tensorflow as tf
import torch
import logging
import numpy as np
from common.psm import psm_f_fast

# Code from Paper

In [79]:
def contrastive_loss_paper(similarity_matrix, metric_values, temperature, beta=1.0): 
    """Contrative Loss with embedding similarity .""" 
    metricshape = tf.shape(metric_values)
    ## z \theta(X): embedding1 = nnmodel.representation(X) 
    # ## z \theta(Y): embedding2 = nnmodel.representation(Y) 
    # ## similaritymatrix = cosinesimilarity(embedding1, embedding2 
    # ## metricvalues = PSM(X, Y) 
    similarity_matrix /= temperature 
    neg_logits1 = similarity_matrix  

    col_indices = tf.cast(tf.argmin(metric_values, axis=1), dtype=tf.int32) 
    pos_indices1 = tf.stack( 
        (tf.range(metricshape[0], dtype=tf.int32), col_indices), axis=1)
    pos_logits1 = tf.gather_nd(similarity_matrix, pos_indices1)    

    metric_values /= beta 
    similarity_measure = tf.exp(-metric_values)
    pos_weights1 = -tf.gather_nd(metric_values, pos_indices1) 
    pos_logits1 += pos_weights1 
    negative_weights = tf.math.log((1.0 - similarity_measure) + 1e-8)
    neg_logits1 += tf.tensor_scatter_nd_update( 
          negative_weights, pos_indices1, pos_weights1)    
    
    neg_logits1 = tf.math.reduce_logsumexp(neg_logits1, axis=1) 
    return tf.reduce_mean(neg_logits1 - pos_logits1) # Equation 4  

# Code from repository

In [80]:
# https://github.com/google-research/google-research/blob/6574e2ca3fab2b76f08566709aae2721110a3b5d/pse/jumping_task/training_helpers.py#L97
EPS = 1e-9
def contrastive_loss_repository(similarity_matrix,
                     metric_values,
                     temperature,
                     coupling_temperature=1.0,
                     use_coupling_weights=True):
    """Contrative Loss with soft coupling."""
    logging.info('Using alternative contrastive loss.')
    metric_shape = tf.shape(metric_values)
    similarity_matrix /= temperature
    neg_logits1, neg_logits2 = similarity_matrix, similarity_matrix

    col_indices = tf.cast(tf.argmin(metric_values, axis=1), dtype=tf.int32)
    pos_indices1 = tf.stack(
        (tf.range(metric_shape[0], dtype=tf.int32), col_indices), axis=1)
    pos_logits1 = tf.gather_nd(similarity_matrix, pos_indices1)

    row_indices = tf.cast(tf.argmin(metric_values, axis=0), dtype=tf.int32)
    pos_indices2 = tf.stack(
        (row_indices, tf.range(metric_shape[1], dtype=tf.int32)), axis=1)
    pos_logits2 = tf.gather_nd(similarity_matrix, pos_indices2)

    if use_coupling_weights:
        metric_values /= coupling_temperature
        coupling = tf.exp(-metric_values)
        pos_weights1 = -tf.gather_nd(metric_values, pos_indices1)
        pos_weights2 = -tf.gather_nd(metric_values, pos_indices2)
        pos_logits1 += pos_weights1
        pos_logits2 += pos_weights2
        negative_weights = tf.math.log((1.0 - coupling) + EPS)
        neg_logits1 += tf.tensor_scatter_nd_update(
            negative_weights, pos_indices1, pos_weights1)
        neg_logits2 += tf.tensor_scatter_nd_update(
            negative_weights, pos_indices2, pos_weights2)

    neg_logits1 = tf.math.reduce_logsumexp(neg_logits1, axis=1)
    neg_logits2 = tf.math.reduce_logsumexp(neg_logits2, axis=0)

    loss1 = tf.reduce_mean(neg_logits1 - pos_logits1)
    loss2 = tf.reduce_mean(neg_logits2 - pos_logits2)
    return loss1 + loss2

def cosine_similarity_tensor(x, y):
  """Computes cosine similarity between all pairs of vectors in x and y."""
  x_expanded, y_expanded = x[:, tf.newaxis], y[tf.newaxis, :]
  similarity_matrix = tf.reduce_sum(x_expanded * y_expanded, axis=-1)
  similarity_matrix /= (
      tf.norm(x_expanded, axis=-1) * tf.norm(y_expanded, axis=-1) + EPS)
  return similarity_matrix

In [81]:
# The embedding is two-dimensional, n_states x contrastive_loss_head size
np.random.seed(1)
e1 = np.random.randint(low=0, high=255, size=(56, 64)).astype(np.float32)
e2 = np.random.randint(low=0, high=255, size=(56, 64)).astype(np.float32)

# Dimension is (n_states,)
a1 = np.random.randint(0, 8, size=(56,))
a2 = np.random.randint(0, 8, size=(56,))
a2[10:20] = a1[30:40] # just make the action sequences partly similar

temp = 0.1
gamma = 0.8

In [82]:
e1_tensor, e2_tensor = tf.convert_to_tensor(e1), tf.convert_to_tensor(e2)
a1_tensor, a2_tensor = tf.convert_to_tensor(a1), tf.convert_to_tensor(a2)

e1_torch, e2_torch = torch.from_numpy(e1), torch.from_numpy(e2)
a1_torch, a2_torch = torch.from_numpy(a1), torch.from_numpy(a2)

psm_matrix_torch = psm_f_fast(a1_torch, a2_torch)
psm_matrix_tensor = tf.convert_to_tensor(psm_matrix_torch.numpy())

In [83]:
sim_matrix_tensor = cosine_similarity_tensor(e1_tensor, e2_tensor)
print(contrastive_loss_paper(sim_matrix_tensor, psm_matrix_tensor, temperature=temp))

sim_matrix_tensor = cosine_similarity_tensor(e1_tensor, e2_tensor)
print(contrastive_loss_repository(sim_matrix_tensor, psm_matrix_tensor, temperature=temp, use_coupling_weights=False))

tf.Tensor(21.747622, shape=(), dtype=float32)
tf.Tensor(8.222357, shape=(), dtype=float32)


In [84]:
from common.training_helpers import cosine_similarity, contrastive_loss_paper, contrastive_loss_repository

sim_matrix_torch = cosine_similarity(e1_torch, e2_torch)
assert np.allclose(sim_matrix_tensor.numpy(), sim_matrix_torch.numpy())

print(contrastive_loss_paper(sim_matrix_torch.clone(), psm_matrix_torch, temperature=temp))
print(contrastive_loss_repository(sim_matrix_torch.clone(), psm_matrix_torch, temperature=temp))

tensor(21.7476)
tensor(8.2224)
