In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# 定义一个基本的特征提取器
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        # 假设使用一个简单的全连接层代替复杂的神经网络骨干
        self.fc = nn.Linear(256, 128)  # 假设输入特征维度为256，输出维度为128

    def forward(self, x):
        return self.fc(x)

# 定义两个多层感知机（MLP）来生成因果和非因果特征
class MLP(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.fc1(x)

# 用于计算 InfoNCE 损失的函数
def info_nce_loss(features, labels, weight_matrix):
    # 计算 T^c(f^c, y) = f^c W^c y
    scores = torch.matmul(features, weight_matrix)
    scores = torch.matmul(scores, labels.T)
    # 计算 InfoNCE 损失
    nce_loss = -torch.mean(torch.diag(scores) - torch.logsumexp(scores, dim=1))
    return nce_loss

# 实例化模型和权重矩阵
feature_extractor = FeatureExtractor()
mlp_causal = MLP(128, 64)  # 假设因果特征维度为64
mlp_noncausal = MLP(128, 64)  # 假设非因果特征维度为64
weight_matrix = torch.randn(64, 64)  # 假设权重矩阵的维度

# 假设输入和标签
x = torch.randn(10, 256)  # 假设有10个样本，每个样本特征维度为256
y = torch.randn(10, 64)  # 假设有10个标签，每个标签维度为64

# 特征提取
z = feature_extractor(x)

# 生成因果和非因果特征
f_c = mlp_causal(z)
f_n = mlp_noncausal(z)

# 计算 InfoNCE 损失
loss = info_nce_loss(f_c, y, weight_matrix)

loss.item()  # 返回损失的数值

45.6485710144043