In [1]:
%load_ext autoreload
%load_ext line_profiler
%autoreload 2
%matplotlib ipympl

import sys
import numpy as np
import time
from tqdm import tqdm
import ot

sys.path.append('./src')
verbose = False

random_seed = np.random.randint(100)
random_seed = 0

In [2]:
from src.utils_solver import Lmatrix2paths,  adapted_wasserstein_squared, list_repr_mu_x, path2adaptedpath, qpath2mu_x, quantization, nested, plot_V, sort_qpath
n_sample = 1000
T = 3
L = np.array([[1, 0, 0], [2, 4, 0], [3, 2, 1]])
X,A = Lmatrix2paths(L, n_sample, seed = random_seed, verbose = False)
M = np.array([[1, 0, 0], [2, 3, 0], [3, 1, 2]])
Y,B = Lmatrix2paths(M, n_sample, seed = random_seed, verbose = False)

adaptedX = path2adaptedpath(X, delta_n = 0.1)
adaptedY = path2adaptedpath(Y, delta_n = 0.1)

In [3]:
# Quantization map
q2v = np.unique(np.concatenate([adaptedX, adaptedY], axis=0))
v2q = {k: v for v, k in enumerate(q2v)}  # Value to Quantization

# Quantized paths
qX = np.array([[v2q[x] for x in y] for y in adaptedX])
qY = np.array([[v2q[x] for x in y] for y in adaptedY])

# Sort paths and transpose to (n_sample, T+1)
qX = sort_qpath(qX.T)
qY = sort_qpath(qY.T)

# Get conditional distribution mu_{x_{1:t}} = mu_x[t][(x_1,...,x_t)] = {x_{t+1} : mu_{x_{1:t}}(x_{t+1}), ...}
mu_x = qpath2mu_x(qX)
nu_y = qpath2mu_x(qY)

# represent mu_x[t] with 
# mu_x_c[t]: a list of x_{1:t}
# mu_x_v[t]: a list of [x_{t+1}, ...] (so a list of list of values)
# mu_x_w[t]: a list of [mu_{x_{1:t}}(x_{t+1}), ...] (so a list of list of weights)
# mu_x_n[t]: a list of #Number of next value x_{t+1} follows x_{1:t}
# mu_x_cumn[t]: a list of position where (x_{1:t},?) appears in x_{1:t+1}

mu_x_c, mu_x_cn, mu_x_v, mu_x_w, mu_x_cumn = list_repr_mu_x(mu_x)
nu_y_c, nu_y_cn, nu_y_v, nu_y_w, nu_y_cumn = list_repr_mu_x(nu_y)
# All list except weights should be increasing! 

In [4]:
def new_nested(mu_x_cn, mu_x_v, mu_x_w, mu_x_cumn, nu_y_cn, nu_y_v, nu_y_w, nu_y_cumn):
    T = len(mu_x_cn)
    square_cost_matrix = (q2v[None, :] - q2v[None, :].T) ** 2
    V = [np.zeros([mu_x_cn[t],nu_y_cn[t]]) for t in range(T)] # V_t(x_{1:t},y_{1:t})
    for t in range(T - 1, -1, -1):
        print(t)
        for cx, vx, wx, ix, jx in zip(range(mu_x_cn[t]),mu_x_v[t],mu_x_w[t],mu_x_cumn[t][:-1],mu_x_cumn[t][1:]):
            for cy, vy, wy, iy, jy in zip(range(nu_y_cn[t]),nu_y_v[t],nu_y_w[t],nu_y_cumn[t][:-1],nu_y_cumn[t][1:]):
                if len(vx)== 1 and len(vy) == 1 and t == T-1:
                    V[t][cx,cy] = square_cost_matrix[vx[0],vy[0]]  
                    # could make it faster for any vx that len(vx) = 1
                    # could make it faster for any vy that len(vy) = 1
                else:
                    idxy = np.ix_(vx, vy) # this is slow, can we make the slicing faster
                    cost = square_cost_matrix[idxy]
                    cost += V[t + 1][ix:jx,iy:jy] if t < T - 1 else 0

                    # V[t][cx,cy] = ot.emd2(wx, wy, cost)
                    V[t][cx,cy] = np.sum(cost* ot.lp.emd(wx, wy, cost))

    AW_2square = V[0][0,0]
    return AW_2square

# start_time = time.perf_counter()
# AW_2square = new_nested(mu_x_cn, mu_x_v, mu_x_w, mu_x_cumn, nu_y_cn, nu_y_v, nu_y_w, nu_y_cumn)
# end_time = time.perf_counter()
# print("Elapsed time (Adapted OT): {:.4f} seconds".format(end_time - start_time))

%lprun -f new_nested AW_2square = new_nested(mu_x_cn, mu_x_v, mu_x_w, mu_x_cumn, nu_y_cn, nu_y_v, nu_y_w, nu_y_cumn)
print("Numerical AW_2^2: ", AW_2square)

2
1
0
Numerical AW_2^2:  3.0517166666666666


Timer unit: 1e-09 s

Total time: 8.83707 s
File: /var/folders/gq/bts52kyn0v72cpz5c5489984006lvv/T/ipykernel_75320/791767963.py
Function: new_nested at line 1

Line #      Hits         Time  Per Hit   % Time  Line Contents
     1                                           def new_nested(mu_x_cn, mu_x_v, mu_x_w, mu_x_cumn, nu_y_cn, nu_y_v, nu_y_w, nu_y_cumn):
     2         1       1000.0   1000.0      0.0      T = len(mu_x_cn)
     3         1     310000.0 310000.0      0.0      square_cost_matrix = (q2v[None, :] - q2v[None, :].T) ** 2
     4         1     442000.0 442000.0      0.0      V = [np.zeros([mu_x_cn[t],nu_y_cn[t]]) for t in range(T)] # V_t(x_{1:t},y_{1:t})
     5         4       2000.0    500.0      0.0      for t in range(T - 1, -1, -1):
     6         3     263000.0  87666.7      0.0          print(t)
     7       961     416000.0    432.9      0.0          for cx, vx, wx, ix, jx in zip(range(mu_x_cn[t]),mu_x_v[t],mu_x_w[t],mu_x_cumn[t][:-1],mu_x_cumn[t][1:]):
     8    7890

In [5]:
q2v, v2q, mu_x, nu_y, q2v_x, v2q_x, q2v_y, v2q_y = quantization(adaptedX, adaptedY, markovian=False, verbose = False)

%lprun -f nested AW_2square,_ = nested(mu_x, nu_y, v2q_x, v2q_y, q2v)
print("Numerical AW_2^2: ", AW_2square)

Nested backward induction .......


Timestep 2: 100%|██████████| 900/900 [00:19<00:00, 45.47it/s]
Timestep 1: 100%|██████████| 57/57 [00:00<00:00, 141.59it/s]
Timestep 0: 100%|██████████| 1/1 [00:00<00:00, 1024.25it/s]

Numerical AW_2^2:  3.0517166666666666





Timer unit: 1e-09 s

Total time: 18.5865 s
File: /Users/hous/Github/AWD_numerics/src/utils_solver.py
Function: nested at line 165

Line #      Hits         Time  Per Hit   % Time  Line Contents
   165                                           def nested(mu_x, nu_y, v2q_x, v2q_y, q2v, markovian=False, verbose=True):
   166         1       1000.0   1000.0      0.0      T = len(mu_x)
   167         1     102000.0 102000.0      0.0      square_cost_matrix = (q2v[None, :] - q2v[None, :].T) ** 2
   168                                           
   169         1     486000.0 486000.0      0.0      V = [np.zeros([len(v2q_x[t]), len(v2q_y[t])]) for t in range(T)]
   170         1       1000.0   1000.0      0.0      if verbose:
   171         1      78000.0  78000.0      0.0          print("Nested backward induction .......")
   172         4       3000.0    750.0      0.0      for t in range(T - 1, -1, -1):
   173         3   11472000.0    4e+06      0.1          tqdm_bar = tqdm(mu_x[t].items()

In [6]:
from adapted_empirical_measure.AEM_grid import uniform_empirical_grid_measure
from trees.Build_trees_from_paths import build_tree_from_paths
from trees.TreeAnalysis import get_depth
from awd_trees.Nested_Dist_Algo import compute_nested_distance, nested_optimal_transport_loop
# Compute uniform adapted empirical grid measures with weights
adapted_X, adapted_weights_X = uniform_empirical_grid_measure(X.T, delta_n=0.1, use_weights=True)
adapted_Y, adapted_weights_Y = uniform_empirical_grid_measure(Y.T, delta_n=0.1, use_weights=True)

# Build trees from the adapted paths
adapted_tree_1 = build_tree_from_paths(adapted_X, adapted_weights_X)
adapted_tree_2 = build_tree_from_paths(adapted_Y, adapted_weights_Y)

# Compute the nested (adapted optimal transport) distance and measure execution time
max_depth = get_depth(adapted_tree_1)

%lprun -f nested_optimal_transport_loop distance, probability_matrices = nested_optimal_transport_loop(adapted_tree_1, adapted_tree_2, max_depth, "solver_lp_pot", 0, 2)
print("Numerical AW_2^2: ", distance)

Depth 2: 100%|██████████| 900/900 [01:03<00:00, 14.08it/s]
Depth 1: 100%|██████████| 57/57 [00:00<00:00, 190.80it/s]
Depth 0: 100%|██████████| 1/1 [00:00<00:00, 1246.82it/s]

Numerical AW_2^2:  3.051716666666665





Timer unit: 1e-09 s

Total time: 61.6122 s
File: /Users/hous/Github/AWD_numerics/./src/awd_trees/Nested_Dist_Algo.py
Function: nested_optimal_transport_loop at line 21

Line #      Hits         Time  Per Hit   % Time  Line Contents
    21                                           def nested_optimal_transport_loop(
    22                                               tree1_root, tree2_root, max_depth, method, lambda_reg, power
    23                                           ):
    24                                               """
    25                                               Computes the nested optimal transport plan between two trees.
    26                                           
    27                                               Parameters:
    28                                               - tree1_root (TreeNode): Root of the first tree.
    29                                               - tree2_root (TreeNode): Root of the second tree.
    30                    