In [None]:
import numpy as np
import torch
import matplotlib as mpl
import matplotlib.pyplot as plt
from tslearn.generators import random_walks
from tslearn.metrics import dtw_path, dtw_subsequence_path

from decdtw.decdtw import DecDTWLayer
from decdtw.gdtw_solver.utils import expand_global_constraint
from decdtw.utils import BatchedSignal
mpl.rcParams['text.usetex'] = True

import warnings
warnings.filterwarnings('ignore')

In [None]:
%matplotlib notebook

np.random.seed(20)  # 20 for fig 1, 105 for fig 2

fig, axs = plt.subplots(1, 4, figsize=(12, 3))

x = np.squeeze(random_walks(n_ts=1, sz=8))
y = np.squeeze(random_walks(n_ts=1, sz=8))

D = (x[:, None] - y[None, :]) ** 2

pth, dtw_discr = dtw_path(x, y)
for x_ind, y_ind in pth:
    D[x_ind, y_ind] = np.inf
D = np.flip(D, axis=0)

# classic DTW correspondences
axs[0].plot(list(range(8)), x, '--o', label=r'$\mathbf{x}$', c=u'#CC79A7', linewidth=2)
axs[0].plot(list(range(8)), y, '--o', label=r'$\mathbf{y}$', c=u'#E69F00', linewidth=2)
for x_ind, y_ind in pth:
    axs[0].plot([x_ind, y_ind], [x[x_ind], y[y_ind]], c='black', linewidth=1)
axs[0].legend(fontsize=12)
axs[0].set_xlabel(r'$i$', fontsize=16)
axs[0].set_ylabel(r'$\mathbf{x}/\mathbf{y}[i]$', fontsize=16)
axs[0].set_title('a) Classic DTW Alignment', fontsize=16)

# classic DTW warp function
cmap = mpl.cm.get_cmap("binary").copy()
cmap.set_bad(u'#0072B2')
axs[1].matshow(D, cmap=cmap)
axs[1].set_xticklabels([''] + list(range(8)))
axs[1].set_yticklabels([''] + list(reversed(range(8))))
axs[1].set_title('b) Classic DTW Warping Path', fontsize=16)
axs[1].set_xlabel(r'$i$', fontsize=16)
axs[1].set_ylabel(r'$j$', fontsize=16)
axs[1].xaxis.set_ticks_position('bottom')

# GDTW correspondences

t = np.linspace(0., 1., 8)
s_x = BatchedSignal(torch.from_numpy(x)[None, :, None], times=torch.from_numpy(t)[None, :])
s_y = BatchedSignal(torch.from_numpy(y)[None, :, None], times=torch.from_numpy(t)[None, :])

decdtw = DecDTWLayer(n_warp_fn_times=15)
gdtw_warp_fn = decdtw.forward(s_x, s_y, 0.0)

axs[2].plot(t, x, '-o', label=r'$\mathbf{x}$', c=u'#CC79A7', linewidth=2)
axs[2].plot(t, y, '-o', label=r'$\mathbf{y}$', c=u'#E69F00', linewidth=2)

for t_w, w_v in zip(gdtw_warp_fn.times.T, gdtw_warp_fn.values.T):
    axs[2].plot([t_w, w_v], [s_x(t_w[:, None]).squeeze().numpy(), s_y(w_v[:, None]).squeeze().numpy()], c='black', linewidth=1)
axs[2].set_title('c) GDTW Alignment', fontsize=16)
axs[2].set_xlabel(r'$t$', fontsize=16)
axs[2].set_ylabel(r'$\mathbf{x}/\mathbf{y}(t)$', fontsize=16)
axs[2].legend(fontsize=12)

# GDTW time warp function
axs[3].plot(gdtw_warp_fn.times.squeeze(), gdtw_warp_fn.values.squeeze(), '-o', c=u'#0072B2')
axs[3].set_xlabel(r'$t$', fontsize=16)
axs[3].set_ylabel(r'$\phi(t)$', fontsize=16)
axs[3].set_title('d) GDTW Warp Function', fontsize=16)

fig.tight_layout()
plt.savefig('gdtw_example.pdf')

In [None]:
%matplotlib notebook

np.random.seed(105)  # 20 for fig 1, 105 for fig 2

fig, axs = plt.subplots(1, 3, figsize=(9, 3))

x = np.squeeze(random_walks(n_ts=1, sz=8))
y = np.squeeze(random_walks(n_ts=1, sz=8))
t = np.linspace(0., 1., 8)
s_x = BatchedSignal(torch.from_numpy(x)[None, :, None], times=torch.from_numpy(t)[None, :])
s_y = BatchedSignal(torch.from_numpy(y)[None, :, None], times=torch.from_numpy(t)[None, :])

gdtw = DecDTWLayer(n_warp_fn_times=8)
gdtw_warp_fn_0 = gdtw.forward(s_x, s_y, 0.0)
gdtw_warp_fn_1 = gdtw.forward(s_x, s_y, 1.0)
gdtw_warp_fn_10 = gdtw.forward(s_x, s_y, 10.0)

# underlying time series
axs[0].plot(t, x, '-o', label=r'$\mathbf{x}$', linewidth=2, c=u'#CC79A7')
axs[0].plot(t, y, '-o', label=r'$\mathbf{y}$', linewidth=2, c=u'#E69F00')
axs[0].legend(fontsize=12)
axs[0].set_xlabel(r'$t$', fontsize=16)
axs[0].set_ylabel(r'$\mathbf{x}/\mathbf{y}(t)$', fontsize=16)
axs[0].set_title('a) Underlying Signals', fontsize=16)

# effect of regularisation on warp functions
axs[1].plot(gdtw_warp_fn_0.times.squeeze(), gdtw_warp_fn_0.values.squeeze(), '-o', c=u'#56B4E9', label=r'$\lambda=0$')
axs[1].plot(gdtw_warp_fn_1.times.squeeze(), gdtw_warp_fn_1.values.squeeze(), '-o', c=u'#009E73', label=r'$\lambda=1$')
axs[1].plot(gdtw_warp_fn_10.times.squeeze(), gdtw_warp_fn_10.values.squeeze(), '-o', c=u'#F0E442', label=r'$\lambda=10$')
axs[1].set_xlabel(r'$t$', fontsize=16)
axs[1].set_ylabel(r'$\phi(t)$', fontsize=16)
axs[1].set_title('b) Effect of Regularisation', fontsize=16)
axs[1].legend(fontsize=12)

# effect of local/global constraints
glb_band = torch.DoubleTensor([0.1, 0.1, 0.1, 0.1, 0.3, 0.3, 0.3, 0.3])
glb_lo, glb_up = expand_global_constraint(glb_band, s_x.times, s_y.times)
gdtw_warp_fn_glb = gdtw.forward(s_x, s_y, 0.0, band_constr=glb_band)
gdtw_warp_fn_lcl = gdtw.forward(s_x, s_y, 0.0, grad_constr=2.)

# plotting
t_band = torch.linspace(0., 1., 100, dtype=torch.double)
glb_lo, glb_up = expand_global_constraint(glb_band, t_band[None, :], s_y.times)

axs[2].plot(gdtw_warp_fn_0.times.squeeze(), gdtw_warp_fn_0.values.squeeze(), '-o', c=u'#000000', label=r'unconstr.')
axs[2].plot(gdtw_warp_fn_glb.times.squeeze(), gdtw_warp_fn_glb.values.squeeze(), '-o', c=u'#0072B2', label=r'global')
axs[2].plot(gdtw_warp_fn_lcl.times.squeeze(), gdtw_warp_fn_lcl.values.squeeze(), '-o', c=u'#D55E00', label=r'local')
axs[2].plot(t_band, glb_lo.squeeze(), c=u'#E69F00')
axs[2].plot(t_band, glb_up.squeeze(), c=u'#E69F00')
axs[2].fill_between(t_band, glb_lo.squeeze(), glb_up.squeeze(), color=u'#E69F00', alpha=0.1)
axs[2].set_xlabel(r'$t$', fontsize=16)
axs[2].set_ylabel(r'$\phi(t)$', fontsize=16)
axs[2].set_title('c) Effect of Constraints', fontsize=16)
axs[2].legend(fontsize=12)
axs[2].set_ylim(0., 1.)
axs[2].set_xlim(0., 1.)

fig.tight_layout()
plt.savefig('gdtw_constraints.pdf')

In [None]:
%matplotlib notebook
from decdtw.utils import batch_interp, batched_linspace
from decdtw.gdtw_solver.dp_solver import solve_gdtw_with_costs_cpu, refine_warp_fn_bounds_inplace
from decdtw.gdtw_solver.utils import expand_local_constraint, expand_reg_wt
from decdtw.gdtw_solver.dp_graph import warp_fn_glb_bounds, feature_node_cost
from decdtw.gdtw_solver.nlp_solver import refine_soln_nlp

np.random.seed(29)

reg_wt = 0.1
subseq_enabled = True
warp_fn_times = None
grad_constr = None
glb_constr = torch.FloatTensor([0.3, 0.3, 0.3, 0.3, 0.3, 0.4, 0.5, 0.5, 0.5, 0.5])
M = 25
feature_loss_p = 2
reg_loss_p = 2
refine_iters = 3
refine_factor = 0.25
mem_efficient = False

x1 = np.squeeze(random_walks(n_ts=1, sz=15))
x2 = x1[2:12] + np.random.normal(0, 0.2, size=10)
x1 = torch.from_numpy(x1)[None, :, None].float()
x2 = torch.from_numpy(x2)[None, :, None].float()
t1 = torch.linspace(0., 1., x1.shape[1]).unsqueeze(0)
t2 = torch.linspace(0., 1., x2.shape[1]).unsqueeze(0)
x2 = batch_interp(t2 ** 0.6, t2, x2.squeeze(2)).unsqueeze(2)  # apply simple time warp

signal1 = BatchedSignal(x1)
signal2 = BatchedSignal(x2, times=t2 * 0.7)

fig, axs = plt.subplots(2, 4, figsize=(12, 6))

device = signal1.values.device

# if no warp function times are provided, generate automatically
if type(warp_fn_times) == int:
    warp_fn_times = batched_linspace(signal2.times[:, 0], signal2.times[:, -1], warp_fn_times)
elif warp_fn_times is None:
    warp_fn_times = signal2.times

# regularisation of warp derivative is away from this value
exp_warp_grad = torch.ones((signal1.shape[0], 1), dtype=warp_fn_times.dtype, device=warp_fn_times.device)
if not subseq_enabled:
    exp_warp_grad = signal1.times.max(dim=1, keepdim=True).values / warp_fn_times.max(dim=1, keepdim=True).values

# setup warp fn value constraints
glb_band_lb, glb_band_ub = expand_global_constraint(glb_constr, warp_fn_times, signal1.times)
local_lb_vals, local_ub_vals = expand_local_constraint(grad_constr, warp_fn_times)  # gradient constraints
local_lb_vals *= exp_warp_grad  # local constraints relative to identity slope
local_ub_vals *= exp_warp_grad

# automatically set solver discretisation if none provided
if M is None:
    glb_constr = 1. if glb_constr is None else glb_constr
    max_width = glb_constr if type(glb_constr) is float else glb_constr.max()
    M = signal1.times.shape[1] if glb_constr is None else max(int(signal1.times.shape[1] * max_width) + 1, 50)

warp_fn_lb, warp_fn_ub = warp_fn_glb_bounds(warp_fn_times, signal1.times, local_lb_vals, local_ub_vals,
                                            glb_band_lb, glb_band_ub, subseq_enabled)

reg_wt = expand_reg_wt(reg_wt, warp_fn_times)

for i in range(refine_iters + 1):
    warp_fn_vals_dp = batched_linspace(warp_fn_lb, warp_fn_ub, M)
    signal_loss_node_costs = feature_node_cost(signal1, signal2, warp_fn_times, warp_fn_vals_dp, p=feature_loss_p)
    # use below instead if running into OOM issues for long time series with high dim embeddings
    # signal_loss_node_costs = feature_node_cost_l2_alt(signal1, signal2, warp_fn_times, warp_fn_vals_dp)
    if not subseq_enabled:
        signal_loss_node_costs[:, 0, 1:] = np.inf
        signal_loss_node_costs[:, -1, :-1] = np.inf

    if device.type == 'cpu':
        opt_cost, opt_dp_path = solve_gdtw_with_costs_cpu(signal_loss_node_costs, warp_fn_times, warp_fn_vals_dp,
                                                          warp_fn_lb, warp_fn_ub, local_lb_vals, local_ub_vals,
                                                          exp_warp_grad, reg_wt, reg_loss_p)
    else:
        opt_cost, opt_dp_path = solve_gdtw_with_costs(signal_loss_node_costs, warp_fn_times, warp_fn_vals_dp,
                                                      local_lb_vals, local_ub_vals, exp_warp_grad, reg_wt, reg_loss_p,
                                                      mem_efficient)

    # recover actual warp fn values from optimal path indices
    opt_warp_vals = torch.gather(warp_fn_vals_dp, dim=2, index=opt_dp_path.unsqueeze(2)).squeeze(2)
    
    if i == 0:
        optimal_align_warp_fn = BatchedSignal(opt_warp_vals, times=warp_fn_times)
        opt_cost_nlp, opt_warp_vals_nlp = refine_soln_nlp(optimal_align_warp_fn, signal1, signal2, glb_band_lb, glb_band_ub,
                                                      local_lb_vals, local_ub_vals, reg_wt, exp_warp_grad,
                                                      feature_loss_p, reg_loss_p, subseq_enabled)
        
    
    ############ plot bounds and discretisation ############
    axs[0, i].scatter(warp_fn_times.squeeze().repeat_interleave(M), warp_fn_vals_dp.reshape(-1), 0.4, c=u'#56B4E9')
    axs[0, i].plot(warp_fn_times.squeeze(), opt_warp_vals_nlp.squeeze(), label='SLSQP', c=u'#009E73')
    axs[0, i].plot(warp_fn_times.squeeze(), opt_warp_vals.squeeze(), label='DP', c=u'#E69F00')
    axs[0, i].set_ylim(0., 1.)
    axs[0, i].set_xlabel(r'$t$', fontsize=16)
    axs[0, i].set_ylabel(r'$\phi(t)$', fontsize=16)
    axs[0, i].set_title(f'Iteration {i+1}: ' + r'$\hat{f}(\phi^\ast)$' + f': {opt_cost.item():.4}', fontsize=16)
    axs[0, i].legend(fontsize=12, loc='lower right')
    
    axs[1, i].plot(t1.squeeze(), x1.squeeze(), '-o', c=u'#CC79A7', label=r'$\mathbf{x}$')
    axs[1, i].plot(t2.squeeze() * 0.7, x2.squeeze(), '-o', c=u'#E69F00', label=r'$\mathbf{y}$')
    axs[1, i].set_xlabel(r'$t$', fontsize=16)
    axs[1, i].set_ylabel(r'$\mathbf{x}/\mathbf{y}(t)$', fontsize=16)
    axs[1, i].legend(fontsize=12)
    
    for t_w, w_v in zip(warp_fn_times.T, opt_warp_vals.T):
        axs[1, i].plot([t_w, w_v], [signal2(t_w[:, None]).squeeze().numpy(), signal1(w_v[:, None]).squeeze().numpy()],
                       c='black', linewidth=1)
    
    refine_warp_fn_bounds_inplace(warp_fn_lb, warp_fn_ub, opt_warp_vals, refine_factor=refine_factor)

print(f'SLSQP solver cost: {opt_cost_nlp.item():.4}')

fig.tight_layout()
plt.savefig('gdtw_solver.pdf')