In [None]:
import os, sys
sys.path.append("..")

import torch
import numpy as np
from TrajectoryNet.dataset import EBData

from src.light_sb import LightSB
from src.distributions import LoaderSampler, TensorSampler
from tqdm import tqdm
from sklearn.decomposition import PCA
from TrajectoryNet.optimal_transport.emd import earth_mover_distance

import wandb

## Config

In [None]:
DIM = 5
assert DIM > 1

SEED = 42
BATCH_SIZE = 128
EPSILON = 0.1
D_LR = 1e-2
D_GRADIENT_MAX_NORM = float("inf")
N_POTENTIALS = 100
SAMPLING_BATCH_SIZE = 128
INIT_BY_SAMPLES = True
IS_DIAGONAL = True
T = 1
DEVICE = "cpu"

MAX_STEPS = 2000
CONTINUE = -1

In [None]:
torch.manual_seed(SEED); np.random.seed(SEED)
EPS = EPSILON
EPSILON_END = EPSILON

In [None]:
# EXP_NAME = f'Gaussians_Mixture_test_EPSILON_{EPSILON}_STEPS_{N_STEPS}_DIM_{DIM}'
EXP_NAME = f'LightSB_cell_T_{T}_EPSILON_{EPSILON}_SEED_{SEED}'
OUTPUT_PATH = '../checkpoints/{}'.format(EXP_NAME)

config = dict(
    DIM=DIM,
    D_LR=D_LR,
    BATCH_SIZE=BATCH_SIZE,
    EPSILON=EPSILON,
    D_GRADIENT_MAX_NORM=D_GRADIENT_MAX_NORM,
    N_POTENTIALS=N_POTENTIALS,
    INIT_BY_SAMPLES=INIT_BY_SAMPLES,
    IS_DIAGONAL=IS_DIAGONAL,
    SEED=SEED,
)

if not os.path.exists(OUTPUT_PATH):
    os.makedirs(OUTPUT_PATH)

## Data loading

In [None]:
ds = EBData('pcs', max_dim=5)

frame_0_start, frame_0_end = np.where(ds.labels == 0)[0][0], np.where(ds.labels == 0)[0][-1]
frame_1_start, frame_1_end = np.where(ds.labels == 1)[0][0], np.where(ds.labels == 1)[0][-1]
frame_2_start, frame_2_end = np.where(ds.labels == 2)[0][0], np.where(ds.labels == 2)[0][-1]
frame_3_start, frame_3_end = np.where(ds.labels == 3)[0][0], np.where(ds.labels == 3)[0][-1]
frame_4_start, frame_4_end = np.where(ds.labels == 4)[0][0], np.where(ds.labels == 4)[0][-1]

X_mid_1 = ds.get_data()[frame_1_start:frame_1_end+1]
X_mid_2 = ds.get_data()[frame_2_start:frame_2_end+1]
X_mid_3 = ds.get_data()[frame_3_start:frame_3_end+1]

if T == 1:
    X_mid = X_mid_1
    
    X_0_f = ds.get_data()[frame_0_start:frame_0_end+1]
    X_1_f = ds.get_data()[frame_2_start:frame_2_end+1]
elif T == 2:
    X_mid = X_mid_2
    
    X_0_f = ds.get_data()[frame_1_start:frame_1_end+1]
    X_1_f = ds.get_data()[frame_3_start:frame_3_end+1] 
elif T == 3:
    X_mid = X_mid_3
    
    X_0_f = ds.get_data()[frame_2_start:frame_2_end+1]
    X_1_f = ds.get_data()[frame_4_start:frame_4_end+1]

X_sampler = TensorSampler(torch.tensor(X_0_f).float(), device="cpu")
Y_sampler = TensorSampler(torch.tensor(X_1_f).float(), device="cpu")

## Model initialisation

In [None]:
D = LightSB(dim=DIM, n_potentials=N_POTENTIALS, epsilon=EPSILON,
            sampling_batch_size=SAMPLING_BATCH_SIZE,
            is_diagonal=IS_DIAGONAL).cpu()

if INIT_BY_SAMPLES:
    D.init_r_by_samples(Y_sampler.sample(N_POTENTIALS).to(DEVICE))
    
D_opt = torch.optim.Adam(D.parameters(), lr=D_LR)

if CONTINUE > -1:
    D_opt.load_state_dict(torch.load(os.path.join(OUTPUT_PATH, f'D_opt_{SEED}_{CONTINUE}.pt')))

## Model training

In [None]:
wandb.init(name=EXP_NAME, config=config)

for step in tqdm(range(CONTINUE + 1, MAX_STEPS)):
    # training cycle
    D_opt.zero_grad()
    
    X0, X1 = X_sampler.sample(BATCH_SIZE).to(DEVICE), Y_sampler.sample(BATCH_SIZE).to(DEVICE)
    
    log_potential = D.get_log_potential(X1)
    log_C = D.get_log_C(X0)
    
    D_loss = (-log_potential + log_C).mean()
    D_loss.backward()
    D_gradient_norm = torch.nn.utils.clip_grad_norm_(D.parameters(), max_norm=D_GRADIENT_MAX_NORM)
    D_opt.step()
    
    wandb.log({f'D gradient norm' : D_gradient_norm.item()}, step=step)
    wandb.log({f'D_loss' : D_loss.item()}, step=step)
    
# eval and plots
with torch.no_grad():
    X = X_sampler.sample(X_0_f.shape[0]).to(DEVICE)
    Y = Y_sampler.sample(X_1_f.shape[0]).to(DEVICE)

    XN = D(X)

    X_mid_pred = D.sample_at_time_moment(X, torch.ones(X.shape[0], 1)*0.5).detach().cpu().numpy()

    EMD = earth_mover_distance(X_mid_pred, X_mid)

    wandb.log({f'EMD_{T}' : EMD}, step=step)
            
torch.save(D.state_dict(), os.path.join(OUTPUT_PATH, f'D.pt'))
torch.save(D_opt.state_dict(), os.path.join(OUTPUT_PATH, f'D_opt.pt'))

wandb.finish()