In [None]:
import sys; sys.path.append(".")
from tsne import *
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

In [None]:
def make_dataset_simplex(
    n_clusters: int,
    N_per: int = 32,
    D: int = 20,
    cluster_std: float = 0.05,
    min_sep: float = 10.0,
    device: str = "cuda",
    seed: int = 0,
) -> torch.Tensor:
    """
    Tight Gaussian clusters whose means sit on vertices of a regular simplex.

    Geometry:
      - A regular simplex with K vertices lives in R^{K-1}.
      - We embed it into R^D (requires D >= K-1).
      - We scale it so pairwise distances between means are exactly min_sep.

    Returns:
      X : (K * N_per, D) on `device`.
    """
    if n_clusters < 2:
        raise ValueError("n_clusters must be >= 2.")
    if D < n_clusters - 1:
        raise ValueError(f"Need D >= n_clusters-1 to embed a K-simplex. Got D={D}, K={n_clusters}.")

    g = torch.Generator(device=device).manual_seed(seed)

    K = n_clusters
    eps = 1e-12

    # --- Build a regular simplex in R^{K-1} ---
    # Take centered basis vectors in R^K: v_i = e_i - (1/K) 1
    # These lie in the (K-1)-dimensional subspace orthogonal to 1.
    E = torch.eye(K, device=device)
    one = torch.ones((K, 1), device=device)
    V = E - (1.0 / K) * (one @ one.T)  # (K,K), rows are centered vertices in R^K

    # Pairwise squared distance between distinct rows is constant: ||v_i - v_j||^2 = 2
    # So pairwise distance is sqrt(2). We scale to min_sep.
    V = V * (min_sep / math.sqrt(2.0))

    # --- Map the K points from R^K down to R^{K-1} explicitly (optional but clean) ---
    # Build an orthonormal basis for the subspace orthogonal to 1 (dimension K-1),
    # then project: U has shape (K, K-1), with U^T 1 = 0, U^T U = I.
    # QR of a matrix whose first column is 1 gives us the orthogonal complement.
    A = torch.randn(K, K, generator=g, device=device)
    A[:, 0] = 1.0  # force span to include the all-ones vector
    Q, _ = torch.linalg.qr(A)          # Q is (K,K) orthonormal
    U = Q[:, 1:K]                      # (K, K-1) basis for 1^‚ä•
    means_low = V @ U                  # (K, K-1), simplex vertices in R^{K-1}

    # --- Embed into R^D by padding zeros ---
    means = torch.zeros((K, D), device=device)
    means[:, : K - 1] = means_low      # occupy first K-1 coords

    # --- Sample clusters ---
    X = means.repeat_interleave(N_per, dim=0) + cluster_std * torch.randn(
        K * N_per, D, generator=g, device=device
    )

    return X

In [None]:
def run(X, config, plot=False):
    sne = SNEGD(X, cfg=config, seed=0)
    Y, loss_hist = sne.run()
    rows = pd.DataFrame(sne.history_dicts())
    
    rho_R = np.maximum(np.abs(rows["R_lmin"].to_numpy()), np.abs(rows["R_lmax"].to_numpy()))
    eps = 1e-12
    lam_min_G = rows["G_lmin"].to_numpy()
    ratio_signed = rho_R / (lam_min_G + eps)
    ratio_psd = rho_R / np.maximum(lam_min_G, eps)  # treats nonpositive as ~infinite dominance

    if plot:
        plt.plot(rows["step"].to_numpy(), rows["loss"].to_numpy(), marker="o")
        plt.xlabel("step")
        plt.ylabel("loss")
        plt.yscale("log")
        plt.title("SNE objective vs time")
        plt.grid(True, which="both", linestyle="--", linewidth=0.5)
        
        plt.figure()
        plt.plot(rows["step"].to_numpy(), rho_R, marker="o", label="rho(R)")
        plt.plot(rows["step"].to_numpy(), np.abs(lam_min_G) + eps, marker="o", label="|lambda_min(G)|")
        plt.yscale("log")
        plt.xlabel("step")
        plt.ylabel("magnitude (log scale)")
        plt.title("Curvature ingredients vs time")
        plt.grid(True, which="both", linestyle="--", linewidth=0.5)
        plt.legend()

    return rho_R

In [None]:
cfg = SNEConfig(
    d_out=2,
    perplexity=30.0,
    n_steps=1000,
    lr=0.5,
    init="pca",
    verbose_every=100,
    early_exaggeration=12.0,      # set to 1.0 to disable
    early_exaggeration_steps=0, # set to 0 to disable
    track_every=10,
    max_N_for_exact=256,
)

In [None]:
X = torch.randn(64, 20, device="cuda")
run(X, cfg)

In [None]:
cfg = SNEConfig(
    d_out=2,
    perplexity=5,
    n_steps=3000,
    lr=0.1,
    init="pca",
    verbose_every=100,
    early_exaggeration=12.0,      # set to 1.0 to disable
    early_exaggeration_steps=0, # set to 0 to disable
    track_every=10,
    max_N_for_exact=256,
)

min_seps = [0.01, 0.1, 1, 10, 100, 1000, 10000, 100_000]
ratios = []
for min_sep in min_seps:
    X = make_dataset_simplex(min_sep=min_sep, n_clusters=3, cluster_std=1, D=5)
    ratios.append(run(X, cfg))

In [None]:
ratios_df = pd.DataFrame(ratios).T
ratios_df.columns = min_seps
ratios_df
plt.figure()
for ratio, column in zip(ratios, min_seps):
    plt.plot(ratio, label=column)
plt.yscale("log")
plt.xlabel("step")
plt.ylabel("magnitude (log scale)")
plt.title("Curvature ingredients vs time")
plt.grid(True, which="both", linestyle="--", linewidth=0.5)
plt.legend()

In [None]:
ratios_df