# Finding coherent sets in the Bickley jet

In [None]:
import numpy as np
from scipy.cluster.vq import kmeans2

import matplotlib.pyplot as plt

import sktime
import sktime.decomposition.vampnet as vnet

import torch
import torch.nn as nn

Dataset with 10000 particles and 401 timesteps in two dimensions.

In [None]:
dataset = sktime.data.bickley_jet(n_particles=10000, n_jobs=16)

In [None]:
# just the endpoints, i.e., first and last timestep
ds_2d = dataset.endpoints_dataset()

# mapping this dataset to 3d onto the surface of a cylinder
ds_3d = ds_2d.to_3d()

# uniform clustering of the 3d space, binning the particles accordingly
ds_3d_clusters = ds_3d.cluster(16)

# VAMP on clustered 3d data

This is the VAMP estimator applied to the binned 3d particles.

In [None]:
cov_est = sktime.decomposition.VAMP.covariance_estimator(lagtime=1)
cov_est.partial_fit((ds_3d_clusters.data, ds_3d_clusters.data_lagged))
cov = cov_est.fetch_model()

vamp_model_3d_cluster = sktime.decomposition.VAMP(dim=12).fit(cov).fetch_model()

print("score:", vamp_model_3d_cluster.score())
lsf = vamp_model_3d_cluster.transform(ds_3d_clusters.data)

In [None]:
f, ax = plt.subplots(1, 1)
ax.set_title('First 15 singular values')
ax.plot(vamp_model_3d_cluster.singular_values[:15], 'x')

f, axes = plt.subplots(4, 3, figsize=(15, 16))
for i, ax in enumerate(axes.flatten()):
    ax.set_title(f'left singular function {i}')
    ax.scatter(*ds_2d.data.T, c=lsf[:, i])

In [None]:
f, ax = plt.subplots(1, 1)
ax.set_title('Clustering in singular function space')
c_ref, l_ref = kmeans2(lsf[:, :11], 14)
plt.scatter(*ds_2d.data.T, c=l_ref)

# VAMPNets and VAMP on raw 3d data

In [None]:
assert torch.cuda.is_available()
device = torch.device("cuda:0")
torch.backends.cudnn.benchmark = True
torch.set_num_threads(12)

In [None]:
train_data, val_data = torch.utils.data.random_split(ds_3d, [len(ds_3d) - 1000, 1000])

In [None]:
lobe = nn.Sequential(
    nn.Linear(ds_3d[0][0].shape[0], 64), nn.ELU(),
    nn.Linear(64, 32), nn.ELU(),
    nn.Linear(32, 15)
).to(device=device)

In [None]:
loader = torch.utils.data.DataLoader(train_data, batch_size=512, shuffle=True, num_workers=8)
loader_val = torch.utils.data.DataLoader(val_data, batch_size=len(val_data), shuffle=False, num_workers=8)
opt = torch.optim.Adam(lobe.parameters(), 1e-3)

In [None]:
losses = []
losses_val = []

In [None]:
mode = 'regularize'
epsilon = 1e-8

for epoch in range(1500):
    lvals = []
    
    lobe.train()
    
    for batch_0, batch_t in loader:
        batch_0 = batch_0.to(device=device)
        batch_t = batch_t.to(device=device)
        
        opt.zero_grad()
        
        chi_0 = lobe(batch_0)
        chi_t = lobe(batch_t)
        
        loss = vnet.loss(chi_0, chi_t, method='VAMP2', epsilon=epsilon, mode=mode)
        loss.backward()
        opt.step()
        
        lvals.append(loss.detach().cpu().numpy())
    losses.append(np.mean(lvals))
    
    lobe.eval()
    with torch.no_grad():
        for batch_0, batch_t in loader_val:
            batch_0 = batch_0.to(device=device)
            batch_t = batch_t.to(device=device)
           
            chi_0 = lobe(batch_0)
            chi_t = lobe(batch_t)

            loss = vnet.loss(chi_0, chi_t, method='VAMP2', epsilon=epsilon, mode=mode)
            
            losses_val.append(loss.cpu().numpy())
    
    print(f"Epoch {epoch+1}/{1500}: loss={losses[-1]:.3f}, validation loss={losses_val[-1]:.3f}", 
          end='\r')

In [None]:
plt.loglog(-np.array(losses), label='K train')
plt.loglog(-np.array(losses_val), label='K val')
plt.xlabel('step')
plt.ylabel('score')
plt.legend();

In [None]:
lobe.eval()
with torch.no_grad():
    chi_X = lobe(torch.from_numpy(ds_3d[...][0]).to(device=device))
    chi_Y = lobe(torch.from_numpy(ds_3d[...][1]).to(device=device))
    chi_X = chi_X.cpu().numpy()
    chi_Y = chi_Y.cpu().numpy()

In [None]:
cov_est = sktime.decomposition.VAMP.covariance_estimator(lagtime=1)
cov_est.partial_fit((chi_X, chi_Y))
cov = cov_est.fetch_model()

vampnet_model_3d = sktime.decomposition.VAMP(dim=12).fit(cov).fetch_model()

print("score:", vamp_model_3d_cluster.score())
lsf = vampnet_model_3d.transform(chi_X)

In [None]:
f, ax = plt.subplots(1, 1)
ax.set_title('First 15 singular values')
ax.plot(vampnet_model_3d.singular_values[:15], 'x')

f, axes = plt.subplots(4, 3, figsize=(15, 16))
for i, ax in enumerate(axes.flatten()):
    ax.set_title(f'left singular function {i}')
    ax.scatter(*ds_2d.data.T, c=lsf[:, i])

In [None]:
f, ax = plt.subplots(1, 1)
ax.set_title('Clustering in singular function space')
c_ref, l_ref = kmeans2(lsf[:, :12], 10)
plt.scatter(*ds_2d.data.T, c=l_ref)

In [None]:
animation = dataset.make_animation(c=l_ref.astype(np.float32) / float(np.max(l_ref)), cmap='viridis')
from IPython.display import HTML
HTML(animation.to_html5_video())

# KVAD on 3d clustered data

In [None]:
from sktime.decomposition.kvad import kvad

In [None]:
kvad_3d_cluster_model = kvad(ds_3d_clusters.data, ds_3d_clusters.data_lagged, 
                             Y=ds_2d.data_lagged, bandwidth=1.0)
u, s_kvad_3d_cluster, v = np.linalg.svd(kvad_3d_cluster_model.K)
lsf = kvad_3d_cluster_model.fX @ u

In [None]:
f, ax = plt.subplots(1, 1)
ax.set_title('First 15 singular values')
ax.plot(s_kvad_3d_cluster[:15], 'x')

f, axes = plt.subplots(4, 3, figsize=(15, 16))
for i, ax in enumerate(axes.flatten()):
    ax.set_title(f'left singular function {i}')
    ax.scatter(*ds_2d.data.T, c=lsf[:, i])

In [None]:
f, ax = plt.subplots(1, 1)
ax.set_title('Clustering in singular function space')
c_ref, l_ref = kmeans2(lsf[:, :12], 10)
plt.scatter(*ds_2d.data.T, c=l_ref)

# KVAD with random basis functions on 2d data

In [None]:
def nonlinearity(x):
    return np.exp(-x*x)

n_basis = 500
W = np.random.normal(size=(2, n_basis))
b = np.random.uniform(-1, 1, size=(n_basis,))

W2 = np.random.normal(size=(n_basis, n_basis))
b2 = np.random.uniform(-1, 1, size=(n_basis,))

chi_X = nonlinearity(ds_2d.data @ W + b) @ W2 + b2
chi_Y = nonlinearity(ds_2d.data_lagged @ W + b) @ W2 + b2

In [None]:
kvad_2d_rnd_model = kvad(chi_X, chi_Y, Y=ds_2d.data_lagged, bandwidth=1.3)
u, s_kvad_2d_rnd, v = np.linalg.svd(kvad_2d_rnd_model.K)
lsf = kvad_2d_rnd_model.fX @ u

In [None]:
f, ax = plt.subplots(1, 1)
ax.set_title('First 15 singular values')
ax.plot(s_kvad_2d_rnd[:15], 'x')

f, axes = plt.subplots(4, 3, figsize=(15, 16))
for i, ax in enumerate(axes.flatten()):
    ax.set_title(f'left singular function {i}')
    ax.scatter(*ds_2d.data.T, c=lsf[:, i])

In [None]:
f, ax = plt.subplots(1, 1)
ax.set_title('Clustering in singular function space')
c_ref, l_ref = kmeans2(lsf[:, :12], 10)
plt.scatter(*ds_2d.data.T, c=l_ref)

In [None]:
animation = dataset.make_animation(c=l_ref.astype(np.float32) / float(np.max(l_ref)), cmap='viridis')
from IPython.display import HTML
HTML(animation.to_html5_video())

# VAMP on 2d DS w/ random basis fun

In [None]:
cov_est = sktime.decomposition.VAMP.covariance_estimator(lagtime=1)
cov_est.partial_fit((chi_X, chi_Y))
cov = cov_est.fetch_model()

vamp_2d_rnd_model = sktime.decomposition.VAMP().fit(cov).fetch_model()

s = vamp_2d_rnd_model.singular_values
print("score:", vamp_2d_rnd_model.score())
lsf = vamp_2d_rnd_model.transform(chi_X)

In [None]:
f, ax = plt.subplots(1, 1)
ax.set_title('First 15 singular values')
ax.plot(vamp_2d_rnd_model.singular_values[:15], 'x')

f, axes = plt.subplots(4, 3, figsize=(15, 16))
for i, ax in enumerate(axes.flatten()):
    ax.set_title(f'left singular function {i}')
    ax.scatter(*ds_2d.data.T, c=lsf[:, i])

In [None]:
f, ax = plt.subplots(1, 1)
ax.set_title('Clustering in singular function space')
c_ref, l_ref = kmeans2(lsf[:, :12], 10)
plt.scatter(*ds_2d.data.T, c=l_ref)

# Comparison of estimated singular values

In [None]:
f, ax = plt.subplots(1, 1, figsize=(14, 14))
ax.set_title('First 15 singular values')
ax.plot(vamp_2d_rnd_model.singular_values[:15], marker='o', linestyle='dashed', 
        label='VAMP with random transformations on 2D')
ax.plot(s_kvad_2d_rnd[:15], marker='o', linestyle='dashed', 
        label='KVAD with random transformations on 2D')
ax.plot(s_kvad_3d_cluster[:15], marker='o', linestyle='dashed', label='KVAD on 3D bins')
ax.plot(vampnet_model_3d.singular_values[:15], marker='o', linestyle='dashed', 
        label='VAMPNet + VAMP on 3D data')
ax.plot(vamp_model_3d_cluster.singular_values[:15], marker='o', linestyle='dashed', 
        label='VAMP on 3D bins')
ax.legend();