In [1]:
import torch

In [2]:
def simu_trade(output, target):
    capital = 5e8                                                           # 总资金
    target = target.to(output.device)
    buyable_amount = target[:, 4].unsqueeze(-1).float().to(output.device)   # 每只股票最大可买入的资金数额
    true_yields = target[:, 0].unsqueeze(-1).float().to(output.device)      # 真实收益率
    predicted_yields = output.unsqueeze(-1).float()                         # 预测收益率
    valid_mask = ~torch.isnan(buyable_amount) & ~torch.isnan(true_yields)   # 过滤掉缺失值
    buyable_amount = buyable_amount[valid_mask]
    true_yields = true_yields[valid_mask]
    predicted_yields = predicted_yields[valid_mask]
    top500_values, _ = torch.topk(predicted_yields, 500, largest=True, sorted=True)
    value_500 = top500_values[-1]
    buy_amount = buyable_amount[predicted_yields >= value_500]
    true_yields = true_yields[predicted_yields >= value_500]
    total_profit = torch.sum(buy_amount * true_yields) / capital            # 计算总收益率：股票收益总和 / 总资金
    return total_profit

def simu_trade_loss(output, target, temperature=1e-8):
    capital = 5e8                                                           # 总资金
    target = target.to(output.device)
    buyable_amount = target[:, 4].unsqueeze(-1).float().to(output.device)   # 每只股票最大可买入的资金数额
    true_yields = target[:, 0].unsqueeze(-1).float().to(output.device)      # 真实收益率
    predicted_yields = output.unsqueeze(-1).float()                         # 预测收益率
    valid_mask = ~torch.isnan(buyable_amount) & ~torch.isnan(true_yields)   # 过滤掉缺失值
    buyable_amount = buyable_amount[valid_mask]
    true_yields = true_yields[valid_mask]
    predicted_yields = predicted_yields[valid_mask]
    top500_values, _ = torch.topk(predicted_yields, 500, largest=True, sorted=True)
    value_500 = top500_values[-1]
    diff = (predicted_yields - value_500) / temperature
    weights = torch.sigmoid(diff)
    weighted_profit = torch.sum(buyable_amount * true_yields * weights)
    total_profit = weighted_profit / capital
    loss = -total_profit
    return loss

In [3]:
output = torch.randn(1000, requires_grad=True)  # 模型输出
target = torch.randn(1000, 5)  # 目标数据

loss = simu_trade_loss(output, target)
loss.backward()  # 反向传播
print(loss.item())

loss = simu_trade(output, target)
# loss.backward()  # 无法反向传播
print(loss.item())

-7.774396237891779e-08
7.610611163499925e-08
