In [1]:
import torch

batch_size = 4
temperature = 0.5

# 假设这些是通过模型得到的embeddings
z_i = torch.tensor([
    [0.1, 0.2, 0.3],  # A1
    [0.4, 0.5, 0.6],  # B1
    [0.7, 0.8, 0.9],  # C1
    [0.2, 0.1, 0.3],  # D1
])

z_j = torch.tensor([
    [0.1, 0.2, 0.4],  # A2
    [0.4, 0.6, 0.5],  # B2
    [0.8, 0.7, 0.9],  # C2
    [0.3, 0.1, 0.2],  # D2
])

In [3]:
representations = torch.cat([z_i, z_j], dim=0)
representations


tensor([[0.1000, 0.2000, 0.3000],
        [0.4000, 0.5000, 0.6000],
        [0.7000, 0.8000, 0.9000],
        [0.2000, 0.1000, 0.3000],
        [0.1000, 0.2000, 0.4000],
        [0.4000, 0.6000, 0.5000],
        [0.8000, 0.7000, 0.9000],
        [0.3000, 0.1000, 0.2000]])

In [4]:
similarity_matrix = torch.matmul(representations, representations.T) / temperature
similarity_matrix

tensor([[0.2800, 0.6400, 1.0000, 0.2600, 0.3400, 0.6200, 0.9800, 0.2200],
        [0.6400, 1.5400, 2.4400, 0.6200, 0.7600, 1.5200, 2.4200, 0.5800],
        [1.0000, 2.4400, 3.8800, 0.9800, 1.1800, 2.4200, 3.8600, 0.9400],
        [0.2600, 0.6200, 0.9800, 0.2800, 0.3200, 0.5800, 1.0000, 0.2600],
        [0.3400, 0.7600, 1.1800, 0.3200, 0.4200, 0.7200, 1.1600, 0.2600],
        [0.6200, 1.5200, 2.4200, 0.5800, 0.7200, 1.5400, 2.3800, 0.5600],
        [0.9800, 2.4200, 3.8600, 1.0000, 1.1600, 2.3800, 3.8800, 0.9800],
        [0.2200, 0.5800, 0.9400, 0.2600, 0.2600, 0.5600, 0.9800, 0.2800]])

In [5]:
positives = torch.cat([
    similarity_matrix[0, 4].unsqueeze(0),  # A1-A2
    similarity_matrix[1, 5].unsqueeze(0),  # B1-B2
    similarity_matrix[2, 6].unsqueeze(0),  # C1-C2
    similarity_matrix[3, 7].unsqueeze(0)   # D1-D2
])
positives

tensor([0.3400, 1.5200, 3.8600, 0.2600])

In [6]:
N = 2 * batch_size
mask = torch.ones((N, N), dtype=bool)
mask = mask.fill_diagonal_(0)
for i in range(batch_size):
    mask[i, batch_size + i] = 0
    mask[batch_size + i, i] = 0
mask

tensor([[False,  True,  True,  True, False,  True,  True,  True],
        [ True, False,  True,  True,  True, False,  True,  True],
        [ True,  True, False,  True,  True,  True, False,  True],
        [ True,  True,  True, False,  True,  True,  True, False],
        [False,  True,  True,  True, False,  True,  True,  True],
        [ True, False,  True,  True,  True, False,  True,  True],
        [ True,  True, False,  True,  True,  True, False,  True],
        [ True,  True,  True, False,  True,  True,  True, False]])

In [14]:
negatives = similarity_matrix[mask].view(N, -1)
negatives

tensor([[0.6400, 1.0000, 0.2600, 0.6200, 0.9800, 0.2200],
        [0.6400, 2.4400, 0.6200, 0.7600, 2.4200, 0.5800],
        [1.0000, 2.4400, 0.9800, 1.1800, 2.4200, 0.9400],
        [0.2600, 0.6200, 0.9800, 0.3200, 0.5800, 1.0000],
        [0.7600, 1.1800, 0.3200, 0.7200, 1.1600, 0.2600],
        [0.6200, 2.4200, 0.5800, 0.7200, 2.3800, 0.5600],
        [0.9800, 2.4200, 1.0000, 1.1600, 2.3800, 0.9800],
        [0.2200, 0.5800, 0.9400, 0.2600, 0.5600, 0.9800]])