In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

In [3]:
class ParametricReward(nn.Module):
    def __init__(self):
        super().__init__()
        # 4 reward parameters
        self.rho = nn.Parameter(torch.tensor(0.5))    # benchmark tracking
        self.eta = nn.Parameter(torch.tensor(1.0))    # internal growth rate
        self.lamb = nn.Parameter(torch.tensor(0.1))   # cashflow mismatch penalty
        self.omega = nn.Parameter(torch.tensor(0.1))  # trade cost penalty

    def forward(self, traj_batch):
        total_rewards = []

        for traj in traj_batch:
            reward = 0.0
            cash = traj['cash_t'][0]  # initial cash
            for t in range(len(traj['x_t'])):
                x_t = traj['x_t'][t]      # portfolio weights ($)
                u_t = traj['u_t'][t]      # trades
                r_t = traj['r_t'][t]      # realized sector returns
                B_t = traj['B_t'][t]      # benchmark value
                C_t = traj['C_t'][t]      # cash inflow (external)

                trade_amount = torch.sum(u_t)
                cash = cash + C_t - trade_amount  # cash_t+1 = cash_t + inflow - amount traded

                # Value of new portfolio (after trade)
                V_t = torch.dot(1 + r_t, x_t + u_t)

                # Target portfolio value (PM’s goal)
                P_hat_t = self.rho * B_t + (1 - self.rho) * self.eta * torch.sum(x_t)

                # Reward function with penalty terms
                reward += - (P_hat_t - V_t)**2 \
                          - self.lamb * (trade_amount - C_t)**2 \
                          - self.omega * torch.sum(u_t ** 2)

            total_rewards.append(reward)

        return torch.stack(total_rewards)


In [4]:
# Loss: Softmax ranking (T-REX)
def trex_loss(rewards, pair_indices):
    # pair_indices: list of tuples (i, j) where traj[i] ≺ traj[j]
    loss = 0.0
    for i, j in pair_indices:
        exp_i = torch.exp(rewards[i])
        exp_j = torch.exp(rewards[j])
        loss += -torch.log(exp_j / (exp_i + exp_j))
    return loss / len(pair_indices)

In [5]:
def train_trex_model(traj_batch, pair_indices, lr=0.01, epochs=1000):
    model = ParametricReward()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        optimizer.zero_grad()
        rewards = model(traj_batch)
        loss = trex_loss(rewards, pair_indices)
        loss.backward()
        optimizer.step()

        if epoch % 100 == 0:
            print(f"Epoch {epoch}, Loss: {loss.item():.4f}, "
                  f"ρ={model.rho.item():.3f}, η={model.eta.item():.3f}, "
                  f"λ={model.lamb.item():.3f}, ω={model.omega.item():.3f}")

    return model

In [None]:
#This is an example what traj looks like
traj = {
    'x_t': List[Tensor],     # [T, 11]
    'u_t': List[Tensor],     # [T, 11]
    'r_t': List[Tensor],     # [T, 11]
    'B_t': List[float],      # [T]
    'C_t': List[float],      # [T]
    'cash_t': List[float],   # [T+1]
}