In [2]:
import os
import sys
import numpy as np
import time
import warnings
import random

# Define paths
notebooks_path = os.path.abspath(os.getcwd()) 
src_path = os.path.abspath(os.path.join(notebooks_path, "../src"))
if src_path not in sys.path:
    sys.path.insert(0, src_path)

# Import modules
from multi_dimension.Multidimension_trees import *
from multi_dimension.Multidimension_solver import *
from multi_dimension.Multidimension_adapted_empirical_measure import *

from measure_sampling.Gen_Path_and_AdaptedTrees import generate_adapted_tree
from trees.Tree_Node import *
from trees.TreeAnalysis import *
from trees.TreeVisualization import *
from trees.Save_Load_trees import *
from trees.Tree_AWD_utilities import *
from trees.Build_trees_from_paths import build_tree_from_paths

from adapted_empirical_measure.AEM_grid import *
from adapted_empirical_measure.AEM_kMeans import *
from benchmark_value_gaussian.Comp_AWD2_Gaussian import *
# from awd_trees.Gurobi_AOT import *
from awd_trees.Nested_Dist_Algo import compute_nested_distance, compute_nested_distance_parallel

# Suppress warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)

# Import custom modules from src
from utils_solver import Lmatrix2paths, adapted_empirical_measure, adapted_wasserstein_squared, quantization, nested, plot_V

In [3]:
# Set random seeds for reproducibility
np.random.seed(0)
random.seed(0)
verbose = False

# Create a random seed variable for additional randomness in the notebook
random_seed = np.random.randint(100)
print("Random seed for this run:", random_seed)

Random seed for this run: 44


## Generating Path (same randomness)

In [4]:
n_sample = 200


# For measure "mu"
print("mu")
L = np.array([[1, 0, 0, 0], [1, 2, 0, 0], [1, 2, 3, 0], [1,2,3, 4]])
normalize = False  # Not used explicitly here
X, A = Lmatrix2paths(L, n_sample, seed=random_seed)

# For measure "nu"
print("nu")
M = np.array([[1, 0, 0, 0], [2, 1, 0, 0], [3, 2, 1, 0], [4, 3, 2, 1]])
Y, B = Lmatrix2paths(M, n_sample, seed=random_seed)

mu
Cholesky:
[[1 0 0 0]
 [1 2 0 0]
 [1 2 3 0]
 [1 2 3 4]]
Covariance:
[[ 1  1  1  1]
 [ 1  5  5  5]
 [ 1  5 14 14]
 [ 1  5 14 30]]
nu
Cholesky:
[[1 0 0 0]
 [2 1 0 0]
 [3 2 1 0]
 [4 3 2 1]]
Covariance:
[[ 1  2  3  4]
 [ 2  5  8 11]
 [ 3  8 14 20]
 [ 4 11 20 30]]


## Real distance (not particularly relevant here, as we are comparing speed rather than convergence). 

### What matters here is that the three methods produce the same output, as they should each solve the discrete AOT problem exactly.

In [5]:
dist_bench = adapted_wasserstein_squared(A, B)
print("Theoretical AW_2^2: ", dist_bench)

Theoretical AW_2^2:  30.0


## With your code

In [8]:
# Grid projection of k-mean projection
adaptedX = adapted_empirical_measure(X, delta_n = 0.1)
adaptedY = adapted_empirical_measure(Y, delta_n = 0.1)

q2v, v2q, mu_x, nu_y, q2v_x, v2q_x, q2v_y, v2q_y = quantization(adaptedX, adaptedY, markovian=False)

start_time = time.time()
AW_2square, V = nested(mu_x, nu_y, v2q_x, v2q_y, q2v, markovian=False)
elapsed_time_pot = time.time() - start_time

dist_bench = adapted_wasserstein_squared(A, B)
print("Theoretical AW_2^2: ", dist_bench)
print("Numerical AW_2^2: ", AW_2square)
print("Elapsed time (Adapted OT): {:.4f} seconds".format(elapsed_time_pot))

Quantization ......
Number of distinct values in global quantization:  197
Number of condition subpaths of mu_x
Time 0: 1
Time 1: 49
Time 2: 192
Time 3: 200
Number of condition subpaths of nu_y
Time 0: 1
Time 1: 49
Time 2: 187
Time 3: 200
Nested backward induction .......


Timestep 3: 100%|██████████| 200/200 [00:03<00:00, 54.19it/s]
Timestep 2: 100%|██████████| 192/192 [00:03<00:00, 52.77it/s]
Timestep 1: 100%|██████████| 49/49 [00:00<00:00, 166.29it/s]
Timestep 0: 100%|██████████| 1/1 [00:00<00:00, 982.27it/s]

Theoretical AW_2^2:  30.0
Numerical AW_2^2:  25.109290833333336
Elapsed time (Adapted OT): 7.6320 seconds





## With my code not-parallel

In [14]:
adapted_X1, adapted_weights_X = uniform_empirical_grid_measure(X.T, delta_n=0.1, use_weights=True)

In [15]:
# Compute uniform adapted empirical grid measures with weights
adapted_X1, adapted_weights_X = uniform_empirical_grid_measure(X.T, delta_n=0.1, use_weights=True)
adapted_Y, adapted_weights_Y = uniform_empirical_grid_measure(Y.T, delta_n=0.1, use_weights=True)

# Build trees from the adapted paths
adapted_tree_1 = build_tree_from_paths(adapted_X, adapted_weights_X)
adapted_tree_2 = build_tree_from_paths(adapted_Y, adapted_weights_Y)

In [16]:
# Compute the nested (adapted optimal transport) distance and measure execution time
max_depth = get_depth(adapted_tree_1)
start_time = time.time()
distance_pot = compute_nested_distance(
    adapted_tree_1,
    adapted_tree_2,
    max_depth,
    method="solver_lp_pot",
    return_matrix=False,
    lambda_reg=0,
    power=2,
)
elapsed_time_pot = time.time() - start_time

print("Numerical AW_2^2 (Adapted OT):", distance_pot)
print("Elapsed time (Adapted OT): {:.4f} seconds".format(elapsed_time_pot))

Depth 3: 100%|██████████| 200/200 [00:01<00:00, 124.12it/s]
Depth 2: 100%|██████████| 192/192 [00:01<00:00, 134.65it/s]
Depth 1: 100%|██████████| 49/49 [00:00<00:00, 515.03it/s]
Depth 0: 100%|██████████| 1/1 [00:00<00:00, 1968.23it/s]

Numerical AW_2^2 (Adapted OT): 25.10929083333333
Elapsed time (Adapted OT): 3.1508 seconds





## My code parallel

In [17]:
# Compute nested distance
max_depth_val = get_depth(adapted_tree_1)
start_time = time.time()
distance_pot = compute_nested_distance_parallel(adapted_tree_1, adapted_tree_2, max_depth_val, return_matrix=False, power=2)
elapsed_time_pot = time.time() - start_time

print("Nested distance Parellel:", distance_pot)
print("Computation time Parellel: {:.4f} seconds".format(elapsed_time_pot))

Depth: 3


Parallel Depth 3: 100%|██████████| 12/12 [00:03<00:00,  3.95it/s]


Depth: 2


Parallel Depth 2: 100%|██████████| 12/12 [00:02<00:00,  4.47it/s]


Depth: 1


Parallel Depth 1: 100%|██████████| 12/12 [00:02<00:00,  4.73it/s]


Depth: 0


Parallel Depth 0: 100%|██████████| 12/12 [00:02<00:00,  5.36it/s]


Nested distance Parellel: 25.10929083333333
Computation time Parellel: 13.6722 seconds
