In [1]:
import torch

In [2]:
def global_mean_pool(hidden_states, labels):
    """
    全局均值池化（所有有效token取平均）
    
    Args:
        hidden_states: 模型最后一层隐藏状态 [batch_size, seq_len, hidden_dim]
        labels: 每个token的标签 [batch_size, seq_len]（-100表示无效）
        
    Returns:
        pooled_features: 池化后的特征 [batch_size, hidden_dim]
    """
    # 创建有效掩码（排除-100），并转换为与hidden_states相同的数据类型
    valid_mask = (labels != -100).unsqueeze(-1)  # [batch_size, seq_len, 1]
    valid_mask = valid_mask.to(hidden_states.dtype)  # 确保数据类型一致
    
    # 计算加权和（自动广播valid_mask到hidden_dim维度）
    sum_hidden = torch.sum(hidden_states * valid_mask, dim=1)  # [batch_size, hidden_dim]
    
    # 计算有效token数量（转换为浮点型）
    num_valid = torch.sum(valid_mask, dim=1)  # [batch_size, 1]
    num_valid = torch.clamp(num_valid, min=1e-7)  # 防止除零
    
    # 均值池化（广播除法）
    pooled_features = sum_hidden / num_valid
    return pooled_features

In [3]:
# 示例 2
batch_size = 3
seq_len = 100
hidden_dim = 2
hidden_states = torch.randint(low=0, high=10, size=(batch_size, seq_len, hidden_dim))
# 生成随机的 labels，部分值设为 -100 表示无效
labels = torch.randint(low=-100, high=5, size=(batch_size, seq_len))
# labels = torch.tensor([[1, 1, -100, 2], [3, -100, 3, 4], [1, 1, -100, -100]])

pooled_features= global_mean_pool(hidden_states, labels)
print(hidden_states.shape)
print("池化后的特征:", pooled_features.shape)

torch.Size([3, 100, 2])
池化后的特征: torch.Size([3, 2])


### 测试损失函数

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [8]:
class FocalInfoNCE(nn.Module):
    def __init__(self, m=0.5, tau=0.1, eps=1e-8):
        super().__init__()
        self.m = m
        self.tau = tau
        self.eps = eps  # 数值稳定性参数

    def forward(self, emb_features, nce_labels):
        """
        Args:
            emb_features (Tensor): 特征向量，(batch_size, feature_dim)
            nce_labels (Tensor): 类别标签，(batch_size,)
        Returns:
            loss (Tensor): 标量损失值
        """
        batch_size = emb_features.size(0)
        device = emb_features.device

        # 1. 特征归一化（余弦相似度）
        emb_features = nn.functional.normalize(emb_features, p=2, dim=1)  # L2归一化
        sim_matrix = torch.matmul(emb_features, emb_features.T)  # (b, b)

        # 2. 构造正/负样本掩码
        pos_mask = (nce_labels.unsqueeze(1) == nce_labels.unsqueeze(0)) & \
                  (~torch.eye(batch_size, dtype=torch.bool, device=device))
        print("pos_mask:\n",pos_mask)
        neg_mask = (nce_labels.unsqueeze(1) != nce_labels.unsqueeze(0))
        print("neg_mask:\n",neg_mask)

        # 3. 正样本得分（平均处理）
        s_p = torch.sum(sim_matrix * pos_mask, dim=1) / (pos_mask.sum(dim=1) + self.eps)

        # 4. 负样本调制项（仅保留负样本位置）
        s_n_terms = sim_matrix.clone()  # 避免修改原始张量
        s_n_terms = s_n_terms * neg_mask.float()  # 清零正样本位置
        s_n_terms = s_n_terms * (s_n_terms + self.m) / self.tau  # 仅负样本参与计算

        # 5. 数值稳定损失计算
        max_val = torch.max((s_p** 2) / self.tau)  # 防止指数溢出
        numerator = torch.exp((s_p** 2) / self.tau - max_val)
        denominator = numerator + torch.sum(torch.exp(s_n_terms - max_val), dim=1)
        loss = - (torch.log(numerator / (denominator + self.eps)) + max_val)  # 恢复偏移

        return loss.mean().item()

# 示例数据（归一化后）
batch_size = 4
feature_dim = 8
emb_features = torch.randn(batch_size, feature_dim)
nce_labels = torch.tensor([0, 0, 1, 1])
print("emb_features\n",emb_features)
print("nce_labels\n",nce_labels)
# 初始化损失函数
criterion = FocalInfoNCE(m=0.5, tau=0.1)

# 计算损失
loss = criterion(emb_features, nce_labels)
print("改进后的损失值:", loss)

emb_features
 tensor([[ 2.7103,  1.7868,  0.3986, -0.5020, -1.1758, -1.4855,  1.0214,  1.9834],
        [-0.2004,  0.4443, -1.7188, -0.4812,  1.4300, -0.8460, -1.5075,  0.6973],
        [ 1.1274,  0.1804, -0.4552, -0.3153, -0.5012, -1.4247,  0.4055,  0.4647],
        [-0.4812, -0.9740,  1.4976, -0.0072, -0.4334, -0.4316, -0.8480, -1.2850]])
nce_labels
 tensor([0, 0, 1, 1])
pos_mask:
 tensor([[False,  True, False, False],
        [ True, False, False, False],
        [False, False, False,  True],
        [False, False,  True, False]])
neg_mask:
 tensor([[False, False,  True,  True],
        [False, False,  True,  True],
        [ True,  True, False, False],
        [ True,  True, False, False]])
改进后的损失值: 4.902640342712402


In [None]:
class FusedLoss(nn.Module):
    def __init__(self, alpha=1.0, tau=0.07, eps=1e-8):
        super().__init__()
        self.alpha = alpha  # 正样本项的加权系数
        self.tau = tau      # 温度参数
        self.eps = eps      # 数值稳定性常数
        # 确保参数合法性
        assert self.alpha > 0, "Alpha must be greater than 0"
        assert self.tau > 0, "Tau must be greater than 0"
        assert self.eps > 0, "Epsilon must be greater than 0"

    def forward(self, cos_sim, labels):
        """
        Args:
            cos_sim: (b, n) 输入的相似度矩阵，未经过温度缩放
            labels: (b,) 每个样本的正样本位置索引
        Returns:
            loss: 计算得到的损失值，保留梯度信息
        """
        device = cos_sim.device
        b, n = cos_sim.shape

        # 提取正样本分数并增强 (b,)
        pos_scores = cos_sim.gather(1, labels.unsqueeze(1)).squeeze(1)
        enhanced_pos = self._enhance_pos(pos_scores)  # 正样本增强

        # 调制负样本分数 (b, n)
        mod_neg_scores = self._modulate_neg(cos_sim)
        weighted_neg = self._weight_neg(cos_sim)  # 负样本加权
        mod_neg_scores_weight = mod_neg_scores * weighted_neg  # 应用对样本进行加权

        # 生成 mask 以排除正样本位置 (b, n)
        mask = torch.zeros_like(mod_neg_scores_weight, dtype=torch.bool)
        mask.scatter_(1, labels.unsqueeze(1), True)  # 正样本位置标记为True

        # 将正样本位置分数设为-无穷，避免其参与负样本求和
        mod_neg_scores_weight_masked = mod_neg_scores_weight.masked_fill(mask, -float('inf'))
        
        # 计算负样本的指数和 (b,)
        sum_neg_logits = torch.logsumexp(mod_neg_scores_weight_masked, dim=1)
        sum_neg_exp = sum_neg_logits.exp()

        # 正样本加权调整 (b,)
        weighted_pos = self._weight_pos(pos_scores)
        adjusted_pos = enhanced_pos + weighted_pos

        # 构造分母项
        pos_exp_term = weighted_pos
        denominator = pos_exp_term + sum_neg_exp + self.eps

        # 计算最终损失 (log(denominator) - adjusted_pos) 的均值
        loss = (torch.log(denominator) - adjusted_pos).mean().item()

        return loss
    # 定义各调制函数

    def _enhance_pos(self, pos_scores):
        """
        正样本增强函数
        Args:
            pos_scores: (b,) 正样本分数
        Returns:
            enhanced_pos: (b,) 增强后的正样本分数
        """
        enhanced_pos = pos_scores.clone().to(pos_scores.device)  # 深拷贝，避免原地操作
        enhanced_pos = enhanced_pos/self.tau
        return enhanced_pos
    
    def _modulate_neg(self, neg_scores):
        """
        负样本调制函数
        Args:
            neg_scores: (b, n) 负样本分数
        Returns:
            modulated_neg: (b, n) 调制后的负样本分数
        """
        modulated_neg = neg_scores.clone().to(neg_scores.device)  # 深拷贝，避免原地操作
        modulated_neg = modulated_neg/self.tau
        return modulated_neg

    def _weight_pos(self, pos_scores):
        """
        正样本加权函数
        Args:
            pos_scores: (b,) 正样本分数
        Returns:
            weighted_pos: (b,) 加权后的正样本分数
        """
        # weighted_pos = pos_scores.clone()
        weighted_pos = torch.full_like(pos_scores, self.alpha).to(pos_scores.device)
        return weighted_pos
    
    def _weight_neg(self, neg_scores):
        """
        负样本加权函数
        Args:
            neg_scores: (b, n) 负样本分数
        Returns:
            weighted_neg: (b, n) 加权后的负样本分数
        """
        # weighted_neg = neg_scores.clone()
        weighted_neg = torch.full_like(neg_scores, self.alpha).to(neg_scores.device)
        return weighted_neg