In [None]:
import numpy as np
from baselines.scripts.variable_density_swiss_roll import non_uniform_swiss
seeds = [20251106, 20251108]

In [None]:
N = 1500
K = 2
np.random.seed(seeds[0])
X, t = non_uniform_swiss(2, 1, .2, N, pi0=0.25, K=K, dim=3)

## Data Generation

Here's a plot of the swiss roll data.

In [None]:
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')

sc = ax.scatter(X[:, 0], X[:, 1], X[:, 2], c=t, s=20, alpha=0.7, cmap='viridis')
plt.colorbar(sc, ax=ax, label='t')
plt.show()

In [None]:
import altair as alt
import pandas as pd

def plot_swiss_emb(X_emb, t):
    df_emb = pd.DataFrame(X_emb, columns=['x', 'y'])
    df_emb["t"] = t
    return alt.Chart(df_emb).mark_circle(size=60).encode(
        x='x',
        y='y',
        color='t'
    ).properties(width=400, height=300)

## $t$-SNE Distortions

In [None]:
from sklearn.manifold import TSNE
tsne = TSNE(n_components=2, perplexity=50, random_state=seeds[0], learning_rate='auto', init="random")
Z1 = tsne.fit_transform(X)

tsne = TSNE(n_components=2, perplexity=50, random_state=seeds[1], learning_rate='auto', init="random")
Z2 = tsne.fit_transform(X)

Apply procrustes to align the two embeddings.

In [None]:
from scipy.linalg import orthogonal_procrustes

Z1 = Z1 - Z1.mean(axis=0)
Z2 = Z2 - Z2.mean(axis=0)
R, scale = orthogonal_procrustes(Z2, Z1)
Z2 = Z2 @ R

In [None]:
plots = [
    plot_swiss_emb(Z1, t),
    plot_swiss_emb(Z2, t)
]

[display(p) for p in plots]

In [None]:
from distortions.geometry import Geometry, bind_metric, local_distortions, neighborhoods
from distortions.visualization import dplot
from anndata import AnnData
from sklearn.neighbors import NearestNeighbors

def distortion_plot(Z, X, t, n_neighbors=40, geom_radius=1, threshold=0.1, outlier_factor=3):
    geom = Geometry(affinity_kwds={"radius": geom_radius}, adjacency_kwds={"n_neighbors": n_neighbors})
    H, Hvv, Hs = local_distortions(Z, X, geom)
    embedding = bind_metric(Z, Hvv, Hs)
    embedding["t"] = t

    adata = AnnData(X=X)
    nn = NearestNeighbors(n_neighbors=n_neighbors, metric="euclidean").fit(X)
    knn_graph = nn.kneighbors_graph(X, mode="distance")
    adata.obsp["distances"] = knn_graph
    adata.obsm["X_tsne"] = Z

    N = neighborhoods(adata, threshold=threshold, outlier_factor=outlier_factor, embed_key="X_tsne")

    plot = dplot(embedding, height=400, width=600)\
        .mapping(x="embedding_0", y="embedding_1", color="t")\
        .inter_edge_link(N=N, strokeWidth=.2, opacity=0.9, threshold=10, stroke="#F25E7A", highlightColor="#C83F58", backgroundOpacity=0.6)\
        .geom_ellipse(radiusMin=1, radiusMax=25)
    return plot, H, embedding

In [None]:
distortion_data = [
    distortion_plot(Z1, X, t),
    distortion_plot(Z2, X, t)
]

[display(p[0]) for p in distortion_data]

## Neighbor Distance Preservation

In [None]:
from scipy.spatial.distance import cdist
from sklearn.neighbors import NearestNeighbors

# Compute pairwise distances in embedding spaces, then the ratio matrix
D1 = cdist(Z1, Z1)
D2 = cdist(Z2, Z2)
R = D1 / D2
R_inv = D2 / D1

# Compute nearest neighbors in original space X
n_neighbors = 15
nn = NearestNeighbors(n_neighbors=n_neighbors, metric='euclidean').fit(X)
knn_indices = nn.kneighbors(X, return_distance=False)

# Build mask M: M[i, j] = 1 if j is among i's nearest neighbors (excluding self)
n = X.shape[0]
M = np.zeros((n, n), dtype=int)
for i in range(n):
    for j in knn_indices[i][1:]:  # skip self (first neighbor)
        M[i, j] = 1

# For each row, compute variance of R[i, j] over j where M[i, j] == 1
# Compute analogous V_inv for R_inv.
V = np.array([np.var(R[i][M[i]==1]) for i in range(n)])
V_inv = np.array([np.var(R_inv[i][M[i]==1]) for i in range(n)])
V_max = np.maximum(V, V_inv)

In [None]:
from scipy.linalg import fractional_matrix_power, logm

Hs1 = distortion_data[0][1]  # shape: (n, 2, 2)
Hs2 = distortion_data[1][1]
norm = 'fro'
H_instability = np.linalg.norm(Hs1 - Hs2, ord=norm, axis=(1, 2))
Hs1_norm = np.linalg.norm(Hs1, ord=norm, axis=(1, 2))
Hs2_norm = np.linalg.norm(Hs2, ord=norm, axis=(1, 2))

stability_data = pd.DataFrame({
    "v_d": V_max,
    "n_H": H_instability / (Hs1_norm * Hs2_norm),
    "t": t
})

# Compute n_H_det: |det(Hs1[i]^{-1} @ Hs2[i])| for each i
n = Hs1.shape[0]
n_H_det = np.empty(n)
for i in range(n):
    n_H_det[i] = np.abs(np.log(np.linalg.det(Hs1[i])) - np.log(np.linalg.det(Hs2[i])))
stability_data["n_H_det"] = n_H_det

n_H_sim = np.empty(n)
for i in range(n):
    H1_inv_sqrt = fractional_matrix_power(Hs1[i], -0.5)
    sim = H1_inv_sqrt @ Hs2[i] @ H1_inv_sqrt
    sim_log = logm(sim)
    n_H_sim[i] = np.linalg.norm(sim_log, ord='fro')
stability_data["n_H_sim"] = n_H_sim

In [None]:
alt.Chart(stability_data).mark_circle(size=60).encode(
    x=alt.X('v_d', title='Variance of $d_i$ among neighbors'),
    y=alt.Y('n_H_det', title="$\log|H| - \log|H'|$"),
    #y=alt.Y('n_H', title="$||H - H'||_{F}$"),
    color=alt.Color('t', scale=alt.Scale(range=[
        '#3b75af',
        '#ef8636'
    ], interpolate='lab')),
).properties(
    width=400, height=300,
    title='Neighbor Preservation vs. Distortion Stability'
)

In [None]:
alt.Chart(stability_data).mark_circle(size=20).encode(
    x=alt.X('v_d', title='Variance of $d_i$ among neighbors', scale=alt.Scale(type='log')),
    y=alt.Y('n_H_sim', title="$||\log H^1/2 H' H^1/2||_{F}$", scale=alt.Scale(type='log')),
    #y=alt.Y('n_H_det', title="$|\log|H| - \log|H'||$", scale=alt.Scale(type='log')),
    #y=alt.Y('n_H', title="$||H - H'||_{F}$", scale=alt.Scale(type='log')),
    color=alt.Color('t', scale=alt.Scale(range=[
        '#3b75af',
        '#ef8636'
    ], interpolate='lab')),
).properties(
    width=400, height=300,
    title='Neighbor Preservation vs. Distortion Stability'
)

In [None]:
embedding_list = [p[2] for p in distortion_data]
groups = ["seed1", "seed2"]
for emb, group in zip(embedding_list, groups):
    emb["sample"] = embedding_list[0].index
    emb["group"] = group
    emb["V"] = np.log(V_max)

In [None]:
combined_embedding = pd.concat(embedding_list)
plot_var = dplot(combined_embedding, height = 350, width=450)\
    .mapping(x="embedding_0", y="embedding_1", color="V")\
    .geom_ellipse(opacity=0.9, radiusMin=1, radiusMax=20, stroke=True)\
    .scale_color(stroke=True)\
    .labs(x="UMAP1", y="UMAP2")

plot_group = dplot(combined_embedding, height = 350, width=450)\
    .mapping(x="embedding_0", y="embedding_1", color="group")\
    .geom_ellipse(opacity=0.9, radiusMin=1, radiusMax=20, stroke=True)\
    .scale_color(stroke=True, scheme=["green", "purple"])\
    .labs(x="UMAP1", y="UMAP2")

In [None]:
plot_var

In [None]:
plot_group

In [None]:
#plot_var.save("/Users/krissankaran/Downloads/v3.svg")
plot_group.save("/Users/krissankaran/Downloads/v4.svg")