In [None]:
import umap
import numpy as np
import scipy as sci
import matplotlib.pyplot as plt
import scipy.spatial.distance as dist
import plotly.graph_objects as go
import plotly.express as px
import pandas as pd
from scipy.interpolate import interp1d

In [None]:
n_obs_per_cluster = 100

In [None]:
# utility functions

def create_clustered_data(p1, p2, n_obs_per_cluster):
    c1 = np.random.binomial(n = 1, p = p1, size = (n_obs_per_cluster, 8))
    c2 = np.random.binomial(n = 1, p = p2, size = (n_obs_per_cluster, 8))
    x1 = np.concatenate((c1, c2), axis = 0)
    return x1

def axis_bounds(embedding):
    
    left, right = embedding.T[0].min(), embedding.T[0].max()
    bottom, top = embedding.T[1].min(), embedding.T[1].max()
    adj_h, adj_v = (right - left) * 0.1, (top - bottom) * 0.1
    return [left - adj_h, right + adj_h, bottom - adj_v, top + adj_v]

In [None]:
n_slices = 10

# create some fake slices of data, where symptoms EFGH become more pronounced in cluster 1 over the slices
slice_list = [
    create_clustered_data(
        p1 = np.array([.4, .3, .4, .35, .1, .2, .05, .1]) + np.array([0, 0, 0, 0, .5, .5, .5, .5]) * i / n_slices,
        p2 = np.array([.1, .05, .1, .15, .5, .6, .45, .35]),
        n_obs_per_cluster=n_obs_per_cluster)
    for i in range(n_slices)
]

In [None]:
# compute the jaccard distance matrices for each slice
distance_matrix_list = [
    dist.squareform(dist.pdist(slice.transpose(), metric = 'jaccard'))
    for slice in slice_list
]

In [None]:
# create a relationship dictionary
relationship_dict = {i:i for i in range(8 - 1)}
relationships = [relationship_dict.copy() for i in range(n_slices - 1)]

In [None]:
%%time
aligned_mapper = umap.AlignedUMAP(
    n_neighbors=3,
    min_dist = 0.1,
    n_components = 2,
    metric='euclidean')

aligned_mapper.fit(distance_matrix_list, relations = relationships)

In [None]:
fig, axs = plt.subplots(10, 1, figsize=(5, 15))
ax_bound = axis_bounds(np.vstack(aligned_mapper.embeddings_))
for i, ax in enumerate(axs.flatten()):
    ax.scatter(*aligned_mapper.embeddings_[i].T, c=[1,1,1,1,2,2,2,2], cmap='Spectral')
    ax.axis(ax_bound)
plt.tight_layout()
plt.show()

In [None]:
aligned_mapper = umap.AlignedUMAP(
    n_neighbors=3,
    min_dist = 0.1,
    n_components = 2,
    alignment_window_size=4,
    alignment_regularisation=1e-2,
    metric='euclidean')

aligned_mapper.fit(distance_matrix_list, relations = relationships)

n_embeddings = len(aligned_mapper.embeddings_)
es = aligned_mapper.embeddings_
embedding_df = pd.DataFrame(np.vstack(es), columns=('x', 'y'))
embedding_df['z'] = np.repeat(np.linspace(0, 1.0, n_embeddings), es[0].shape[0])
embedding_df['id'] = np.tile(np.arange(es[0].shape[0]), n_embeddings)
#embedding_df['digit'] = np.tile(digits.target, n_embeddings)

fx = interp1d(
    embedding_df.z[embedding_df.id == 0], embedding_df.x.values.reshape(n_embeddings, 8).T, kind="cubic"
)
fy = interp1d(
    embedding_df.z[embedding_df.id == 0], embedding_df.y.values.reshape(n_embeddings, 8).T, kind="cubic"
)
z = np.linspace(0, 1.0, 100)

palette = px.colors.diverging.Spectral
interpolated_traces = [fx(z), fy(z)]
traces = [
    go.Scatter3d(
        x=interpolated_traces[0][i],
        y=interpolated_traces[1][i],
        z=z*3.0,
        mode="lines",
        # line=dict(
        #     color=palette[digits.target[i]],
        #     width=3.0
        # ),
        opacity=1.0,
    )
    for i in range(8)
]
fig = go.Figure(data=traces)
fig.update_layout(
    width=800,
    height=700,
    autosize=False,
    showlegend=False,
)
fig.show()