In [1]:
import torch

from pytorch_metric_learning.distances import LpDistance, CosineSimilarity, DotProductSimilarity
from pytorch_metric_learning.losses import NTXentLoss, SupConLoss
from pytorch_metric_learning.reducers import AvgNonZeroReducer, PerAnchorReducer
from pytorch_metric_learning.utils import common_functions as c_f

import os

TEST_DEVICE = torch.device(os.environ.get("TEST_DEVICE", "cuda"))

dtypes_from_environ = os.environ.get("TEST_DTYPES", "float16,float32,float64").split(
    ","
)
TEST_DTYPES = [getattr(torch, x) for x in dtypes_from_environ]

In [2]:
# embeddings
batch_size = 6
hidden_dim = 13

embeddings = torch.randn((batch_size, hidden_dim))
print(embeddings)

tensor([[ 0.3272,  0.8661, -1.3020, -0.3160, -1.4932, -0.6192,  0.8469, -0.6799,
         -1.6485,  2.2359,  0.7302,  1.4681,  0.0466],
        [ 1.2040,  1.9452,  1.0550, -0.5377, -0.3891, -1.2407, -1.2048, -0.7506,
          0.4507, -0.9812, -1.2910,  0.4463, -0.7544],
        [ 0.7839,  0.7656, -0.8558, -1.8627,  1.4193,  1.9446, -1.3686,  1.5245,
          1.5849, -0.7058, -0.1134, -1.3433, -0.6476],
        [ 0.5674, -1.7444, -1.3082, -0.2297, -0.2414, -1.4376,  0.8838,  1.8955,
          1.2364,  0.4783,  0.2757,  0.0828,  1.2334],
        [-1.2474,  0.1740, -1.0800, -1.1750,  0.1283, -0.1786, -0.5509,  0.6440,
         -1.0947,  1.6392,  1.8218, -0.1523,  0.5351],
        [-1.2571, -0.0577, -0.4294, -1.4806,  0.5772,  0.6457,  2.9357,  0.4129,
         -0.5948,  1.3127,  1.9486,  0.7775,  0.1448]])


In [3]:
# labels
labels = torch.LongTensor([0, 0, 0, 1, 1, 2])

In [4]:
# loss_funcA
temperature = 0.1

loss_funcA = SupConLoss(temperature=temperature)
print(loss_funcA)


SupConLoss(
  (distance): CosineSimilarity()
  (reducer): AvgNonZeroReducer()
)


In [5]:
# 计算 loss A
loss_A = loss_funcA(embeddings, labels).float()
print(loss_A, type(loss_A))

tensor(3.6946) <class 'torch.Tensor'>


In [6]:
# 复现计算过程
# 注意：distance默认是 CosineSimilarity()

pos_pairs = [(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1), (3, 4), (4, 3)]
neg_pairs = [
    (0, 3), (0, 4), (0, 5),
    (1, 3), (1, 4), (1, 5),
    (2, 3), (2, 4), (2, 5),
    (3, 0), (3, 1), (3, 2), (3, 5),
    (4, 0), (4, 1), (4, 2), (4, 5),
    (5, 0), (5, 1), (5, 2), (5, 3), (5, 4),
]

total_lossA = torch.zeros(5, device="cpu", dtype=torch.float64)

for a1, p in pos_pairs:
    # a1: anchor;
    # p: positive instance
    
    anchor, positive = embeddings[a1], embeddings[p]
    numeratorA = torch.exp(torch.matmul(anchor, positive) / (temperature * torch.norm(anchor) * torch.norm(positive)))
    
    denominatorA = 0
    
    for a2, n in pos_pairs + neg_pairs:
        # n: negative instance
        
        if a2 == a1:
            negative = embeddings[n]
            curr_denomD = torch.exp(
                torch.matmul(anchor, negative) / (temperature * torch.norm(anchor) * torch.norm(negative))
            )
            denominatorA += curr_denomD
            
        else:
            continue
    
    print(numeratorA, denominatorA, - torch.log(numeratorA / denominatorA))
    curr_lossA = -torch.log(numeratorA / denominatorA)
    total_lossA[a1] += curr_lossA

pos_pair_per_anchor = torch.tensor(
    [2, 2, 2, 1, 1], device="cpu", dtype=torch.float64
)
total_lossA = torch.mean(total_lossA / pos_pair_per_anchor).cpu().numpy().tolist()
print("total_lossA: ", total_lossA, type(total_lossA))

rtol = 1e-4
assert torch.isclose(loss_A, total_lossA, rtol=rtol)

tensor(0.3649) tensor(218.6342) tensor(6.3954)
tensor(0.0073) tensor(218.6342) tensor(10.3051)
tensor(0.3649) tensor(2.8354) tensor(2.0502)
tensor(2.4404) tensor(2.8354) tensor(0.1500)
tensor(0.0073) tensor(5.1348) tensor(6.5537)
tensor(2.4404) tensor(5.1348) tensor(0.7439)
tensor(5.0586) tensor(14.0618) tensor(1.0224)
tensor(5.0586) tensor(392.4407) tensor(4.3513)
total_lossA:  3.6945752680301664 <class 'float'>


TypeError: isclose(): argument 'other' (position 2) must be Tensor, not float