In [None]:
%matplotlib inline

import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
import torch as pt
import time
import sys

sys.path.insert(0, '..')

from NNsolver import NNSolver
from NNarchitectures import DenseNet_g
from problems import HJB
from utilities import get_X_process, plot_NN_evaluation

%load_ext autoreload
%autoreload 2

device = pt.device('cuda')

In [None]:
T = 1.0
problem = HJB(d=10, T=T)
K = 2000
K_batch = 2000
print_every = 500
delta_t = 0.01
sq_delta_t = pt.sqrt(pt.tensor(delta_t))
N = int(np.ceil(T / delta_t))
gradient_steps = (N - 1) * [8000] + [40000]
learning_rates = [0.0001] * (N - 1) + [0.0001]


model = NNSolver(problem, 'HJB', learning_rates=learning_rates, gradient_steps=gradient_steps, NN_class=DenseNet_g, K=K, 
                 K_batch=K_batch, delta_t=delta_t, print_every=print_every, method='implicit')

model.Y_n = [DenseNet_g(problem.d, 1, lr=learning_rates[n], arch=[110, 110, 50, 50], problem=problem).to(device) for n in range(N)] + [problem.g]

In [None]:
model.train()

In [None]:
fig = plot_NN_evaluation(model, n=N-2, reference_solution=False, Y_0_true=2.1589400569)

### compare with TT solution along trajectories

In [None]:
problem.modus = 'np'
X, xi = get_X_process(problem, 10, delta_t, seed=44)
problem.modus = 'pt'

TT_traj = [np.load('data/v_tt_traj1.npy'), np.load('data/v_tt_traj2.npy')]
TT_ref = [np.load('data/v_ref_traj1.npy'), np.load('data/v_ref_traj2.npy')]

In [None]:
COLORS = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple', 'tab:brown',
          'tab:pink', 'tab:gray', 'tab:olive', 'tab:cyan']

fig, ax = plt.subplots(1, 2, figsize=(6, 3))
fig.suptitle(r'Evaluation along trajectories')
ax[0].set_xlabel(r'$t$')
ax[1].set_xlabel(r'$t$')
ax[0].set_title('TTs')
ax[1].set_title('NNs')

for i in range(2):
    if i == 0:

        ax[0].plot(t_val, TT_traj[i], label=r'$\widehat{V}(X_t, t)$',
                 color=COLORS[0], linewidth=1.2);
        ax[0].plot(t_val, TT_ref[i], '--', label='$V_{\mathrm{ref}}(X_t, t)$', color=COLORS[1], linewidth=1.2);
    else:
        ax[0].plot(t_val, TT_traj[i],
                 color=COLORS[0], linewidth=1.2);
        ax[0].plot(t_val, TT_ref[i], '--', color=COLORS[1], linewidth=1.2);

ax[0].set_ylim(1.55, 2.8)
ax[0].legend();

for i, k in enumerate([0, 1]):
    Y_ref_traj = []
    for n in range(N + 1):
        X_T_t = X[n, k, :][np.newaxis, :] + np.sqrt(2 * (problem.T - n * delta_t)) * np.random.randn(1000, problem.d)
        Y_ref_traj.append(-np.log(np.mean(1 / (0.5 + 0.5 * np.sum(X_T_t**2, 1)))))
    ax[1].plot(t_val, [model.Y_n[n](pt.tensor(X[n, k, :]).unsqueeze(0).float().to(device)).item() for n in range(N + 1)],
               color=COLORS[0], linewidth=1.2);
    ax[1].plot(t_val, Y_ref_traj, '--', color=COLORS[1], linewidth=1.2);

ax[1].set_ylim(1.55, 2.8)

fig.tight_layout(rect=[0, 0.03, 1, 0.95])
#fig.savefig('img/HJB_100_trajectory_plots_d_10.pdf')