In [1]:
%load_ext autoreload
%load_ext line_profiler
%autoreload 2
# %matplotlib ipympl

import sys
import numpy as np
import time

sys.path.append('./src')

random_seed = np.random.randint(100)
random_seed = 0

In [2]:
from src.utils_solver import Lmatrix2paths,  adapted_wasserstein_squared, list_repr_mu_x, path2adaptedpath, qpath2mu_x, quantization, nested, plot_V, sort_qpath
n_sample = 5000
T = 3
L = np.array([[1, 0, 0], [2, 4, 0], [3, 2, 1]])
X,A = Lmatrix2paths(L, n_sample, seed = random_seed, verbose = False)
M = np.array([[1, 0, 0], [2, 3, 0], [3, 1, 2]])
Y,B = Lmatrix2paths(M, n_sample, seed = random_seed, verbose = False)

dist_bench = adapted_wasserstein_squared(A, B)
print("Theoretical AW_2^2: ", dist_bench)

adaptedX = path2adaptedpath(X, delta_n = 0.1)
adaptedY = path2adaptedpath(Y, delta_n = 0.1)

Theoretical AW_2^2:  3.0


In [3]:
# Quantization map
q2v = np.unique(np.concatenate([adaptedX, adaptedY], axis=0))
v2q = {k: v for v, k in enumerate(q2v)}  # Value to Quantization

# Quantized paths
qX = np.array([[v2q[x] for x in y] for y in adaptedX])
qY = np.array([[v2q[x] for x in y] for y in adaptedY])

# Sort paths and transpose to (n_sample, T+1)
qX = sort_qpath(qX.T)
qY = sort_qpath(qY.T)

# Get conditional distribution mu_{x_{1:t}} = mu_x[t][(x_1,...,x_t)] = {x_{t+1} : mu_{x_{1:t}}(x_{t+1}), ...}
mu_x = qpath2mu_x(qX)
nu_y = qpath2mu_x(qY)

mu_x_c, mu_x_cn, mu_x_v, mu_x_w, mu_x_cumn = list_repr_mu_x(mu_x, q2v)
nu_y_c, nu_y_cn, nu_y_v, nu_y_w, nu_y_cumn = list_repr_mu_x(nu_y, q2v)
# All list except weights should be increasing! 

In [4]:
from src.utils_solver import nested2, nested2_parallel, solve_ot
# start_time = time.perf_counter()
# AW_2square = nested2(mu_x_cn, mu_x_v, mu_x_w, mu_x_cumn, nu_y_cn, nu_y_v, nu_y_w, nu_y_cumn)
# end_time = time.perf_counter()
# print("Elapsed time (Adapted OT): {:.4f} seconds".format(end_time - start_time))
# print("Numerical AW_2^2: ", AW_2square)
%lprun -f nested2 nested2(mu_x_cn, mu_x_v, mu_x_w, mu_x_cumn, nu_y_cn, nu_y_v, nu_y_w, nu_y_cumn)

Timestep 2: 100%|██████████| 3298/3298 [01:32<00:00, 35.70it/s]
Timestep 1: 100%|██████████| 68/68 [00:00<00:00, 70.55it/s] 
Timestep 0: 100%|██████████| 1/1 [00:00<00:00, 885.81it/s]


Timer unit: 1e-09 s

Total time: 87.6917 s
File: /Users/hous/Github/AWD_numerics/src/utils_solver.py
Function: nested2 at line 145

Line #      Hits         Time  Per Hit   % Time  Line Contents
   145                                           def nested2(mu_x_cn, mu_x_v, mu_x_w, mu_x_cumn, nu_y_cn, nu_y_v, nu_y_w, nu_y_cumn):
   146                                           
   147         1       1000.0   1000.0      0.0      T = len(mu_x_cn)
   148         1      27000.0  27000.0      0.0      V = [np.zeros([mu_x_cn[t], nu_y_cn[t]]) for t in range(T)]  # V_t(x_{1:t},y_{1:t})
   149         4       4000.0   1000.0      0.0      for t in range(T - 1, -1, -1):
   150         3   10651000.0    4e+06      0.0          x_bar = tqdm(range(mu_x_cn[t]))
   151         3    1122000.0 374000.0      0.0          x_bar.set_description(f"Timestep {t}")
   152      3373  307358000.0  91123.0      0.4          for cx, vx, wx, ix, jx in zip(
   153         3      21000.0   7000.0      0.0           

In [5]:
from src.utils_solver import nested2, nested2_parallel, chunk_process
# start_time = time.perf_counter()
# AW_2square = nested2_parallel(mu_x_cn, mu_x_v, mu_x_w, mu_x_cumn, nu_y_cn, nu_y_v, nu_y_w, nu_y_cumn)
# end_time = time.perf_counter()
# print("Elapsed time (Adapted OT): {:.4f} seconds".format(end_time - start_time))
# print("Numerical AW_2^2: ", AW_2square)

%lprun -f nested2_parallel nested2_parallel(mu_x_cn, mu_x_v, mu_x_w, mu_x_cumn, nu_y_cn, nu_y_v, nu_y_w, nu_y_cumn)

100%|██████████| 550/550 [00:09<00:00, 57.69it/s]
100%|██████████| 549/549 [00:10<00:00, 54.61it/s] 
100%|██████████| 549/549 [00:13<00:00, 40.91it/s]
100%|██████████| 550/550 [00:14<00:00, 38.82it/s]
100%|██████████| 550/550 [00:15<00:00, 35.83it/s]
100%|██████████| 550/550 [00:15<00:00, 35.14it/s]
100%|██████████| 68/68 [00:00<00:00, 73.57it/s] 
100%|██████████| 1/1 [00:00<00:00, 2896.62it/s]


Timer unit: 1e-09 s

Total time: 21.1729 s
File: /Users/hous/Github/AWD_numerics/src/utils_solver.py
Function: nested2_parallel at line 179

Line #      Hits         Time  Per Hit   % Time  Line Contents
   179                                           def nested2_parallel(
   180                                               mu_x_cn, mu_x_v, mu_x_w, mu_x_cumn, nu_y_cn, nu_y_v, nu_y_w, nu_y_cumn
   181                                           ):
   182         1          0.0      0.0      0.0      T = len(mu_x_cn)
   183         1    6167000.0    6e+06      0.0      V = [np.zeros([mu_x_cn[t], nu_y_cn[t]]) for t in range(T)]  # V_t(x_{1:t},y_{1:t})
   184         4      12000.0   3000.0      0.0      for t in range(T - 1, -1, -1):
   185         3       1000.0    333.3      0.0          n_processes = 6 if t > 1 else 1  # HERE WE NEED TO CHANGE BACK TO t>1
   186         3     563000.0 187666.7      0.0          chunks = np.array_split(range(mu_x_cn[t]), n_processes)
   187         3   