In [1]:
import torch

from pytorch_metric_learning.distances import LpDistance, CosineSimilarity, DotProductSimilarity
from pytorch_metric_learning.losses import NTXentLoss, SupConLoss
from pytorch_metric_learning.reducers import AvgNonZeroReducer, PerAnchorReducer
from pytorch_metric_learning.utils import common_functions as c_f

import os

TEST_DEVICE = torch.device(os.environ.get("TEST_DEVICE", "cuda"))

dtypes_from_environ = os.environ.get("TEST_DTYPES", "float16,float32,float64").split(
    ","
)
TEST_DTYPES = [getattr(torch, x) for x in dtypes_from_environ]

In [2]:
# embeddings
batch_size = 6
hidden_dim = 13

embeddings = torch.randn((batch_size, hidden_dim))
print(embeddings)

tensor([[ 0.1889, -2.0954, -2.1622, -0.9798, -0.2999,  0.2575, -0.5817, -1.2394,
          0.0636,  0.6927,  1.4599, -0.7563, -0.7357],
        [ 0.1996,  0.8045, -0.1687,  0.2859,  0.8299,  0.3390, -1.0150,  0.0772,
          1.9016, -0.7265, -0.2821,  1.4018,  1.9424],
        [ 1.3193, -1.2767,  0.8210, -1.9712, -0.9279, -1.7080,  0.6796, -0.8085,
         -0.8814, -1.5720, -1.0460,  1.9520, -0.3893],
        [ 1.5262, -1.4883, -0.8948,  0.3876, -0.0660,  0.0590,  0.8836, -0.2283,
          0.8427, -0.8171, -0.3715,  0.6567,  0.8577],
        [ 1.0530, -0.5048, -0.0338, -0.0904,  0.1101, -0.3497, -1.4319,  0.1921,
          0.2075, -1.4596, -0.6132, -0.4351,  1.1964],
        [-0.6290, -0.3426, -1.1992,  0.4503, -0.7431,  0.1220, -1.9100,  1.2934,
          2.0955,  0.1623,  1.5234,  1.1079, -0.2712]])


In [3]:
# labels
labels = torch.LongTensor([0, 0, 0, 1, 1, 2])

In [4]:
# loss_funcA
temperature = 0.1

loss_funcA = SupConLoss(temperature=temperature)
print(loss_funcA)


SupConLoss(
  (distance): CosineSimilarity()
  (reducer): AvgNonZeroReducer()
)


In [5]:
# 计算 loss A
loss_A = loss_funcA(embeddings, labels).float()
print(loss_A, type(loss_A))

tensor(3.7643) <class 'torch.Tensor'>


In [6]:
# 复现计算过程
# 注意：distance默认是 CosineSimilarity()

pos_pairs = [(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1), (3, 4), (4, 3)]
neg_pairs = [
    (0, 3), (0, 4), (0, 5),
    (1, 3), (1, 4), (1, 5),
    (2, 3), (2, 4), (2, 5),
    (3, 0), (3, 1), (3, 2), (3, 5),
    (4, 0), (4, 1), (4, 2), (4, 5),
    (5, 0), (5, 1), (5, 2), (5, 3), (5, 4),
]

total_lossA = torch.zeros(5, device="cpu", dtype=torch.float64)

for a1, p in pos_pairs:
    # a1: anchor;
    # p: positive instance
    
    anchor, positive = embeddings[a1], embeddings[p]
    numeratorA = torch.exp(torch.matmul(anchor, positive) / (temperature * torch.norm(anchor) * torch.norm(positive)))
    
    denominatorA = 0
    
    for a2, n in pos_pairs + neg_pairs:
        # n: negative instance
        
        if a2 == a1:
            negative = embeddings[n]
            curr_denomD = torch.exp(
                torch.matmul(anchor, negative) / (temperature * torch.norm(anchor) * torch.norm(negative))
            )
            denominatorA += curr_denomD
            
        else:
            continue
    
    print(numeratorA, denominatorA, - torch.log(numeratorA / denominatorA))
    curr_lossA = -torch.log(numeratorA / denominatorA)
    total_lossA[a1] += curr_lossA

pos_pair_per_anchor = torch.tensor(
    [2, 2, 2, 1, 1], device="cpu", dtype=torch.float64
)
total_lossA = torch.mean(total_lossA / pos_pair_per_anchor)
print("total_lossA: ", total_lossA, type(total_lossA))

rtol = 1e-4
assert torch.isclose(loss_A, total_lossA, rtol=rtol)

tensor(0.0409) tensor(25.4132) tensor(6.4322)
tensor(0.8321) tensor(25.4132) tensor(3.4190)
tensor(0.0409) tensor(185.0638) tensor(8.4176)
tensor(0.3266) tensor(185.0638) tensor(6.3398)
tensor(0.8321) tensor(50.0056) tensor(4.0959)
tensor(0.3266) tensor(50.0056) tensor(5.0312)
tensor(54.9212) tensor(125.7912) tensor(0.8287)
tensor(54.9212) tensor(169.1595) tensor(1.1249)
total_lossA:  tensor(3.7643, dtype=torch.float64) <class 'torch.Tensor'>


RuntimeError: Float did not match Double