This code comes: https://github.com/hanbingyan/FVIOT/tree/main

In [1]:
import torch
import numpy as np
import random

random.seed(12345)
np.random.seed(12345)
torch.manual_seed(12345)
# check gpu is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Config
MEM_SIZE = 3000
BATCH_SIZE = 128
DISCOUNT = 1.0
N_INSTANCE = 10

Utility functions for a more compact code and for the optimisation of the nerual networks.

In [2]:
from collections import namedtuple
import torch.nn as nn

# def sinkhorn_knopp(mu, nu, C, reg, niter):
#     K = np.exp(-C/C.max()/reg)
#     u = np.ones((len(mu), ))
#     for i in range(1, niter):
#         v = nu/np.dot(K.T, u)
#         u = mu/(np.dot(K, v))
#     Pi = np.diag(u) @ K @ np.diag(v)
#     return Pi


Transition = namedtuple('Transition', ('time', 'x', 'y', 'value'))

class Memory(object):
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def clear(self):
        self.memory.clear()
        self.position = 0

    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        samples = random.sample(self.memory, batch_size)
        return samples

    def __len__(self):
        return len(self.memory)


def optimize_model(policy_net, memory, optimizer, Trunc_flag):
    if len(memory) < BATCH_SIZE:
        return

    transitions = memory.sample(BATCH_SIZE)
    # Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
    # detailed explanation). This converts batch-array of Transitions
    # to Transition of batch-arrays.
    batch = Transition(*zip(*transitions))
    values_batch = torch.stack(batch.value)
    x_batch = torch.stack(batch.x)
    y_batch = torch.stack(batch.y)
    time_batch = torch.stack(batch.time)

    left_values = policy_net(time_batch, x_batch, y_batch)

    # # Compute the expected Q values
    Loss_fn = nn.SmoothL1Loss()
    # Loss_fn = nn.MSELoss()
    loss = Loss_fn(left_values, values_batch)

    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    if Trunc_flag:
        for param in policy_net.parameters():
            param.grad.data.clamp_(-1, 1)
    optimizer.step()
    return loss

The neural network at hand.

In [3]:
import torch.nn as nn
import torch.nn.functional as F
import nbimporter

h = 8
class DQN(nn.Module):
    def __init__(self, x_dim, y_dim, T):
        super(DQN, self).__init__()
        self.T = T
        self.linear1 = nn.Linear(x_dim+y_dim, h)
        # self.linear1.weight.data.fill_(10.0)
        # torch.nn.init.xavier_uniform_(self.linear1.weight)
        # torch.nn.init.zeros_(self.linear1.weight)
        # torch.nn.init.zeros_(self.linear1.bias)
        # self.bn = nn.BatchNorm1d(h)
        self.linear2 = nn.Linear(h, h)
        # torch.nn.init.xavier_uniform_(self.linear2.weight)
        # torch.nn.init.zeros_(self.linear2.bias)
        # torch.nn.init.zeros_(self.linear2.weight)

        # self.dropout = nn.Dropout(p=0.5)

        self.linear3 = nn.Linear(h, 1)

        self.linear5 = nn.Linear(2, 1)
        # torch.nn.init.zeros_(self.linear5.bias)
        # torch.nn.init.zeros_(self.linear5.weight)
        # torch.nn.init.xavier_uniform_(self.linear5.weight)
        self.linear6 = nn.Linear(2, 1)
        # torch.nn.init.zeros_(self.linear6.bias)

    def forward(self, time, x, y):
        state = torch.cat((x, y), dim=1)
        state = torch.relu(self.linear1(state))
        # state = self.bn(state)
        state = torch.relu(self.linear2(state))
        # state = self.dropout(state)
        state = torch.sigmoid(self.linear3(state))
        time_f2 = torch.cat((self.T - time, (self.T - time)**2), dim=1)
        time_f1 =  self.linear5(time_f2)
        time_f2 = self.linear6(time_f2)
        return state*time_f1 + time_f2

The code to compute the Adapted Wasserstein distance (AW_2) between two brownian path.

In [4]:
import torch.optim as optim
from torch.distributions.multivariate_normal import MultivariateNormal
import ot
import time as Clock

start = Clock.time()

####### One-dimensional case #########
# with parameter constraint
Trunc_flag = True
# No. of gradient descent steps (G)
N_OPT = 50
# No. of sample paths (N)
smp_size = 2000
# Sample size for empirical OT (B)
in_sample_size = 50

time_horizon = 4
x_dim = 1
y_dim = 1
x_vol = 1.0
y_vol = 0.5
x_init = 1.0
y_init = 2.0


###### Multidimensional case #########
## no parameter constraint
# Trunc_flag = False
# time_horizon = 5
# x_dim = 5
# y_dim = 5
# x_vol = 1.1
# y_vol = 0.1
# x_init = 1.0
# y_init = 2.0
# N_OPT = 400
# smp_size = 4000
# in_sample_size = 300


final_result = np.zeros(N_INSTANCE)

for n_ins in range(N_INSTANCE):

    val_hist = np.zeros(time_horizon+1)
    loss_hist = np.zeros(time_horizon+1)

    memory = Memory(MEM_SIZE)
    policy_net = DQN(x_dim, y_dim, time_horizon).to(device)
    target_net = DQN(x_dim, y_dim, time_horizon).to(device)
    target_net.load_state_dict(policy_net.state_dict())
    target_net.eval()
    # optimizer = optim.SGD(policy_net.parameters(), lr=0.1, momentum=0.9)
    optimizer = optim.Adam(policy_net.parameters(), lr=1e-2) # weight_decay=1e-3)

    x_path_pool = torch.zeros(smp_size, time_horizon+1, x_dim, device=device)
    y_path_pool = torch.zeros(smp_size, time_horizon+1, y_dim, device=device)
    x_path_pool[:, 0, :] = x_init
    y_path_pool[:, 0, :] = y_init

    for smp_id in range(smp_size):
        # sample many paths in advance
        for t in range(1, time_horizon + 1):
            x_path_pool[smp_id, t, :] = x_path_pool[smp_id, t - 1, :] + x_vol * torch.randn(x_dim, device=device)
            y_path_pool[smp_id, t, :] = y_path_pool[smp_id, t - 1, :] + y_vol * torch.randn(y_dim, device=device)

    for time in range(time_horizon, -1, -1):

        for smp_id in range(smp_size):
            x_mvn = MultivariateNormal(loc=x_path_pool[smp_id, time, :], covariance_matrix=torch.eye(x_dim, device=device)*x_vol**2)
            y_mvn = MultivariateNormal(loc=y_path_pool[smp_id, time, :], covariance_matrix=torch.eye(y_dim, device=device)*y_vol**2)
            next_x = x_mvn.sample((400,))
            next_y = y_mvn.sample((400,))

            

            x_batch = torch.repeat_interleave(next_x, repeats=in_sample_size, dim=0)
            y_batch = torch.tile(next_y, (in_sample_size, 1))
            l2_mat = torch.sum((x_batch - y_batch)**2, dim=1)

            if time == time_horizon:
                expected_v = 0.0
            elif time == time_horizon-1:
                min_obj = l2_mat.reshape(in_sample_size, in_sample_size)
                expected_v = ot.emd2(np.ones(in_sample_size) / in_sample_size, np.ones(in_sample_size) / in_sample_size,
                                     min_obj.detach().cpu().numpy())
            else:
                val = target_net(torch.ones(x_batch.shape[0], 1, device=device)*(time+1.0), x_batch, y_batch).reshape(-1)
                min_obj = (l2_mat + DISCOUNT*val).reshape(in_sample_size, in_sample_size)
                expected_v = ot.emd2(np.ones(in_sample_size)/in_sample_size, np.ones(in_sample_size)/in_sample_size,
                                     min_obj.detach().cpu().numpy())

            memory.push(torch.tensor([time], dtype=torch.float32, device=device), x_path_pool[smp_id, time, :],
                        y_path_pool[smp_id, time, :], torch.tensor([expected_v], device=device))

        # Optimize at time t
        for opt_step in range(N_OPT):
            loss = optimize_model(policy_net, memory, optimizer, Trunc_flag)
            if Trunc_flag:
                with torch.no_grad():
                    for param in policy_net.parameters():
                        ## param.add_(torch.randn(param.size(), device=device)/50)
                        param.clamp_(-1.0, 1.0)
            if loss:
                loss_hist[time] += loss.detach().cpu().item()


        loss_hist[time] /= N_OPT

        # update target network
        target_net.load_state_dict(policy_net.state_dict())
        # test initial value
        val = target_net(torch.ones(1, 1, device=device)*0.0, x_path_pool[0, 0, :].reshape(1, x_dim),
                         y_path_pool[0, 0, :].reshape(1, y_dim)).reshape(-1)
        val_hist[time] = val

        # empty memory
        memory.clear()
        print('Time step', time, 'Loss', loss_hist[time])

        # print('Shift vector in the last layer:', target_net.linear3.bias.sum().item())


    # for name, param in target_net.named_parameters():
    #     if param.requires_grad:
    #         print(name, param.data)


    print('Instance', n_ins)
    # print('Time elapsed', end - start)
    print('Last values', val_hist[0])
    final_result[n_ins] = val_hist[0]

print('All final value:', final_result)
print('Final mean:', final_result.mean())
print('Final std:', final_result.std())
end = Clock.time()
print('Average time for one instance:', (end-start)/N_INSTANCE)
# plt.figure(figsize=(8, 6))
# plt.plot(val_hist)
# plt.xlabel('Steps', fontsize=16)
# plt.ylabel(r'$V_0$', fontsize=16)
# # plt.tick_params(axis = 'both', which = 'major', labelsize = 16)
# plt.legend(bbox_to_anchor=(1, 1), title='', fontsize=16, title_fontsize=16)
# plt.savefig('conti_val.pdf', format='pdf', dpi=1000, bbox_inches='tight', pad_inches=0.1)
# plt.show()


# plt.figure(figsize=(8, 6))
# plt.plot(loss_hist)
# plt.xlabel('Steps', fontsize=16)
# plt.ylabel('Loss', fontsize=16)
# plt.savefig('conti_loss.pdf', format='pdf', dpi=1000, bbox_inches='tight', pad_inches=0.1)
# plt.show()

Time step 4 Loss 0.00026412227157379675
Time step 3 Loss 4.137784266471863
Time step 2 Loss 1.8712179160118103
Time step 1 Loss 0.6323760211467743
Time step 0 Loss 0.3541708070039749
Instance 0
Last values 6.563196182250977
Time step 4 Loss 0.0007923892914686804
Time step 3 Loss 3.531408405303955
Time step 2 Loss 1.8780092191696167
Time step 1 Loss 0.7234289163351059
Time step 0 Loss 0.4100919580459595
Instance 1
Last values 7.604437351226807
Time step 4 Loss 0.007531039901805343
Time step 3 Loss 3.5540512704849245
Time step 2 Loss 2.5861518812179565
Time step 1 Loss 1.1958688259124757
Time step 0 Loss 0.40108677059412
Instance 2
Last values 7.97860860824585
Time step 4 Loss 0.007473419991874834
Time step 3 Loss 3.7993326807022094
Time step 2 Loss 1.4090513098239899
Time step 1 Loss 0.630340029001236
Time step 0 Loss 0.3933335566520691
Instance 3
Last values 7.170917510986328
Time step 4 Loss 0.16258780620992183
Time step 3 Loss 4.519766855239868
Time step 2 Loss 2.5808471608161927
Tim

## SUPER LONG TO RUN DISCRETE CDE USING KMEANS

In [6]:
from sklearn.cluster import KMeans
import numpy as np

import torch.optim as optim
from torch.distributions.multivariate_normal import MultivariateNormal
import ot
import time as Clock

start = Clock.time()

####### One-dimensional case #########
# with parameter constraint
Trunc_flag = True
# No. of gradient descent steps (G)
N_OPT = 50
# No. of sample paths (N)
smp_size = 2000
# Sample size for empirical OT (B)
in_sample_size = 50

time_horizon = 4
x_dim = 1
y_dim = 1
x_vol = 1.0
y_vol = 0.5
x_init = 1.0
y_init = 2.0


final_result = np.zeros(N_INSTANCE)

for n_ins in range(N_INSTANCE):

    val_hist = np.zeros(time_horizon+1)
    loss_hist = np.zeros(time_horizon+1)

    memory = Memory(MEM_SIZE)
    policy_net = DQN(x_dim, y_dim, time_horizon).to(device)
    target_net = DQN(x_dim, y_dim, time_horizon).to(device)
    target_net.load_state_dict(policy_net.state_dict())
    target_net.eval()
    # optimizer = optim.SGD(policy_net.parameters(), lr=0.1, momentum=0.9)
    optimizer = optim.Adam(policy_net.parameters(), lr=1e-2) # weight_decay=1e-3)

    x_path_pool = torch.zeros(smp_size, time_horizon+1, x_dim, device=device)
    y_path_pool = torch.zeros(smp_size, time_horizon+1, y_dim, device=device)
    x_path_pool[:, 0, :] = x_init
    y_path_pool[:, 0, :] = y_init

    for smp_id in range(smp_size):
        # sample many paths in advance
        for t in range(1, time_horizon + 1):
            x_path_pool[smp_id, t, :] = x_path_pool[smp_id, t - 1, :] + x_vol * torch.randn(x_dim, device=device)
            y_path_pool[smp_id, t, :] = y_path_pool[smp_id, t - 1, :] + y_vol * torch.randn(y_dim, device=device)

    for time in range(time_horizon, -1, -1):

        for smp_id in range(smp_size):

            if time < time_horizon:
                if smp_id % 100 ==0:
                    print(smp_id)
                x_mvn = MultivariateNormal(loc=x_path_pool[smp_id, time, :], covariance_matrix=torch.eye(x_dim, device=device)*x_vol**2)
                y_mvn = MultivariateNormal(loc=y_path_pool[smp_id, time, :], covariance_matrix=torch.eye(y_dim, device=device)*y_vol**2)
                # Draw 400 samples for x and y from the respective Gaussian distributions
                num_samples = 150
                sampled_x = x_mvn.sample((num_samples,))  # shape: (400, x_dim)
                sampled_y = y_mvn.sample((num_samples,))  # shape: (400, y_dim)

                # Transfer to CPU and convert to numpy arrays (needed for sklearn's KMeans)
                sampled_x_np = sampled_x.detach().cpu().numpy()
                sampled_y_np = sampled_y.detach().cpu().numpy()

                # Cluster the x samples
                kmeans_x = KMeans(n_clusters=in_sample_size, n_init=10).fit(sampled_x_np)
                centers_x = torch.tensor(kmeans_x.cluster_centers_, device=device, dtype=sampled_x.dtype)
                # Compute the cluster weights (nonuniform probabilities)
                weights_x_np = np.bincount(kmeans_x.labels_, minlength=in_sample_size).astype(np.float32) / num_samples

                # Cluster the y samples
                kmeans_y = KMeans(n_clusters=in_sample_size, n_init=10).fit(sampled_y_np)
                centers_y = torch.tensor(kmeans_y.cluster_centers_, device=device, dtype=sampled_y.dtype)
                weights_y_np = np.bincount(kmeans_y.labels_, minlength=in_sample_size).astype(np.float32) / num_samples

                # Compute the squared Euclidean cost matrix between cluster centers
                cost_matrix = torch.cdist(centers_x, centers_y, p=2)**2

            # Now, depending on the time step, compute expected value via optimal transport
            if time == time_horizon:
                expected_v = 0.0
            elif time == time_horizon - 1:
                # For the final step, use only the distance cost
                expected_v = ot.emd2(weights_x_np, weights_y_np, cost_matrix.detach().cpu().numpy())
            else:
                # For intermediate times, incorporate the continuation value from the target network.
                # Compute the pairwise V-values for each (center_x, center_y) pair.
                # Create a grid of centers
                X_grid = centers_x.unsqueeze(1).expand(in_sample_size, in_sample_size, x_dim)
                Y_grid = centers_y.unsqueeze(0).expand(in_sample_size, in_sample_size, y_dim)
                time_tensor = torch.ones((in_sample_size, in_sample_size, 1), device=device) * (time + 1.0)
                
                # Reshape for batch evaluation
                x_input = X_grid.reshape(-1, x_dim)
                y_input = Y_grid.reshape(-1, y_dim)
                time_input = time_tensor.reshape(-1, 1)
                val = target_net(time_input, x_input, y_input).reshape(in_sample_size, in_sample_size)
                
                # Add the discounted value function to the cost matrix
                cost_matrix = cost_matrix + DISCOUNT * val
                expected_v = ot.emd2(weights_x_np, weights_y_np, cost_matrix.detach().cpu().numpy())


            memory.push(torch.tensor([time], dtype=torch.float32, device=device), x_path_pool[smp_id, time, :],
                        y_path_pool[smp_id, time, :], torch.tensor([expected_v], device=device))

        # Optimize at time t
        for opt_step in range(N_OPT):
            loss = optimize_model(policy_net, memory, optimizer, Trunc_flag)
            if Trunc_flag:
                with torch.no_grad():
                    for param in policy_net.parameters():
                        ## param.add_(torch.randn(param.size(), device=device)/50)
                        param.clamp_(-1.0, 1.0)
            if loss:
                loss_hist[time] += loss.detach().cpu().item()


        loss_hist[time] /= N_OPT

        # update target network
        target_net.load_state_dict(policy_net.state_dict())
        # test initial value
        val = target_net(torch.ones(1, 1, device=device)*0.0, x_path_pool[0, 0, :].reshape(1, x_dim),
                         y_path_pool[0, 0, :].reshape(1, y_dim)).reshape(-1)
        val_hist[time] = val

        # empty memory
        memory.clear()
        print('Time step', time, 'Loss', loss_hist[time])

        # print('Shift vector in the last layer:', target_net.linear3.bias.sum().item())


    # for name, param in target_net.named_parameters():
    #     if param.requires_grad:
    #         print(name, param.data)


    print('Instance', n_ins)
    # print('Time elapsed', end - start)
    print('Last values', val_hist[0])
    final_result[n_ins] = val_hist[0]

print('All final value:', final_result)
print('Final mean:', final_result.mean())
print('Final std:', final_result.std())
end = Clock.time()
print('Average time for one instance:', (end-start)/N_INSTANCE)

Time step 4 Loss 3.4736964527724014e-05
0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
Time step 3 Loss 3.8936922359466553
0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
Time step 2 Loss 2.537625653743744
0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
Time step 1 Loss 1.4277370071411133
0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
Time step 0 Loss 0.048351617008447645
Instance 0
Last values 7.248353004455566
Time step 4 Loss 0.03434914529090747
0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
Time step 3 Loss 4.1678907012939455
0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
Time step 2 Loss 2.3007030630111696
0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
Time step 1 Loss 0.6179903799295425
0
100
200
300