In [1]:
import os, sys
sys.path.append("..")

import torch
import numpy as np
from TrajectoryNet.dataset import EBData

from src.light_sb_ou import LightSB_OU
from src.distributions import LoaderSampler, TensorSampler
from tqdm import tqdm
from sklearn.decomposition import PCA
from TrajectoryNet.optimal_transport.emd import earth_mover_distance

## Config

In [2]:
DIM = 5
assert DIM > 1

SEED = 42
BATCH_SIZE = 128
EPSILON = 0.1
D_LR = 1e-2
D_GRADIENT_MAX_NORM = float("inf")
N_POTENTIALS = 100
SAMPLING_BATCH_SIZE = 128
INIT_BY_SAMPLES = True
IS_DIAGONAL = True
T = 3
DEVICE = "cpu"

MAX_STEPS = 2000
CONTINUE = -1

In [3]:
def setup_consistent_evaluation():
    EVAL_SEED = 0xBADBEEF 
    torch.manual_seed(EVAL_SEED)
    np.random.seed(EVAL_SEED)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(EVAL_SEED)

In [4]:
EPS = EPSILON
EPSILON_END = EPSILON

## Data loading

In [5]:
ds = EBData('pcs', max_dim=5)

frame_0_start, frame_0_end = np.where(ds.labels == 0)[0][0], np.where(ds.labels == 0)[0][-1]
frame_1_start, frame_1_end = np.where(ds.labels == 1)[0][0], np.where(ds.labels == 1)[0][-1]
frame_2_start, frame_2_end = np.where(ds.labels == 2)[0][0], np.where(ds.labels == 2)[0][-1]
frame_3_start, frame_3_end = np.where(ds.labels == 3)[0][0], np.where(ds.labels == 3)[0][-1]
frame_4_start, frame_4_end = np.where(ds.labels == 4)[0][0], np.where(ds.labels == 4)[0][-1]

X_mid_1 = ds.get_data()[frame_1_start:frame_1_end+1]
X_mid_2 = ds.get_data()[frame_2_start:frame_2_end+1]
X_mid_3 = ds.get_data()[frame_3_start:frame_3_end+1]

if T == 1:
    X_mid = X_mid_1
    
    X_0_f = ds.get_data()[frame_0_start:frame_0_end+1]
    X_1_f = ds.get_data()[frame_2_start:frame_2_end+1]
elif T == 2:
    X_mid = X_mid_2
    
    X_0_f = ds.get_data()[frame_1_start:frame_1_end+1]
    X_1_f = ds.get_data()[frame_3_start:frame_3_end+1] 
elif T == 3:
    X_mid = X_mid_3
    
    X_0_f = ds.get_data()[frame_2_start:frame_2_end+1]
    X_1_f = ds.get_data()[frame_4_start:frame_4_end+1]

X_sampler = TensorSampler(torch.tensor(X_0_f).float(), device="cpu")
Y_sampler = TensorSampler(torch.tensor(X_1_f).float(), device="cpu")



## Model training

In [6]:
b_T = [0.001, -0.1, -0.1]
mu_T = [-0.3, -0.001, -2]

In [7]:
result = []
setup_consistent_evaluation()

for T in [1, 2, 3]:    
    if T == 1:
        X_mid = X_mid_1
        X_0_f = ds.get_data()[frame_0_start:frame_0_end+1]
        X_1_f = ds.get_data()[frame_2_start:frame_2_end+1]
    elif T == 2:
        X_mid = X_mid_2
        X_0_f = ds.get_data()[frame_1_start:frame_1_end+1]
        X_1_f = ds.get_data()[frame_3_start:frame_3_end+1] 
    elif T == 3:
        X_mid = X_mid_3
        X_0_f = ds.get_data()[frame_2_start:frame_2_end+1]
        X_1_f = ds.get_data()[frame_4_start:frame_4_end+1]
    
    X_sampler = TensorSampler(torch.tensor(X_0_f).float(), device="cpu")
    Y_sampler = TensorSampler(torch.tensor(X_1_f).float(), device="cpu")
    
    trial_results = []
    for i in range(5):
        D = LightSB_OU(dim=DIM, n_potentials=N_POTENTIALS, epsilon=EPSILON, 
                        m=mu_T[T-1], b=b_T[T-1], sampling_batch_size=SAMPLING_BATCH_SIZE,
                        is_diagonal=IS_DIAGONAL).cpu()
                
        if INIT_BY_SAMPLES:
            D.init_r_by_samples(Y_sampler.sample(N_POTENTIALS).to(DEVICE))
                
        D_opt = torch.optim.Adam(D.parameters(), lr=D_LR)
                
        for step in range(CONTINUE + 1, MAX_STEPS):
            D_opt.zero_grad()
            X0, X1 = X_sampler.sample(BATCH_SIZE).to(DEVICE), Y_sampler.sample(BATCH_SIZE).to(DEVICE)
                    
            log_potential = D.get_log_potential(X1)
            log_C = D.get_log_C(X0)
                    
            D_loss = (-log_potential + log_C).mean()
            D_loss.backward()
            torch.nn.utils.clip_grad_norm_(D.parameters(), max_norm=D_GRADIENT_MAX_NORM)
            D_opt.step()
                
        with torch.no_grad():
            X = X_sampler.sample(X_0_f.shape[0]).to(DEVICE)
            X_mid_pred = D.sample_at_time_moment(X, torch.ones(X.shape[0], 1)*0.5).detach().cpu().numpy()
            EMD = earth_mover_distance(X_mid_pred, X_mid)
            trial_results.append(EMD)
            
    mean_emd = np.mean(trial_results)
    print("T mean_emd ", T, mean_emd)
    result.append(mean_emd)

print(np.mean(result), np.std(result))

T mean_emd  1 0.791927631880616
T mean_emd  2 0.826060842090936
T mean_emd  3 0.8283597020924356
0.8154493920213293 0.016658853290158898
