In [None]:
# ================================
# VWAP optimal execution × NN × q_T=0 soft constraint
# PolicyNet outputs v_t directly (no intermediate u)
# ================================
import math
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

device = "cuda" if torch.cuda.is_available() else "cpu"
print("device:", device)


class PolicyNet(nn.Module):
    def __init__(self, state_dim=4, hidden_dim=128, v_max=None, q0=400000.0, T=1.0):
        super().__init__()
        if v_max is None:
            v_max = 2.0 * q0 / T
        self.v_max = float(v_max)
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Tanh(),
        )
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.v_max * self.net(x)


def simulate_batch_signed_volume(
    policy: nn.Module,
    batch_size: int = 256,
    J: int = 200,
    T: float = 1.0,
    S0: float = 50.0,
    q0: float = 400000.0,
    V_const: float = 4_000_000.0,
    sigma: float = 0.45,
    eta: float = 0.12,
    phi: float = 0.63,
    k: float = 5e-7,
    gamma: float = 3e-6,
    lambda_qT: float = 1e-8,
    device: str = "cpu",
) -> torch.Tensor:
    policy.train()
    dt = T / J
    sqrt_dt = math.sqrt(dt)
    S = torch.full((batch_size, 1), S0, device=device)
    X = torch.zeros((batch_size, 1), device=device)
    q = torch.full((batch_size, 1), q0, device=device)
    Q = torch.zeros((batch_size, 1), device=device)
    vw_num = torch.zeros((batch_size, 1), device=device)
    vw_den = torch.zeros((batch_size, 1), device=device)
    eps_S = torch.randn(batch_size, J, device=device)
    S_ref = S0
    V_ref = V_const
    Q_ref = V_const * T
    q_ref = max(abs(q0), 1.0)
    for j in range(J):
        V_t = torch.full_like(S, V_const)
        Q = Q + V_t * dt
        S_norm = S / S_ref
        V_norm = V_t / V_ref
        Q_norm = Q / Q_ref
        q_norm = q / q_ref
        state = torch.cat([S_norm, V_norm, Q_norm, q_norm], dim=1)
        v = policy(state)
        dq = v * dt
        q_next = q - dq
        rho = v / V_t
        L_val = eta * torch.abs(rho) ** (1.0 + phi)
        dX = (v * S - V_t * L_val) * dt
        X = X + dX
        dW = eps_S[:, j:j+1] * sqrt_dt
        dS = sigma * dW - k * v * dt
        S = S + dS
        q = q_next
        vw_num = vw_num + S * V_t * dt
        vw_den = vw_den + V_t * dt
    VWAP_T = vw_num / vw_den
    pnl = X - q0 * VWAP_T
    pnl = torch.clamp(pnl, -1e6, 1e6)
    utility = torch.exp(-gamma * pnl)
    penalty = lambda_qT * (q.squeeze(1) ** 2)
    loss = (utility + penalty).mean()
    return loss


def train_policy_signed_volume(
    num_epochs: int = 300,
    batch_size: int = 256,
    lr: float = 1e-3,
    lambda_qT: float = 1e-8,
    q0: float = 400000.0,
    T: float = 1.0,
    device: str = "cpu",
) -> nn.Module:
    policy = PolicyNet(state_dim=4, hidden_dim=128, v_max=None, q0=q0, T=T).to(device)
    opt = optim.Adam(policy.parameters(), lr=lr)
    for e in range(1, num_epochs + 1):
        opt.zero_grad()
        loss = simulate_batch_signed_volume(
            policy,
            batch_size=batch_size,
            T=T,
            q0=q0,
            lambda_qT=lambda_qT,
            device=device,
        )
        loss.backward()
        nn.utils.clip_grad_norm_(policy.parameters(), max_norm=1.0)
        opt.step()
        if e % 50 == 0:
            print(f"epoch {e}, loss = {loss.item():.6f}")
    return policy


def sample_single_test_path_signed(
    policy: nn.Module,
    J: int = 200,
    T: float = 1.0,
    S0: float = 50.0,
    q0: float = 400000.0,
    V_const: float = 4_000_000.0,
    sigma: float = 0.45,
    eta: float = 0.12,
    phi: float = 0.63,
    k: float = 5e-7,
    gamma: float = 3e-6,
    device: str = "cpu",
):
    policy.eval()
    dt = T / J
    sqrt_dt = math.sqrt(dt)
    S = torch.full((1, 1), S0, device=device)
    X = torch.zeros((1, 1), device=device)
    q = torch.full((1, 1), q0, device=device)
    Q = torch.zeros((1, 1), device=device)
    eps_S = torch.randn(1, J, device=device)
    S_ref = S0
    V_ref = V_const
    Q_ref = V_const * T
    q_ref = max(abs(q0), 1.0)
    S_hist = torch.zeros(J + 1, device=device)
    V_hist = torch.zeros(J,     device=device)
    Q_hist = torch.zeros(J + 1, device=device)
    q_hist = torch.zeros(J + 1, device=device)
    v_hist = torch.zeros(J,     device=device)
    S_hist[0] = S.item()
    Q_hist[0] = Q.item()
    q_hist[0] = q.item()
    with torch.no_grad():
        for j in range(J):
            V_t = torch.full((1, 1), V_const, device=device)
            V_hist[j] = V_t.item()
            Q = Q + V_t * dt
            Q_hist[j + 1] = Q.item()
            S_norm = S / S_ref
            V_norm = V_t / V_ref
            Q_norm = Q / Q_ref
            q_norm = q / q_ref
            state = torch.cat([S_norm, V_norm, Q_norm, q_norm], dim=1)
            v = policy(state)
            dq = v * dt
            v_hist[j] = v.item()
            q_next = q - dq
            q_hist[j + 1] = q_next.item()
            rho = v / V_t
            L_val = eta * torch.abs(rho) ** (1.0 + phi)
            dX = (v * S - V_t * L_val) * dt
            X = X + dX
            dW = eps_S[:, j:j+1] * sqrt_dt
            dS = sigma * dW - k * v * dt
            S = S + dS
            S_hist[j + 1] = S.item()
            q = q_next
    t_grid = torch.linspace(0.0, T, J + 1).cpu().numpy()
    t_mid  = torch.linspace(0.0, T, J).cpu().numpy()
    return t_grid, t_mid, S_hist.cpu().numpy(), V_hist.cpu().numpy(), Q_hist.cpu().numpy(), q_hist.cpu().numpy(), v_hist.cpu().numpy()


print("Training start...")
policy = train_policy_signed_volume(
    num_epochs=300,
    batch_size=256,
    lr=1e-3,
    lambda_qT=1e-8,
    q0=400000.0,
    T=1.0,
    device=device,
)
print("Training done.")

t_grid, t_mid, S_p, V_p, Q_p, q_p, v_p = sample_single_test_path_signed(
    policy,
    J=200,
    T=1.0,
    device=device,
)

plt.figure(figsize=(10, 4))
plt.plot(t_grid, S_p)
plt.xlabel("t"); plt.ylabel("S_t"); plt.title("Single path: S_t"); plt.grid(True)

plt.figure(figsize=(10, 4))
plt.plot(t_grid, q_p)
plt.xlabel("t"); plt.ylabel("q_t"); plt.title("Single path: inventory q_t (signed)"); plt.grid(True)

plt.figure(figsize=(10, 4))
plt.plot(t_mid, v_p)
plt.xlabel("t"); plt.ylabel("v_t"); plt.title("Single path: trading speed v_t (signed)"); plt.grid(True)

plt.show()
