In [120]:
import numpy as np
import plotly.express as px

from encoders import HexagonalSSPSpace
from decoders import train_decoder_net_sk
from utils import plot_path, generate_path

In [112]:
dt = 0.01
T = 10
num_timesteps = int(T/dt)
np.random.seed(0)
path = generate_path(num_timesteps, 2)
bounds = np.array([np.min(path, axis=0), np.max(path, axis=0)]).T
deltas = np.diff(path, axis=0)
encoder = HexagonalSSPSpace(domain_dim=2, length_scale=0.1)

In [113]:
decoder, history = train_decoder_net_sk(encoder, bounds)

Generating training examples ...
(200000, 2)



The balance properties of Sobol' points require n to be a power of 2.



Training decoder network ...
Iteration 1, loss = 0.13561451
Validation score: 0.829968
Iteration 2, loss = 0.04147201
Validation score: 0.852163
Iteration 3, loss = 0.03824883
Validation score: 0.858207
Iteration 4, loss = 0.03689686
Validation score: 0.862828
Iteration 5, loss = 0.03577485
Validation score: 0.865984
Iteration 6, loss = 0.03479423
Validation score: 0.869005
Iteration 7, loss = 0.03400452
Validation score: 0.871843
Iteration 8, loss = 0.03339124
Validation score: 0.873531
Iteration 9, loss = 0.03296987
Validation score: 0.874806
Iteration 10, loss = 0.03264298
Validation score: 0.876237
Iteration 11, loss = 0.03240048
Validation score: 0.877323
Iteration 12, loss = 0.03220590
Validation score: 0.878231
Iteration 13, loss = 0.03204052
Validation score: 0.878608
Iteration 14, loss = 0.03191637
Validation score: 0.878790
Iteration 15, loss = 0.03184289
Validation score: 0.879429
Iteration 16, loss = 0.03178002
Validation score: 0.879622
Iteration 17, loss = 0.03172648
Vali

In [127]:
x_t = encoder.encode(path[:1])
ssps = [np.copy(x_t)]
for i in range(num_timesteps - 1):
    dx, dy = deltas[i]
    dx_ssp = encoder.encode([[dx,dy]])
    x_t = x_t * dx_ssp
    ssps.append(x_t)

ssps = np.array(ssps).squeeze()
decoded = decoder.decode(ssps)
print(path.shape)
print(ssps.shape)

(1000, 2)
(1000, 151)


In [141]:
print(decoded.shape)
ts = np.linspace(0, T, num_timesteps)
plot_path(ts, decoded, path)

(1000, 2)


In [154]:
diffs = np.diff(decoded)
cleaned = np.copy(decoded)

for i in range(len(diffs)):
    if np.linalg.norm(diffs[i]) > 1:
        print(f"Shifting by {diffs[i]} ({i})")
        cleaned[i+1:] -= diffs[i]

plot_path(np.arange(num_timesteps), cleaned, path)

Shifting by [-1.11485989] (0)
Shifting by [1.00444091] (82)
Shifting by [1.01293348] (83)
Shifting by [1.02078693] (84)
Shifting by [1.02789933] (85)
Shifting by [1.03421857] (86)
Shifting by [1.0397437] (87)
Shifting by [1.04452012] (88)
Shifting by [1.04862958] (89)
Shifting by [1.0521789] (90)
Shifting by [1.05528667] (91)
Shifting by [1.05807271] (92)
Shifting by [1.06064966] (93)
Shifting by [1.06311643] (94)
Shifting by [1.06555544] (95)
Shifting by [1.06803119] (96)
Shifting by [1.07059111] (97)
Shifting by [1.07326702] (98)
Shifting by [1.07607741] (99)
Shifting by [1.07902984] (100)
Shifting by [1.08212345] (101)
Shifting by [1.10176746] (102)
Shifting by [1.10974256] (103)
Shifting by [1.11755027] (104)
Shifting by [1.12518916] (105)
Shifting by [1.13265726] (106)
Shifting by [1.13995273] (107)
Shifting by [1.14707436] (108)
Shifting by [1.15402124] (109)
Shifting by [1.16079309] (110)
Shifting by [1.16738966] (111)
Shifting by [1.12208779] (112)
Shifting by [1.1260174] (113)