In [5]:
import torch
import torch.nn.functional as F

def simclr_loss(z: torch.Tensor, temperature: float = 0.5) -> torch.Tensor:
    """
    Compute SimCLR (NT-Xent) loss.

    Args:
        z (torch.Tensor): Tensor of shape (2N, D), where 2N is the number of augmented samples,
                          and D is the embedding dimension.
                          The first N and last N are positive pairs.
        temperature (float): Temperature parameter τ.

    Returns:
        torch.Tensor: The scalar loss value.
    """
    print(z)
    # Normalize the representations
    z = F.normalize(z, dim=1)
    print(z)
    # Compute cosine similarity matrix (2N x 2N)
    similarity_matrix = torch.matmul(z, z.T)
    print(similarity_matrix)
    # Get batch size
    batch_size = z.size(0)
    assert batch_size % 2 == 0, "Batch size should be even for SimCLR"
    N = batch_size // 2

    # Create labels for positive pairs: (0, N), (1, N+1), ..., (N-1, 2N-1)
    pos_indices = torch.arange(N)
    print(pos_indices)
    positives = torch.cat([
        torch.stack([pos_indices, pos_indices + N], dim=1),
        torch.stack([pos_indices + N, pos_indices], dim=1)
    ], dim=0)
    print(positives)
    # Mask to remove self-comparisons from denominator
    mask = torch.eye(batch_size, dtype=torch.bool, device=z.device)
    print(mask)
    logits = similarity_matrix / temperature
    print(logits)
    logits.masked_fill_(mask, -1e9)


    # Compute log-softmax
    log_probs = F.log_softmax(logits, dim=1)
    softmax = F.softmax(logits, dim=1)
    print("log probability",log_probs)
    print("softmax", softmax)
    # Compute the mean loss over all positive pairs
    loss = -log_probs[positives[:, 0], positives[:, 1]]
    print(loss)
    return loss.mean()


In [6]:
def nt_xent_loss(z1, z2, temperature=.5):
    batch_size = z1.shape[0]
    
    # Normalize the embeddings
    z1 = F.normalize(z1, dim=1)
    z2 = F.normalize(z2, dim=1)
    
    # Concatenate
    z = torch.cat([z1, z2], dim=0)  # (2N, D)
    
    # Similarity matrix
    sim_matrix = torch.matmul(z, z.T)  # (2N, 2N)
    sim_matrix = sim_matrix / temperature

    # Remove similarity of samples to themselves
    mask = torch.eye(sim_matrix.size(0), device=z.device).bool()
    sim_matrix = sim_matrix.masked_fill(mask, -float('inf'))

    # Positive pairs (i, i+N) and (i+N, i)
    positives = torch.cat([
        torch.arange(batch_size, device=z.device) + batch_size,
        torch.arange(batch_size, device=z.device)
    ])

    # Labels
    labels = positives

    # Loss
    loss = F.cross_entropy(sim_matrix, labels)
    return loss

In [7]:
original = torch.randn(4,2)
batch_1 = original @ (torch.ones(2,2)*0.5).T

In [8]:
original, batch_1

(tensor([[-0.5662,  0.3491],
         [-1.2195, -1.2805],
         [-0.3878,  1.1199],
         [ 1.8077,  1.1549]]),
 tensor([[-0.1086, -0.1086],
         [-1.2500, -1.2500],
         [ 0.3660,  0.3660],
         [ 1.4813,  1.4813]]))

In [9]:
batch_2 = original+2

In [10]:
batch_2

tensor([[1.4338, 2.3491],
        [0.7805, 0.7195],
        [1.6122, 3.1199],
        [3.8077, 3.1549]])

In [11]:
simclr_loss(torch.cat([batch_1, batch_2]))

tensor([[-0.1086, -0.1086],
        [-1.2500, -1.2500],
        [ 0.3660,  0.3660],
        [ 1.4813,  1.4813],
        [ 1.4338,  2.3491],
        [ 0.7805,  0.7195],
        [ 1.6122,  3.1199],
        [ 3.8077,  3.1549]])
tensor([[-0.7071, -0.7071],
        [-0.7071, -0.7071],
        [ 0.7071,  0.7071],
        [ 0.7071,  0.7071],
        [ 0.5210,  0.8536],
        [ 0.7353,  0.6778],
        [ 0.4591,  0.8884],
        [ 0.7700,  0.6380]])
tensor([[ 1.0000,  1.0000, -1.0000, -1.0000, -0.9720, -0.9992, -0.9528, -0.9956],
        [ 1.0000,  1.0000, -1.0000, -1.0000, -0.9720, -0.9992, -0.9528, -0.9956],
        [-1.0000, -1.0000,  1.0000,  1.0000,  0.9720,  0.9992,  0.9528,  0.9956],
        [-1.0000, -1.0000,  1.0000,  1.0000,  0.9720,  0.9992,  0.9528,  0.9956],
        [-0.9720, -0.9720,  0.9720,  0.9720,  1.0000,  0.9616,  0.9975,  0.9458],
        [-0.9992, -0.9992,  0.9992,  0.9992,  0.9616,  1.0000,  0.9397,  0.9986],
        [-0.9528, -0.9528,  0.9528,  0.9528,  0.9975,  0.9

tensor(3.2122)

In [12]:
nt_xent_loss(batch_1, batch_2)

tensor(3.2122)

In [31]:
import torch
from torch import nn

def ntxent_loss_with_labels(out0, labels, temperature=0.5):
    """
    NT-Xent loss for a batch with labels indicating class/group similarity.

    Args:
        out0 (Tensor): Output embeddings for a batch of images (batch_size, embedding_size).
        labels (Tensor): Labels for the batch (batch_size,). Each label indicates the class/group of each sample.
        temperature (float): Scaling factor for logits.

    Returns:
        Tensor: Contrastive Cross-Entropy Loss value.
    """
    
    # Normalize the output to unit length (cosine similarity)
    out0 = nn.functional.normalize(out0, dim=1)

    # Calculate cosine similarity (pairwise similarity matrix)
    logits = torch.einsum("nc,mc->nm", out0, out0) / temperature
    print(logits)
    # Mask diagonal (self-similarity)
    batch_size = out0.size(0)
    mask = torch.eye(batch_size, device=out0.device, dtype=torch.bool)
    logits = logits[~mask].view(batch_size, -1)  # Remove self-similarities

    # Generate positive pair labels: Same label is a positive pair
    labels = labels.unsqueeze(0) == labels.unsqueeze(1)  # Shape: (batch_size, batch_size)
    print(labels)
    # Convert boolean mask to integer (1 for positive pairs, 0 for negative pairs)
    labels = labels.float()

    # Cross-entropy loss: maximize similarity for positive pairs, minimize for negative pairs
    cross_entropy = nn.CrossEntropyLoss(reduction="mean")
    print(labels)
    # The target labels should be indices where labels are 1 (positive pairs)
    target = labels.argmax(dim=1)  # Get the index of the positive pair for each sample
    print(target)
    # Calculate the loss
    loss = cross_entropy(logits, target)

    return loss


In [32]:
batch_3 = torch.randn(8,2)
labels = torch.tensor([1,2,3,3,2,1,4,4])

ntxent_loss_with_labels(batch_3, labels)

tensor([[ 2.0000,  1.2998,  1.8844,  0.1167,  1.8589,  1.3779,  0.9956,  1.9544],
        [ 1.2998,  2.0000,  1.7340, -1.4416,  1.7690,  1.9972,  1.9653,  1.5928],
        [ 1.8844,  1.7340,  2.0000, -0.5590,  1.9987,  1.7840,  1.5192,  1.9837],
        [ 0.1167, -1.4416, -0.5590,  2.0000, -0.6282, -1.3667, -1.6735, -0.3097],
        [ 1.8589,  1.7690,  1.9987, -0.6282,  2.0000,  1.8155,  1.5653,  1.9732],
        [ 1.3779,  1.9972,  1.7840, -1.3667,  1.8155,  2.0000,  1.9431,  1.6542],
        [ 0.9956,  1.9653,  1.5192, -1.6735,  1.5653,  1.9431,  2.0000,  1.3411],
        [ 1.9544,  1.5928,  1.9837, -0.3097,  1.9732,  1.6542,  1.3411,  2.0000]])
tensor([[ True, False, False, False, False,  True, False, False],
        [False,  True, False, False,  True, False, False, False],
        [False, False,  True,  True, False, False, False, False],
        [False, False,  True,  True, False, False, False, False],
        [False,  True, False, False,  True, False, False, False],
        [ Tru

tensor(2.3013)