(Visit the
[documentation](https://datafold-dev.gitlab.io/datafold/tutorial_index.html) page
to view the executed notebook.)

# Jointly Smooth Functions: An Example

For a detailed introduction see paper

Or Yair, Felix Dietrich, Rotem Mulayoff, Ronen Talmon, Ioannis G. Kerekidis, Spectral Discovery of Jointly Smooth Features for Multimodal Data, ArXiv, 2020, Available at: https://arxiv.org/abs/2004.04386

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import scipy
import scipy.sparse.linalg

# NOTE: make sure "path/to/datafold" is in sys.path or PYTHONPATH if datafold is not installed
import datafold.pcfold as pfold
from datafold.dynfold.jsf import JointlySmoothFunctions, ColumnSplitter, JsfDataset
from datafold.utils.plot import plot_parameters_and_observations

from scipy.sparse import csr_matrix, SparseEfficiencyWarning
import warnings
warnings.simplefilter('ignore', SparseEfficiencyWarning)

## Generate Parameters and Observations

We generate data for the parameters, the observations, and the effective parameters with the three functions below.

In [None]:
def generate_parameters(_x,_y):
    return np.column_stack([
        _x,
        _y,
    ])


def generate_observations(_x,_z, div = 5, mult = 6):
    return np.column_stack([
        (div/2*_z+_x/2+2/3)*np.cos(mult*np.pi*_z)/2,
        (div/2*_z+_x/2+2/3)*np.sin(mult*np.pi*_z)/2
    ])


def generate_points(n_samples):
    rng = np.random.default_rng(42)
    xyz = rng.uniform(low=-0.5, high=0.5, size=(n_samples, 3))
    x,y,z = xyz[:,0].reshape(-1,1), xyz[:,1].reshape(-1,1), xyz[:,2].reshape(-1,1)
    
    parameters = generate_parameters(x,y)
    effective_parameter = parameters[:, 0] + parameters[:, 1]**2
    observations = generate_observations(effective_parameter,z[:,0], 2, 2)
    
    return parameters, observations, effective_parameter

In [None]:
n_samples = 6000
parameters, observations, effective_parameter = generate_points(n_samples)

plot_parameters_and_observations(parameters, observations, effective_parameter)

## The JsfDataset Class

`JsfDataset` does the slicing of the data. This is needed, as `.fit`, `.transform`, and `.fit_transform` accept a single data array `X`. Thus, the multimodal data is passed in as a single array and separated inside the methods. `JsfDataset.fit_transform` of this module provides this splitting functionality. The constructor of `JsfDataset` expects: 
- a name
- a slice or list (the columns of `X` corresponding to this dataset)
- an optional kernel (default: `GaussianKernel`)
- optional dist_kwargs for the `PCManifold` created in `.fit_transform`

In [None]:
cknn_delta = 1
cknn_k_neighbor = 50
kernel1 = pfold.kernels.ContinuousNNKernel(k_neighbor=cknn_k_neighbor, delta=cknn_delta)
kernel2 = pfold.kernels.ContinuousNNKernel(k_neighbor=cknn_k_neighbor, delta=cknn_delta)

X = np.column_stack([parameters, observations])


datasets = [
    JsfDataset('parameters', slice(0, 2), kernel1, backend='scipy.kdtree', cut_off=1e-8),
    JsfDataset('observations', slice(2, 4), kernel2, backend='scipy.kdtree', cut_off=1e-8),
]

## Fit JointlySmoothFunctions model

In [None]:
jsf = JointlySmoothFunctions(
    n_kernel_eigenvectors=200, 
    n_jointly_smooth_functions=10,
    datasets=datasets,
    kernel_eigenvalue_cut_off=1e-8,
    eigenvector_tolerance=1e-10,
)

In [None]:
jsf.fit(X)

E0 = jsf.calculate_E0()

In [None]:
plt.plot(jsf.eigenvalues,'.-')
plt.axhline(y=E0, c='r')
plt.show()

In [None]:
idx_plot = np.random.permutation(n_samples)

n_plots = 8
fig, ax =plt.subplots(1, n_plots, figsize=(n_plots*3, 3),sharey=True)
for k in range(len(ax)):
    ax[k].scatter(
        effective_parameter[idx_plot],
        jsf.jointly_smooth_functions[idx_plot, k],
        c=parameters[idx_plot, 0],
        s=5,
        cmap=plt.cm.Spectral
    )

In [None]:
plot_parameters_and_observations(
    parameters=parameters,
    observations=observations,
    effective_parameter=jsf.jointly_smooth_functions[:, 1]
)

In [None]:
n_new_samples = 3000
new_parameters, new_observations, new_effective_parameter = generate_points(n_new_samples)
new_X = np.column_stack([new_parameters, new_observations])

oos_jsfs = jsf.transform(new_X)

plot_parameters_and_observations(
    parameters=new_parameters,
    observations=new_observations,
    effective_parameter=oos_jsfs[:, 1],
)