# Imports

In [2]:
import sys, os
sys.path.append(os.path.abspath(".."))
import torch
import numpy as np
import matplotlib.pylab as plt
from src.dataset import Dataset
from src.soft_dtw_torch import SoftDTWTorch, squared_euclidean_distances, jacobian_sq_euc
from src.softdtw_barycenter import softdtw_barycenter
import torch.nn as nn
import torch, numpy as np, random
from fastdtw import fastdtw
import numpy as np

def set_seed(seed=1234):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)




In [None]:
def split_60_40(X):
    X = torch.tensor(X, dtype=torch.float32)  
    
    N = X.shape[0]
    T = X.shape[1]
    idx = int(0.6 * T)

    X_in = X[:, :idx]      
    X_out = X[:, idx:]     

    return X_in, X_out


# Load data

In [4]:
ds = Dataset("ECG200")
X_train, y_train, X_test, y_test = ds.load_dataset()



Loading UCR dataset: ECG200


In [None]:
class SimpleMLP(nn.Module): # very simple model
    def __init__(self, input_len, output_len, hidden_dim=64):
        
        super().__init__()
        
        self.fc1 = nn.Linear(input_len, hidden_dim)
        self.act = nn.Sigmoid()
        self.fc2 = nn.Linear(hidden_dim, output_len)

    def forward(self, x):
        
        h = self.act(self.fc1(x))
        out = self.fc2(h)
        return out


In [6]:
Xtrain_in, Xtrain_out = split_60_40(X_train)
Xtest_in, Xtest_out = split_60_40(X_test)

Xtrain_in_flat  = Xtrain_in.squeeze(-1)
Xtrain_out_flat = Xtrain_out.squeeze(-1)

Xtest_in_flat  = Xtest_in.squeeze(-1)
Xtest_out_flat = Xtest_out.squeeze(-1)

T_in  = Xtrain_in_flat.shape[1]
T_out = Xtrain_out_flat.shape[1]

set_seed(42)
model = SimpleMLP(input_len=T_in, output_len=T_out, hidden_dim=64)


# Random init

In [None]:
#Random Init
def train_model(model,criterion, name_model,n_epochs=1000):
    loss_prec = float('inf')
    criterion = criterion

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    for epoch in range(n_epochs):  
        optimizer.zero_grad()

        y_pred_flat = model(Xtrain_in_flat)      
        loss = criterion(y_pred_flat, Xtrain_out_flat)
        loss.backward()
        optimizer.step()

        if loss_prec - loss.item() < 1e-6:
            print(f"Converged at epoch {epoch+1}")
            torch.save(model.state_dict(), f"{name_model}.pth")
            break

        loss_prec = loss.item()
        
        
        if epoch % 10 == 0:
            print(f"Epoch {epoch+1}/{n_epochs}, Loss: {loss.item():.6f}")

    return model

        

In [8]:
def make_softdtw_criterion(gamma):
    sdtw = SoftDTWTorch(gamma=gamma)
    
    def criterion(y_pred_flat, y_true_flat):
        y_pred = y_pred_flat.unsqueeze(-1) 
        y_true = y_true_flat.unsqueeze(-1)
        return sdtw(y_pred, y_true).mean()

    return criterion


In [9]:
def evaluate_dtw(model, X_in, X_out):
    dtw_vals = []
    model.eval()

    with torch.no_grad():
        y_pred_flat = model(X_in)

        for i in range(len(X_in)):
            pred = y_pred_flat[i].cpu().numpy()
            true = X_out[i].cpu().numpy()

            dist, _ = fastdtw(pred, true)
            dtw_vals.append(dist)

    return float(np.mean(dtw_vals))


In [10]:
model = train_model(model,nn.MSELoss(),"simple_mlp_ecg200",n_epochs=5000)


Epoch 1/5000, Loss: 0.424313
Epoch 11/5000, Loss: 0.224199
Epoch 21/5000, Loss: 0.173191
Epoch 31/5000, Loss: 0.153992
Epoch 41/5000, Loss: 0.142147
Epoch 51/5000, Loss: 0.130656
Epoch 61/5000, Loss: 0.120439
Epoch 71/5000, Loss: 0.111548
Epoch 81/5000, Loss: 0.104047
Epoch 91/5000, Loss: 0.097827
Epoch 101/5000, Loss: 0.092728
Epoch 111/5000, Loss: 0.088522
Epoch 121/5000, Loss: 0.084998
Epoch 131/5000, Loss: 0.082017
Epoch 141/5000, Loss: 0.079474
Epoch 151/5000, Loss: 0.077280
Epoch 161/5000, Loss: 0.075364
Epoch 171/5000, Loss: 0.073667
Epoch 181/5000, Loss: 0.072144
Epoch 191/5000, Loss: 0.070760
Epoch 201/5000, Loss: 0.069486
Epoch 211/5000, Loss: 0.068302
Epoch 221/5000, Loss: 0.067192
Epoch 231/5000, Loss: 0.066142
Epoch 241/5000, Loss: 0.065144
Epoch 251/5000, Loss: 0.064190
Epoch 261/5000, Loss: 0.063274
Epoch 271/5000, Loss: 0.062394
Epoch 281/5000, Loss: 0.061544
Epoch 291/5000, Loss: 0.060724
Epoch 301/5000, Loss: 0.059928
Epoch 311/5000, Loss: 0.059157
Epoch 321/5000, Los

In [None]:
random_results = {}

criterion_euclid = torch.nn.MSELoss()
model = SimpleMLP(T_in, T_out, hidden_dim=64)
model = train_model(model, criterion_euclid, "rand_euclid",n_epochs=5000)

random_results["euclid"] = evaluate_dtw(model, Xtest_in_flat, Xtest_out.squeeze(-1))


gammas = [0.001]

for gamma in gammas:
    print(f"\n=== SoftDTW(gamma={gamma}) — Random init ===")
    
    criterion = make_softdtw_criterion(gamma)
    model = SimpleMLP(T_in, T_out, hidden_dim=64)   
    model = train_model(model, criterion, f"rand_softdtw_gamma{gamma}",n_epochs=5000)

    dtw_err = evaluate_dtw(model, Xtest_in_flat, Xtest_out.squeeze(-1))
    random_results[f"softdtw_gamma{gamma}"] = dtw_err

print("\n--- RANDOM INITIALIZATION RESULTS ---")
for k, v in random_results.items():
    print(f"{k:20s}: {v:.4f}")


Epoch 1/5000, Loss: 0.426614
Epoch 11/5000, Loss: 0.220988
Epoch 21/5000, Loss: 0.161612
Epoch 31/5000, Loss: 0.145491
Epoch 41/5000, Loss: 0.133501
Epoch 51/5000, Loss: 0.122526
Epoch 61/5000, Loss: 0.113346
Epoch 71/5000, Loss: 0.105521
Epoch 81/5000, Loss: 0.098974
Epoch 91/5000, Loss: 0.093675
Epoch 101/5000, Loss: 0.089412
Epoch 111/5000, Loss: 0.085950
Epoch 121/5000, Loss: 0.083084
Epoch 131/5000, Loss: 0.080659
Epoch 141/5000, Loss: 0.078569
Epoch 151/5000, Loss: 0.076738
Epoch 161/5000, Loss: 0.075109
Epoch 171/5000, Loss: 0.073639
Epoch 181/5000, Loss: 0.072297
Epoch 191/5000, Loss: 0.071057
Epoch 201/5000, Loss: 0.069902
Epoch 211/5000, Loss: 0.068815
Epoch 221/5000, Loss: 0.067784
Epoch 231/5000, Loss: 0.066801
Epoch 241/5000, Loss: 0.065858
Epoch 251/5000, Loss: 0.064947
Epoch 261/5000, Loss: 0.064065
Epoch 271/5000, Loss: 0.063206
Epoch 281/5000, Loss: 0.062367
Epoch 291/5000, Loss: 0.061546
Epoch 301/5000, Loss: 0.060739
Epoch 311/5000, Loss: 0.059945
Epoch 321/5000, Los

In [12]:
random_results

{'euclid': 6.307397786915826, 'softdtw_gamma0.001': 6.388235841140904}

# Euclidean init 

In [None]:
def train_euclidean_init_then_softdtw(gamma, n_epochs_euclid=3000, n_epochs_sdtw=3000):
    print(f"Euclidean initialization + SoftDTW(gamma={gamma})")

    model = SimpleMLP(T_in, T_out, hidden_dim=64)

    print("Pretraining with Euclidean loss")
    criterion_euclid = torch.nn.MSELoss()
    model = train_model(model, criterion_euclid, 
                        f"euclid_init_stage1_gamma{gamma}", 
                        n_epochs=n_epochs_euclid)

    print("Fine-tuning with Soft-DTW")
    criterion_sdtw = make_softdtw_criterion(gamma)
    model = train_model(model, criterion_sdtw, 
                        f"euclid_init_stage2_gamma{gamma}", 
                        n_epochs=n_epochs_sdtw)

    dtw_err = evaluate_dtw(model, Xtest_in_flat, Xtest_out.squeeze(-1))
    
    return dtw_err


In [12]:
euclid_init_results = {}
gammas = [1.0, 0.1, 0.01, 0.001]

for gamma in gammas:
    dtw_err = train_euclidean_init_then_softdtw(
        gamma,
        n_epochs_euclid=3000,
        n_epochs_sdtw=3000
    )
    euclid_init_results[f"softdtw_gamma{gamma}"] = dtw_err

print("\n--- EUCLIDEAN INITIALIZATION RESULTS ---")
for k, v in euclid_init_results.items():
    print(f"{k:25s}: {v:.4f}")


Euclidean initialization + SoftDTW(gamma=1.0)
Pretraining with Euclidean loss
Epoch 1/3000, Loss: 0.500295
Epoch 11/3000, Loss: 0.239167
Epoch 21/3000, Loss: 0.172820
Epoch 31/3000, Loss: 0.154919
Epoch 41/3000, Loss: 0.143262
Epoch 51/3000, Loss: 0.132220
Epoch 61/3000, Loss: 0.122694
Epoch 71/3000, Loss: 0.114790
Epoch 81/3000, Loss: 0.107922
Epoch 91/3000, Loss: 0.102087
Epoch 101/3000, Loss: 0.097190
Epoch 111/3000, Loss: 0.093069
Epoch 121/3000, Loss: 0.089599
Epoch 131/3000, Loss: 0.086672
Epoch 141/3000, Loss: 0.084193
Epoch 151/3000, Loss: 0.082078
Epoch 161/3000, Loss: 0.080252
Epoch 171/3000, Loss: 0.078650
Epoch 181/3000, Loss: 0.077223
Epoch 191/3000, Loss: 0.075931
Epoch 201/3000, Loss: 0.074743
Epoch 211/3000, Loss: 0.073638
Epoch 221/3000, Loss: 0.072596
Epoch 231/3000, Loss: 0.071606
Epoch 241/3000, Loss: 0.070657
Epoch 251/3000, Loss: 0.069739
Epoch 261/3000, Loss: 0.068848
Epoch 271/3000, Loss: 0.067977
Epoch 281/3000, Loss: 0.067125
Epoch 291/3000, Loss: 0.066287
Epo

In [13]:
euclid_init_results

{'softdtw_gamma1.0': 6.028368127532158,
 'softdtw_gamma0.1': 5.933686647854047,
 'softdtw_gamma0.01': 5.681924049584777,
 'softdtw_gamma0.001': 6.2091918509258175}