In [6]:
%load_ext autoreload
%autoreload 2
import sys
import numpy as np
import time
sys.path.append('../')
random_seed = 10

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
from pnot.solver import Lmatrix2paths,  adapted_wasserstein_squared, path2adaptedpath, sort_qpath
from pnot.solver import ConditionalLaw
from pnot.solver import nested2, nested2_parallel
from pnot.utils import nested

n_sample = 10000
T = 3
L = np.array([[1, 0, 0], [1, 1, 0], [1, 1, 1]])
X,A = Lmatrix2paths(L, n_sample, seed = random_seed, verbose = False)
M = np.array([[1, 0, 0], [2, 1, 0], [2, 1, 2]])
Y,B = Lmatrix2paths(M, n_sample, seed = random_seed, verbose = False)

dist_bench = adapted_wasserstein_squared(A, B)
print("Theoretical AW_2^2: ", dist_bench)

delta_n = 1/n_sample**(1/T)
delta_n = 0.01

adaptedX = path2adaptedpath(X, delta_n)
adaptedY = path2adaptedpath(Y, delta_n)

# Quantization map
q2v = np.unique(np.concatenate([adaptedX, adaptedY], axis=0))
v2q = {k: v for v, k in enumerate(q2v)}  # Value to Quantization

# Quantized paths
qX = np.array([[v2q[x] for x in y] for y in adaptedX])
qY = np.array([[v2q[x] for x in y] for y in adaptedY])

# Sort paths and transpose to (n_sample, T+1)
qX = sort_qpath(qX.T)
qY = sort_qpath(qY.T)

cost_matrix = np.square(q2v[:,None] - q2v[None,:])

Theoretical AW_2^2:  3.0


In [10]:
from pnot.solver import nested2, nested2_parallel
from pnot.utils import nested

kernel_x = ConditionalLaw(qX, False)
kernel_y = ConditionalLaw(qY, False)

# start_time = time.perf_counter()
# AW_2square = nested2(kernel_x, kernel_y, cost_matrix)
# end_time = time.perf_counter()
# print("Elapsed time (Adapted OT): {:.4f} seconds".format(end_time - start_time))
# print("Numerical AW_2^2: ", AW_2square)


start_time = time.perf_counter()
AW_2square = nested2_parallel(kernel_x, kernel_y, cost_matrix)
end_time = time.perf_counter()
print("Elapsed time (Adapted OT): {:.4f} seconds".format(end_time - start_time))
print("Numerical AW_2^2: ", AW_2square)

100%|██████████| 1203/1203 [01:16<00:00, 15.64it/s]
100%|██████████| 1202/1202 [01:16<00:00, 15.67it/s]
100%|██████████| 1203/1203 [01:17<00:00, 15.46it/s]
100%|██████████| 1202/1202 [01:17<00:00, 15.57it/s]
100%|██████████| 1202/1202 [01:17<00:00, 15.52it/s]
100%|██████████| 1202/1202 [01:17<00:00, 15.57it/s]
100%|██████████| 1202/1202 [01:19<00:00, 15.20it/s]
100%|██████████| 1202/1202 [01:19<00:00, 15.21it/s]
100%|██████████| 71/71 [00:01<00:00, 44.51it/s]
100%|██████████| 71/71 [00:03<00:00, 21.27it/s]
100%|██████████| 71/71 [00:04<00:00, 14.61it/s]
100%|██████████| 71/71 [00:06<00:00, 11.23it/s]
100%|██████████| 71/71 [00:05<00:00, 12.32it/s]
100%|██████████| 70/70 [00:02<00:00, 28.09it/s]
100%|██████████| 70/70 [00:04<00:00, 16.82it/s]
100%|██████████| 70/70 [00:01<00:00, 56.92it/s]


Elapsed time (Adapted OT): 96.0669 seconds
Numerical AW_2^2:  2.99963523


100%|██████████| 1/1 [00:00<00:00, 24.48it/s]


In [11]:
markovian = False
AW_2square = nested(X, Y, delta_n, markovian)
print("Numerical AW_2^2: ", AW_2square)

Start computing
Timestep 2
Computing 92554014 OTs .......
Timestep 1
Computing 319225 OTs .......
Timestep 0
Computing 1 OTs .......
5.26369 seconds
AW_2^2: 2.99964
Finish
Numerical AW_2^2:  2.9996352300000018


# Non-Markovian Solver

In [3]:
from pnot.solver import nested2, nested2_parallel
from pnot.utils import nested

markovian = False

AW_2square = nested(X, Y, delta_n, markovian)
print("Numerical AW_2^2: ", AW_2square)

kernel_x = ConditionalLaw(qX, markovian)
kernel_y = ConditionalLaw(qY, markovian)

start_time = time.perf_counter()
AW_2square = nested2(kernel_x, kernel_y, cost_matrix)
end_time = time.perf_counter()
print("Elapsed time (Adapted OT): {:.4f} seconds".format(end_time - start_time))
print("Numerical AW_2^2: ", AW_2square)


start_time = time.perf_counter()
AW_2square = nested2_parallel(kernel_x, kernel_y, cost_matrix)
end_time = time.perf_counter()
print("Elapsed time (Adapted OT): {:.4f} seconds".format(end_time - start_time))
print("Numerical AW_2^2: ", AW_2square)

Start computing
Timestep 2
Computing 10000 OTs .......
Timestep 1
Computing 7744 OTs .......
Timestep 0
Computing 1 OTs .......
0.00970092 seconds
AW_2^2: 2.54543
Finish
Numerical AW_2^2:  2.5454280000000007


Timestep 2: 100%|██████████| 100/100 [00:00<00:00, 1247.78it/s]
Timestep 1: 100%|██████████| 88/88 [00:00<00:00, 826.63it/s]
Timestep 0: 100%|██████████| 1/1 [00:00<00:00, 418.97it/s]


Elapsed time (Adapted OT): 0.2395 seconds
Numerical AW_2^2:  2.5454280000000002


100%|██████████| 13/13 [00:00<00:00, 38.01it/s]
100%|██████████| 13/13 [00:00<00:00, 677.57it/s]
100%|██████████| 13/13 [00:00<00:00, 1067.09it/s]
100%|██████████| 12/12 [00:00<00:00, 130.95it/s]
100%|██████████| 13/13 [00:00<00:00, 1485.03it/s]
100%|██████████| 12/12 [00:00<00:00, 1539.29it/s]
100%|██████████| 12/12 [00:00<00:00, 1633.35it/s]
100%|██████████| 12/12 [00:00<00:00, 2259.35it/s]
100%|██████████| 11/11 [00:00<00:00, 593.48it/s]
100%|██████████| 11/11 [00:00<00:00, 268.88it/s]
100%|██████████| 11/11 [00:00<00:00, 490.39it/s]
100%|██████████| 11/11 [00:00<00:00, 502.90it/s]
100%|██████████| 11/11 [00:00<00:00, 876.44it/s]
100%|██████████| 11/11 [00:00<00:00, 798.44it/s]
100%|██████████| 11/11 [00:00<00:00, 990.18it/s]
100%|██████████| 11/11 [00:00<00:00, 1176.22it/s]


Elapsed time (Adapted OT): 4.9584 seconds
Numerical AW_2^2:  2.5454280000000002


100%|██████████| 1/1 [00:00<00:00, 1085.48it/s]


# Markovian Solver

In [4]:
# from pnot.utils import nested
# markovian = True
# nested(X, Y, delta_n, markovian)

In [5]:
# qX

In [6]:
# kernel_x.q2idx[2]

In [7]:
# kernel_x.v[1]

In [8]:
# kernel_x.next_idx[1]

In [4]:
markovian = True

AW_2square = nested(X, Y, delta_n, markovian)
print("Numerical AW_2^2: ", AW_2square)

kernel_x = ConditionalLaw(qX, markovian)
kernel_y = ConditionalLaw(qY, markovian)

start_time = time.perf_counter()
AW_2square = nested2(kernel_x, kernel_y, cost_matrix)
end_time = time.perf_counter()
print("Elapsed time (Adapted OT): {:.4f} seconds".format(end_time - start_time))
print("Numerical AW_2^2: ", AW_2square)


start_time = time.perf_counter()
AW_2square = nested2_parallel(kernel_x, kernel_y, cost_matrix)
end_time = time.perf_counter()
print("Elapsed time (Adapted OT): {:.4f} seconds".format(end_time - start_time))
print("Numerical AW_2^2: ", AW_2square)



Start computing
Timestep 2
Computing 8460 OTs .......
Timestep 1
Computing 7744 OTs .......
Timestep 0
Computing 1 OTs .......
0.0103447 seconds
AW_2^2: 5.7677
Finish
Numerical AW_2^2:  5.767699999999999


Timestep 2: 100%|██████████| 90/90 [00:00<00:00, 1765.20it/s]
Timestep 1: 100%|██████████| 88/88 [00:00<00:00, 1174.43it/s]
Timestep 0: 100%|██████████| 1/1 [00:00<00:00, 555.91it/s]

Elapsed time (Adapted OT): 0.1313 seconds
Numerical AW_2^2:  2.9327935000000007



100%|██████████| 12/12 [00:00<00:00, 1402.43it/s]
100%|██████████| 11/11 [00:00<00:00, 1047.91it/s]
100%|██████████| 12/12 [00:00<00:00, 927.43it/s]
100%|██████████| 11/11 [00:00<00:00, 1415.47it/s]
100%|██████████| 11/11 [00:00<00:00, 2420.76it/s]
100%|██████████| 11/11 [00:00<00:00, 2243.05it/s]
100%|██████████| 11/11 [00:00<00:00, 2655.08it/s]
100%|██████████| 11/11 [00:00<00:00, 2310.33it/s]
100%|██████████| 11/11 [00:00<00:00, 320.57it/s]
100%|██████████| 11/11 [00:00<00:00, 834.67it/s]
100%|██████████| 11/11 [00:00<00:00, 831.50it/s]
100%|██████████| 11/11 [00:00<00:00, 523.71it/s]
100%|██████████| 11/11 [00:00<00:00, 1051.80it/s]
100%|██████████| 11/11 [00:00<00:00, 1308.42it/s]
100%|██████████| 11/11 [00:00<00:00, 1168.60it/s]
100%|██████████| 11/11 [00:00<00:00, 1388.84it/s]


Elapsed time (Adapted OT): 3.9887 seconds
Numerical AW_2^2:  2.9327935000000007


100%|██████████| 1/1 [00:00<00:00, 707.30it/s]
