In [10]:
import numpy as np
torch.manual_seed(0)
np.random.seed(0)

In [18]:
import numpy as np

def cosine_similarity(z_i, z_j):
    dot_product = np.dot(z_i, z_j)
    norm_i = np.linalg.norm(z_i)
    norm_j = np.linalg.norm(z_j)
    return dot_product / (norm_i * norm_j)

def contrastive_loss_python(z_all, i, j, tau):
    # Normalize all vectors
    z_all = z_all / np.linalg.norm(z_all, axis=1, keepdims=True)
    
    z_i = z_all[i]
    z_j = z_all[j]
    
    sim_ij = cosine_similarity(z_i, z_j)
    numerator = np.exp(sim_ij / tau)
    
    denominator = 0.0
    for k, z_k in enumerate(z_all):
        if k != i:
            sim_ik = cosine_similarity(z_i, z_k)
            denominator += np.exp(sim_ik / tau)
    
    loss = -np.log(numerator / denominator)
    return loss

# Example usage
np.random.seed(0)  # For reproducibility
batch_size = 4
hidden_dim = 3
temperature = 0.5
fixed_features = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 1, 1]], dtype=np.float32)
i, j = 0, 1
loss_python = contrastive_loss_python(fixed_features, i, j, tau=temperature)
print("Contrastive Loss (Python):", loss_python)


Contrastive Loss (Python): 1.6434669749791164


In [17]:
import torch
import torch.nn.functional as F

def contrastive_loss_pytorch(feats, temperature=0.5):
    feats = F.normalize(feats, dim=-1)  # Normalize the features
    similarity_matrix = torch.mm(feats, feats.T)
    batch_size = feats.size(0)
    mask = torch.eye(batch_size, dtype=torch.bool, device=feats.device)
    similarity_matrix = similarity_matrix / temperature
    similarity_matrix.masked_fill_(mask, float('-inf'))  # Mask self-similarity
    pos_mask = mask.roll(shifts=batch_size // 2, dims=0)
    positive_similarities = similarity_matrix[pos_mask]
    denominator = torch.logsumexp(similarity_matrix, dim=-1)
    loss = -positive_similarities + denominator
    return loss.mean()

# Example usage
torch.manual_seed(0)  # For reproducibility
batch_size = 4
hidden_dim = 3
temperature = 0.5
fixed_features = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 1, 1]], dtype=torch.float32)
loss_pytorch = contrastive_loss_pytorch(fixed_features, temperature=temperature)
print("Contrastive Loss (PyTorch):", loss_pytorch.item())


Contrastive Loss (PyTorch): 1.2185781002044678
