In [None]:
# ===========================================
#  VWAP optimal execution × NN × q_T=0 constraint
#  v_t を NN が直接出力する簡潔版（実験パラメータは上）
# ===========================================

import math
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

device = "cuda" if torch.cuda.is_available() else "cpu"
print("device =", device)


# ===== 実験パラメータ (ここだけ変えればOK) =====
S0    = 50.0
q0    = 400000.0
V     = 4_000_000.0
sigma = 0.45
eta   = 0.12
phi   = 0.63
k     = 5e-7
gamma = 3e-6
lambda_qT = 1e-8

T = 1.0
J = 200
num_epochs = 300
batch_size = 256
hidden_dim = 128


# ===== NN (PolicyNet) =====
class PolicyNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.v_max = 2.0 * q0 / T
        self.net = nn.Sequential(
            nn.Linear(4, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Tanh(),
        )

    def forward(self, state):
        return self.v_max * self.net(state)


# ===== モンテカルロ (学習用) =====
def simulate_batch(policy):
    dt = T / J
    sqrt_dt = math.sqrt(dt)
    S = torch.full((batch_size,1), S0, device=device)
    q = torch.full((batch_size,1), q0, device=device)
    X = torch.zeros((batch_size,1), device=device)
    Q = torch.zeros((batch_size,1), device=device)
    vw_num = torch.zeros((batch_size,1), device=device)
    vw_den = torch.zeros((batch_size,1), device=device)
    eps = torch.randn(batch_size, J, device=device)
    S_ref, V_ref, Q_ref = S0, V, V*T
    q_ref = max(abs(q0),1.0)
    policy.train()
    for j in range(J):
        V_t = torch.full_like(S, V)
        state = torch.cat([
            S/S_ref, 
            V_t/V_ref,
            Q/Q_ref,
            q/q_ref
        ], dim=1)
        v = policy(state)
        dq = v * dt
        q_next = q - dq
        rho = v / V_t
        L = eta * torch.abs(rho)**(1+phi)
        X = X + (v*S - V_t*L)*dt
        dW = eps[:,j:j+1]*sqrt_dt
        S = S + sigma*dW - k*v*dt
        Q = Q + V_t*dt
        vw_num += S * V_t * dt
        vw_den += V_t * dt
        q = q_next
    VWAP_T = vw_num / vw_den
    pnl = X - q0 * VWAP_T
    pnl = torch.clamp(pnl, -1e6, 1e6)
    utility = torch.exp(-gamma * pnl)
    penalty = lambda_qT * (q.squeeze(1)**2)
    loss = (utility + penalty).mean()
    return loss


# ===== 学習 =====
def train():
    policy = PolicyNet().to(device)
    opt = optim.Adam(policy.parameters(), lr=1e-3)
    for epoch in range(1, num_epochs+1):
        opt.zero_grad()
        loss = simulate_batch(policy)
        loss.backward()
        nn.utils.clip_grad_norm_(policy.parameters(),1.0)
        opt.step()
        if epoch % 50 == 0:
            print(f"epoch {epoch}, loss={loss.item():.6f}")
    return policy


# ===== テストパス生成（可視化用） =====
def sample_single_path(policy):
    dt = T/J
    sqrt_dt = math.sqrt(dt)
    S = torch.full((1,1), S0, device=device)
    q = torch.full((1,1), q0, device=device)
    Q = torch.zeros((1,1), device=device)
    X = torch.zeros((1,1), device=device)
    eps = torch.randn(1,J,device=device)
    S_ref, V_ref, Q_ref = S0, V, V*T
    q_ref = max(abs(q0),1.0)
    S_hist = torch.zeros(J+1)
    q_hist = torch.zeros(J+1)
    v_hist = torch.zeros(J)
    S_hist[0] = S.item()
    q_hist[0] = q.item()
    policy.eval()
    with torch.no_grad():
        for j in range(J):
            V_t = torch.full((1,1), V, device=device)
            state = torch.cat([
                S/S_ref,
                V_t/V_ref,
                Q/Q_ref,
                q/q_ref
            ], dim=1)
            v = policy(state)
            dq = v * dt
            v_hist[j] = v.item()
            q_next = q - dq
            q_hist[j+1] = q_next.item()
            rho = v / V_t
            L = eta * torch.abs(rho)**(1+phi)
            X = X + (v*S - V_t*L)*dt
            dW = eps[:,j:j+1]*sqrt_dt
            S = S + sigma*dW - k*v*dt
            S_hist[j+1] = S.item()
            Q = Q + V_t*dt
            q = q_next
    t_grid = torch.linspace(0,T,J+1)
    t_mid  = torch.linspace(0,T,J)
    return t_grid, t_mid, S_hist.numpy(), q_hist.numpy(), v_hist.numpy()


# ===== 実行 =====
policy = train()
t_grid, t_mid, S_p, q_p, v_p = sample_single_path(policy)
plt.figure(figsize=(10,4))
plt.plot(t_grid, q_p); plt.grid(True); plt.title("q_t")
plt.figure(figsize=(10,4))
plt.plot(t_mid, v_p); plt.grid(True); plt.title("v_t")
plt.show()
