In [1]:
import sys
from pathlib import Path

In [2]:
import numpy as np
import tensorflow as tf

In [3]:
BASE_DIR = Path("..")
sys.path.append(str(BASE_DIR.resolve()))

In [4]:
from losses_dev import ntxent_loss

In [5]:
tf.random.set_seed(42)

# Define Loss functions

In [6]:
def pairwise_similarity(u: tf.Tensor, v: tf.Tensor) -> tf.Tensor:
    """
        Calculate pairwise similarity between two vectors
            using cos distance.

        args:
            u: tf.Tensor - First input vector.
            v: tf.Tensor - Second input vector.
        returns:
            score: tf.Tensor - Similarity score scalar.
    """
    
    numer = tf.tensordot(u, v, axes = [[1],[1]])
    denom = tf.norm(u, 2, axis = 1) * tf.norm(v, 2, axis = 1)

    score =  numer / denom
    return score

In [7]:
def pairwise_similarity_old(u: tf.Tensor, v: tf.Tensor) -> tf.Tensor:
    """
        Calculate pairwise similarity between two vectors
            using cos distance.

        args:
            u: tf.Tensor - First input bector.
            v: tf.Tensor - Second input vector.
        returns:
            score: tf.Tensor - Similarity score scalar.
    """

    score = tf.tensordot(tf.transpose(u), v, axes = 1) / (tf.norm(u, 2) * tf.norm(v, 2))
    return score

In [30]:
def ntxent_loss_new(batch: tf.Tensor,
                temp: float = 1.0) -> tf.Tensor:
    """
        Normalised temperature-scaled cross entropy loss.

        args:
            batch: tf.Tensor - Batch of augmented tensors of size 2N.
                Where N is the minibatch size.
            temp: float - Temperate scale coefficient.
        returns:
            loss: tf.Tensor - Loss scalar.
    """

    n_batch = tf.cast(tf.shape(batch)[0], tf.float32)
    n_minibatch = tf.cast(n_batch / 2, np.int32)
    loss = tf.constant(0.0)
    
    # get similarity matrix
    sim_mat = pairwise_similarity(batch, batch)
    sim_mat = tf.math.exp(sim_mat / temp)
    
    # calculate loss
    for k in tf.range(n_minibatch, dtype = tf.int32):
        loss_pairwise_1 = -1.0 * tf.math.log(sim_mat[k, (n_minibatch + k)] / (tf.reduce_sum(sim_mat[k,:k]) + tf.reduce_sum(sim_mat[k,(k + 1):])))
        loss_pairwise_2 = -1.0 * tf.math.log(sim_mat[(n_minibatch + k), k] / (tf.reduce_sum(sim_mat[(n_minibatch + k),:(n_minibatch + k)]) + tf.reduce_sum(sim_mat[(n_minibatch + k),(n_minibatch + k + 1):])))
        loss = loss + loss_pairwise_1 + loss_pairwise_2
        
    return loss / n_batch

        


In [72]:
def ntxent_loss_alt(batch_u: tf.Tensor,
                    batch_v: tf.Tensor,
                    temp: float = 1.0) -> tf.Tensor:
    """
        Normalised temperature-scaled cross entropy loss.

        args:
            batch: tf.Tensor - Batch of augmented tensors of size 2N.
                Where N is the minibatch size.
            temp: float - Temperate scale coefficient.
        returns:
            loss: tf.Tensor - Loss scalar.
    """

    n_minibatch = batch_u.shape[0]
    loss = tf.constant(0.0)
    
    # get similarity matrix
    sim_mat = pairwise_similarity(batch_u, batch_v)
    sim_mat = tf.math.exp(sim_mat / temp)
    sim_mat_t = tf.transpose(sim_mat)
    
    # calculate loss
    for k in tf.range(n_minibatch, dtype = tf.int32):
        loss_pairwise_1 = -1.0 * tf.math.log(sim_mat[k, k] / tf.reduce_sum(sim_mat[k,:]))
        loss_pairwise_2 = -1.0 * tf.math.log(sim_mat_t[k, k] / tf.reduce_sum(sim_mat_t[k,:]))
        loss = loss + loss_pairwise_1 + loss_pairwise_2
        
    return loss / (2 * n_minibatch)

# Define Mock Inputs

In [69]:
test_input_1 = tf.random.normal((128,1024), dtype = tf.float32)
test_input_2 = tf.random.normal((128,1024), dtype = tf.float32)

In [65]:
test_input_stacked = tf.concat([test_input_1,test_input_2], axis = 0)

In [42]:
%%time
loss_1 = ntxent_loss(test_input_1, test_input_2)

CPU times: user 1min 8s, sys: 218 ms, total: 1min 8s
Wall time: 1min 8s


In [76]:
%%time
loss_2 = ntxent_loss_new(test_input_stacked)

CPU times: user 276 ms, sys: 7.4 ms, total: 283 ms
Wall time: 278 ms


In [77]:
%%time
loss_3 = ntxent_loss_alt(test_input_1, test_input_2)

CPU times: user 159 ms, sys: 3.65 ms, total: 162 ms
Wall time: 158 ms


## Precompute pairwise similarities

In [22]:
%%time
# vectorised
ps_vect = pairwise_similarity(test_input_stacked, test_input_stacked)

CPU times: user 9.81 ms, sys: 3.53 ms, total: 13.3 ms
Wall time: 4.65 ms


In [23]:
%%time
# sequential
for k_outer in tf.range(0, test_input_stacked.shape[0], dtype = tf.int32):
    u = test_input_stacked[k_outer]
    for k_inner in tf.range(0, test_input_stacked.shape[0], dtype = tf.int32):
        v = test_input_stacked[k_inner]
        score = pairwise_similarity_old(u,v)

CPU times: user 1min, sys: 104 ms, total: 1min
Wall time: 1min
