In [1]:
%load_ext autoreload
%autoreload 2
import sys
import numpy as np
import time
sys.path.append('../')
random_seed = 1

In [2]:
from pnot.solver import Lmatrix2paths,  adapted_wasserstein_squared, path2adaptedpath, sort_qpath
from pnot.solver import ConditionalLaw
from pnot.solver import nested2, nested2_parallel
from pnot.utils import nested, nestedmarkovian

n_sample = 500
T = 3
L = np.array([[1, 0, 0], [1, 1, 0], [1, 1, 1]])
X,A = Lmatrix2paths(L, n_sample, seed = random_seed, verbose = False)
M = np.array([[1, 0, 0], [2, 1, 0], [2, 1, 2]])
Y,B = Lmatrix2paths(M, n_sample, seed = random_seed, verbose = False)

dist_bench = adapted_wasserstein_squared(A, B)
print("Theoretical AW_2^2: ", dist_bench)

delta_n = 1/n_sample**(1/T)
delta_n = 0.01

adaptedX = path2adaptedpath(X, delta_n)
adaptedY = path2adaptedpath(Y, delta_n)

# Quantization map
q2v = np.unique(np.concatenate([adaptedX, adaptedY], axis=0))
v2q = {k: v for v, k in enumerate(q2v)}  # Value to Quantization

# Quantized paths
qX = np.array([[v2q[x] for x in y] for y in adaptedX])
qY = np.array([[v2q[x] for x in y] for y in adaptedY])

# Sort paths and transpose to (n_sample, T+1)
qX = sort_qpath(qX.T)
qY = sort_qpath(qY.T)

cost_matrix = np.square(q2v[:,None] - q2v[None,:])

Theoretical AW_2^2:  3.0


# Non-Markovian Solver

In [3]:
from pnot.solver import nested2, nested2_parallel
from pnot.utils import nested

markovian = False

AW_2square = nested(X, Y, delta_n, markovian)
print("Numerical AW_2^2: ", AW_2square)

kernel_x = ConditionalLaw(qX, markovian)
kernel_y = ConditionalLaw(qY, markovian)

start_time = time.perf_counter()
AW_2square = nested2(kernel_x, kernel_y, cost_matrix)
end_time = time.perf_counter()
print("Elapsed time (Adapted OT): {:.4f} seconds".format(end_time - start_time))
print("Numerical AW_2^2: ", AW_2square)


start_time = time.perf_counter()
AW_2square = nested2_parallel(kernel_x, kernel_y, cost_matrix)
end_time = time.perf_counter()
print("Elapsed time (Adapted OT): {:.4f} seconds".format(end_time - start_time))
print("Numerical AW_2^2: ", AW_2square)

Start computing
Timestep 2
Computing 250000 OTs .......
Timestep 1
Computing 76176 OTs .......
Timestep 0
Computing 1 OTs .......
0.0803173 seconds
AW_2^2: 2.98388
Finish
Numerical AW_2^2:  2.983883999999997


Timestep 2: 100%|██████████| 500/500 [00:01<00:00, 451.45it/s]
Timestep 1: 100%|██████████| 276/276 [00:01<00:00, 227.79it/s]
Timestep 0: 100%|██████████| 1/1 [00:00<00:00, 128.17it/s]

Elapsed time (Adapted OT): 2.3767 seconds
Numerical AW_2^2:  2.983884



100%|██████████| 63/63 [00:00<00:00, 287.12it/s]
100%|██████████| 63/63 [00:00<00:00, 293.72it/s]
100%|██████████| 62/62 [00:00<00:00, 287.01it/s]
100%|██████████| 62/62 [00:00<00:00, 372.97it/s]
100%|██████████| 63/63 [00:00<00:00, 332.89it/s]
100%|██████████| 62/62 [00:00<00:00, 356.42it/s]
100%|██████████| 62/62 [00:00<00:00, 325.80it/s]
100%|██████████| 63/63 [00:00<00:00, 242.27it/s]
100%|██████████| 35/35 [00:00<00:00, 140.04it/s]
100%|██████████| 35/35 [00:00<00:00, 134.84it/s]
100%|██████████| 34/34 [00:00<00:00, 210.39it/s]
100%|██████████| 34/34 [00:00<00:00, 155.31it/s]
100%|██████████| 35/35 [00:00<00:00, 82.62it/s]
100%|██████████| 34/34 [00:00<00:00, 127.95it/s]
100%|██████████| 34/34 [00:00<00:00, 109.73it/s]
100%|██████████| 35/35 [00:00<00:00, 88.05it/s]


Elapsed time (Adapted OT): 4.4690 seconds
Numerical AW_2^2:  2.983884


100%|██████████| 1/1 [00:00<00:00, 167.36it/s]


# Markovian Solver

In [4]:
markovian = True

AW_2square = nested(X, Y, delta_n, markovian)
print("Numerical AW_2^2: ", AW_2square)

kernel_x = ConditionalLaw(qX, markovian)
kernel_y = ConditionalLaw(qY, markovian)

start_time = time.perf_counter()
AW_2square = nested2(kernel_x, kernel_y, cost_matrix)
end_time = time.perf_counter()
print("Elapsed time (Adapted OT): {:.4f} seconds".format(end_time - start_time))
print("Numerical AW_2^2: ", AW_2square)


start_time = time.perf_counter()
AW_2square = nested2_parallel(kernel_x, kernel_y, cost_matrix)
end_time = time.perf_counter()
print("Elapsed time (Adapted OT): {:.4f} seconds".format(end_time - start_time))
print("Numerical AW_2^2: ", AW_2square)



Start computing
Timestep 2
Computing 117198 OTs .......
Timestep 1
Computing 76176 OTs .......
Timestep 0
Computing 1 OTs .......
0.0909874 seconds
AW_2^2: 3.90203
Finish
Numerical AW_2^2:  3.9020268910158755


Timestep 2: 100%|██████████| 306/306 [00:00<00:00, 376.62it/s]
Timestep 1: 100%|██████████| 276/276 [00:01<00:00, 250.29it/s]
Timestep 0: 100%|██████████| 1/1 [00:00<00:00, 167.03it/s]

Elapsed time (Adapted OT): 1.9248 seconds
Numerical AW_2^2:  3.9020268910158733



100%|██████████| 38/38 [00:00<00:00, 225.08it/s]
100%|██████████| 39/39 [00:00<00:00, 132.11it/s]
100%|██████████| 39/39 [00:00<00:00, 168.88it/s]
100%|██████████| 38/38 [00:00<00:00, 106.55it/s]
100%|██████████| 38/38 [00:00<00:00, 146.41it/s]
100%|██████████| 38/38 [00:00<00:00, 157.40it/s]
100%|██████████| 38/38 [00:00<00:00, 346.55it/s]
100%|██████████| 38/38 [00:00<00:00, 159.26it/s]
100%|██████████| 35/35 [00:00<00:00, 116.73it/s]
100%|██████████| 35/35 [00:00<00:00, 56.89it/s]
100%|██████████| 34/34 [00:00<00:00, 255.25it/s]
100%|██████████| 34/34 [00:00<00:00, 120.27it/s]
100%|██████████| 34/34 [00:00<00:00, 175.60it/s]
100%|██████████| 35/35 [00:00<00:00, 90.01it/s]
100%|██████████| 34/34 [00:00<00:00, 106.75it/s]
100%|██████████| 35/35 [00:00<00:00, 117.85it/s]


Elapsed time (Adapted OT): 4.9708 seconds
Numerical AW_2^2:  3.9020268910158733


100%|██████████| 1/1 [00:00<00:00, 136.32it/s]
