In [1]:
import torch
import numpy as np
import sys,os
import matplotlib.pyplot as plt
sys.path.append('..')
from utils.utils import set_seed
import warnings
from scipy.interpolate import interp1d
warnings.filterwarnings("ignore")
import pickle

In [2]:
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
AE_model = torch.load('../checkpoints/AE_dim_4_2025_10_01_08:36:42/model.pt', map_location=device)
test_data = {}
with torch.no_grad():
    for path in os.listdir('../raw_data/test_data'):
        key = path.split('.pt')[0]
        file_path = os.path.join('../raw_data/test_data', path)
        data = torch.load(file_path, map_location=device)
        latent = []
        for tra in data:
            z = AE_model.encoder(tra)
            latent.append(z)

        latent = torch.stack(latent, dim=0)
        latent = (latent - AE_model.mean) / AE_model.std
        test_data[key] = latent

In [3]:
len(test_data.keys())

9

In [4]:
model_dict = {}
coeff_list = [1, 2.5, 5, 10, 15, 20]
seed_list = [0, 1, 2]
# seed_list = [0]
for coeff in coeff_list:
    for seed in seed_list:
        for path in os.listdir('../checkpoints'):
            if f'seed_{seed}' in path and f'coeff_{coeff:.1f}' in path:
                ckpt_path = os.path.join('../checkpoints', path, 'model.pt')
                model = torch.load(ckpt_path, map_location=device)
                print(f'Load model from {ckpt_path}')
                break
        model_dict[f'coeff_{coeff}_seed_{seed}'] = model

Load model from ../checkpoints/SDE_coeff_1.0_seed_0/model.pt
Load model from ../checkpoints/SDE_coeff_2.5_seed_0/model.pt
Load model from ../checkpoints/SDE_coeff_5.0_seed_0/model.pt
Load model from ../checkpoints/SDE_coeff_10.0_seed_0/model.pt
Load model from ../checkpoints/SDE_coeff_20.0_seed_0/model.pt


In [5]:
# ========== calculate error ==============
pred_dict = {}
with torch.no_grad():
    for key, model in model_dict.items():
        pred_list = []
        for _, true_tra in test_data.items():
            
            pred_tra = model.predict(true_tra[:, 0, :], true_tra.shape[1])
            pred_list.append(pred_tra)
            
        pred_dict[key] = pred_list

100%|██████████| 2999/2999 [00:01<00:00, 1656.28it/s]
100%|██████████| 2999/2999 [00:01<00:00, 1713.66it/s]
100%|██████████| 2999/2999 [00:01<00:00, 1736.72it/s]
100%|██████████| 2999/2999 [00:01<00:00, 1715.68it/s]
100%|██████████| 2999/2999 [00:01<00:00, 1717.63it/s]
100%|██████████| 2999/2999 [00:01<00:00, 1773.91it/s]
100%|██████████| 2999/2999 [00:01<00:00, 1760.54it/s]
100%|██████████| 2999/2999 [00:01<00:00, 1785.36it/s]
100%|██████████| 2999/2999 [00:01<00:00, 1776.47it/s]
100%|██████████| 2999/2999 [00:01<00:00, 1772.85it/s]
100%|██████████| 2999/2999 [00:01<00:00, 1695.33it/s]
100%|██████████| 2999/2999 [00:01<00:00, 1676.51it/s]
100%|██████████| 2999/2999 [00:01<00:00, 1677.96it/s]
100%|██████████| 2999/2999 [00:01<00:00, 1678.52it/s]
100%|██████████| 2999/2999 [00:01<00:00, 1675.39it/s]
100%|██████████| 2999/2999 [00:01<00:00, 1658.31it/s]
100%|██████████| 2999/2999 [00:01<00:00, 1655.89it/s]
100%|██████████| 2999/2999 [00:01<00:00, 1661.85it/s]
100%|██████████| 2999/2999 [

In [6]:
mmd_dict = {}
for key, pred_list in pred_dict.items():
    mmd_list = []
    for i, (_, true_tra) in enumerate(test_data.items()):
        pred_tra = pred_list[i].cpu().numpy()
        true_tra = true_tra.cpu().numpy()
        
        # Compute temporal MMD - MMD for each time step
        temporal_mmd_list = []
        for t in range(pred_tra.shape[1]):  # iterate over time steps (3000)
            # Get data at time step t: shape [50, 4]
            pred_t = pred_tra[:, t, :]  # [50, 4]
            true_t = true_tra[:, t, :]  # [50, 4]
            
            # Compute MMD for this time step
            xx, yy, zz = np.matmul(pred_t, pred_t.T), np.matmul(true_t, true_t.T), np.matmul(pred_t, true_t.T)
            rx = (xx.diagonal().reshape([pred_t.shape[0], 1]))
            ry = (yy.diagonal().reshape([true_t.shape[0], 1]))
            dxx = rx + rx.T - 2*xx
            dyy = ry + ry.T - 2*yy
            dxy = rx + ry.T - 2*zz

            # Multi-scale RBF kernels with different bandwidths
            sigmas = [0.25, 0.5, 1.0, 2.0, 4.0, 8.0]
            mmd_t = 0
            for sigma in sigmas:
                XX = np.exp(-dxx / (2 * sigma**2))
                YY = np.exp(-dyy / (2 * sigma**2))
                XY = np.exp(-dxy / (2 * sigma**2))
                mmd_t += XX.mean() + YY.mean() - 2*XY.mean()
            
            mmd_t /= len(sigmas)  # Average over all kernels
            temporal_mmd_list.append(mmd_t)
        
        # Take mean over all time steps
        temporal_mmd = np.mean(temporal_mmd_list)
        mmd_list.append({
            'temporal_mmd': temporal_mmd,
            'mmd_per_timestep': temporal_mmd_list
        })
    mmd_dict[key] = mmd_list

In [9]:
mmd_dict.keys()

dict_keys(['coeff_1_seed_0', 'coeff_1_seed_1', 'coeff_1_seed_2', 'coeff_2.5_seed_0', 'coeff_2.5_seed_1', 'coeff_2.5_seed_2', 'coeff_5_seed_0', 'coeff_5_seed_1', 'coeff_5_seed_2', 'coeff_10_seed_0', 'coeff_10_seed_1', 'coeff_10_seed_2', 'coeff_15_seed_0', 'coeff_15_seed_1', 'coeff_15_seed_2', 'coeff_20_seed_0', 'coeff_20_seed_1', 'coeff_20_seed_2'])

In [10]:
# Save mmd_dict to a file
with open('mmd_dict.pkl', 'wb') as f:
    pickle.dump(mmd_dict, f)