In [48]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchmetrics.retrieval import RetrievalMRR
from torchmetrics.functional import retrieval_reciprocal_rank
for _ in range(10):
    torch.cuda.empty_cache()

In [169]:
import torch
import torch.nn as nn
import torchsort

class MRRLoss(nn.Module):
    def __init__(self):
        super(MRRLoss, self).__init__()

    def forward(self, logits, target):
        # logits: (batch_size, num_classes)
        # target: (batch_size, )

        # 使用 torchsort 进行可微分的排名
        print(logits)
        ranks = torchsort.soft_rank(-logits,regularization_strength=0.001)
        print(ranks)
        target_unsqueezed = target.unsqueeze(1)
        
        # 获取目标标签的排名
        target_ranks = ranks.gather(1, target_unsqueezed).squeeze(1)
        print(target_ranks)
        # 计算倒数排名
        reciprocal_ranks = 1.0 / target_ranks.float()
        
        # 计算 MRR
        mrr = torch.mean(reciprocal_ranks)

        
        # 损失是负的 MRR（因为我们希望最大化 MRR）
        loss = 1-mrr
        
        return loss*10
batch_size = 5
dim = 5
    
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
embed = nn.Embedding(5, 5).to(device)
n_id = torch.tensor([1, 2, 3, 4, 0]).to(device)
tail = torch.tensor([2, 3, 4, 0, 1]).to(device)

# 确保嵌入层的权重需要梯度
embed.weight.requires_grad = True

# 创建 MRRLoss 实例
criterion = MRRLoss().to(device)

# 前向传播
logits = embed(n_id)
loss = criterion(logits, tail)

# 反向传播
loss.backward()

print("Loss:", loss.item())
target_one_hot = F.one_hot(target, num_classes=dim).float().to(device)

tensor([[ 2.4437, -2.1697, -0.4736,  0.0405,  1.9425],
        [ 0.2956, -0.8279, -1.2893,  0.3296,  0.1969],
        [ 0.9472, -1.1201,  0.4256,  1.8251,  0.8675],
        [ 0.2216, -2.7837, -0.0654, -0.2471,  0.2239],
        [-0.9022, -0.5339, -1.7308,  0.7359,  0.9498]], device='cuda:0',
       grad_fn=<EmbeddingBackward0>)
tensor([[1., 5., 4., 3., 2.],
        [2., 4., 5., 1., 3.],
        [2., 5., 4., 1., 3.],
        [2., 5., 3., 4., 1.],
        [4., 3., 5., 2., 1.]], device='cuda:0', grad_fn=<SoftRankBackward>)
tensor([4., 1., 3., 2., 3.], device='cuda:0', grad_fn=<SqueezeBackward1>)
Loss: 5.166666507720947


In [167]:
target_one_hot = F.one_hot(tail, num_classes=dim).float().to(device)
retrieval_reciprocal_rank(-logits,target_one_hot)

tensor(0.1667, device='cuda:0')

In [87]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SoftMRRLoss(nn.Module):
    def __init__(self, sigma=1.0):
        super(SoftMRRLoss, self).__init__()
        self.sigma = sigma

    def forward(self, logits, target):
        # logits: (batch_size, num_classes)
        # target: (batch_size, num_classes) in one-hot encoded form
        
        # Apply softmax to logits to get probabilities
        probs = torch.softmax(logits, dim=1)
        
        # Compute the "soft" ranks
        batch_size, num_classes = probs.shape
        ranks = torch.arange(1, num_classes + 1, device=logits.device).view(1, -1).repeat(batch_size, 1)
        ranks = ranks.float()
        
        sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=1)
        sorted_ranks = torch.gather(ranks, 1, sorted_indices.argsort(dim=1))
        
        # Compute soft ranks
        soft_ranks = torch.sum(sorted_probs * sorted_ranks, dim=1)
        
        # Compute reciprocal ranks
        reciprocal_ranks = 1.0 / soft_ranks
        
        # Compute MRR
        mrr = torch.mean(torch.sum(target * reciprocal_ranks.view(-1, 1), dim=1))
        
        # The loss is the negative MRR (since we want to maximize MRR)
        loss = -mrr
        
        return loss

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 模型和数据
batch_size = 5
dim = 5
model = nn.Embedding(10, dim).to(device)

n_id = torch.tensor([1, 2, 3, 4, 0]).to(device)
target = torch.tensor([2, 3, 4, 0, 1]).to(device)  # 目标编号形状为 (batch_size,)
target_one_hot = F.one_hot(target, num_classes=dim).float().to(device)  # 目标转换为多热编码

# 前向传播
logits = model(n_id)  # 输出形状为 (batch_size, dim)

# 初始化 SoftMRRLoss
soft_mrr_loss = SoftMRRLoss().to(device)

# 计算 MRR 损失
loss = soft_mrr_loss(logits, target_one_hot)

# 反向传播
loss.backward()

print(f"Loss: {loss.item()}")


Loss: -0.3551948666572571


tensor([10000000000, 10000000000,           1,           0])


tensor([[False, False],
        [False, False],
        [False,  True],
        [ True, False]])
