In [13]:
from datasets2d import *

ModuleNotFoundError: No module named 'datasets2d'

In [2]:
import functools

import jax
import jax.lax as lax
import jax.numpy as jnp
import jax.scipy.stats as jstats
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import sklearn.datasets
import sklearn.mixture
import tqdm


import datasets2d
import importlib

importlib.reload(datasets2d)

ModuleNotFoundError: No module named 'jax'

In [3]:
from cycler import cycler

# Set the color cycle to use 'tab20'
plt.rcParams["axes.prop_cycle"] = cycler(
    color=plt.cm.tab10.colors,  # type: ignore
    marker=[
        "o",
        "s",
        "D",
        "x",
        "P",
        "H",
        "d",
        "v",
        "^",
        "<",
        ">",
        "p",
        "*",
        "+",
        "X",
        "|",
        "_",
        ".",
        ",",
        "o",
    ][:10],
)

NameError: name 'plt' is not defined

In [4]:
@jax.jit
@functools.partial(jax.vmap, in_axes=(0, None, None, None))
def gmm_jax(x, means, covs, weights):
    p = []

    for mean, cov, weight in zip(means, covs, weights):
        p.append(jstats.multivariate_normal.logpdf(x, mean, cov) + jnp.log(weight))

    # logsumexp
    p = jnp.stack(p, axis=-1)
    p = jax.scipy.special.logsumexp(p, axis=-1)
    return p


@jax.jit
def tmm_jax(x, means, covs, weights):
    p = []

    # pdfs of the individual components
    for mean, cov, weight in zip(means, covs, weights):
        p.append(jstats.t.logpdf(x, mean, cov) + jnp.log(weight))

    # adding them all up
    p = jnp.stack(p, axis=-1)
    p = jax.scipy.special.logsumexp(p, axis=-1)
    return p


# the loss of the interpolation
def loss(ms, means, covs, weights, dif_refrence, mixture_model=gmm_jax):
    # maximize the negative log likelihood
    nll = -mixture_model(ms, means, covs, weights).sum()

    # calculate the difference between the points
    dif = jnp.diff(ms, axis=0)
    dif = (dif**2).sum(axis=1)

    # penalize the difference between the points compared to the reference
    # difference from the linear interpolation
    dif = (dif - dif_refrence) ** 2
    dif = dif.sum()

    return nll * 1e-1  # + dif * 1e-2


# perform gradient descent on the points
@jax.jit
def step(ms, means, covs, weights, dif_refrence):
    g = jax.grad(loss)(ms, means, covs, weights, dif_refrence)
    ms_new = ms - 1e-1 * g
    ms = ms.at[1:-1].set(ms_new[1:-1])
    return ms


# reinterpolate the points to have a uniform distance along the same path
@jax.jit
def reinterpolate(ms):
    dif = jnp.diff(ms, axis=0)
    dif = (dif**2).sum(axis=1) ** 0.5
    dif = jnp.cumsum(dif)
    dif = jnp.concatenate([jnp.array([0]), dif])
    dif = dif / dif[-1]

    ts = jnp.linspace(0, 1, ms.shape[0])
    rms = jax.vmap(jnp.interp, in_axes=(None, None, -1))(ts, dif, ms)
    return rms.T


# given two indicies i and j find interpolation between the ith and jth means
def compute_interpolation(i, j, means, covs, weights):
    m1, m2 = means[i], means[j]

    # linear interpolation between m1 and m2
    ts = jnp.linspace(0, 1, 1024)[..., None]
    ms = (1 - ts) * m1 + ts * m2

    # euclidean distance
    dif_refrence = ((ms[0] - ms[1]) ** 2).sum()

    # this is where the band becomes elastic and where the work is done.
    for i in range(1024):
        ms = step(ms, means, covs, weights, dif_refrence)
        ms = reinterpolate(ms)

    ps = gmm_jax(ms, means, covs, weights)
    return ms, ts, ps


# score the interpolation for the final graph
@jax.jit
def score_interpolation(xs):
    # approximate by the decreasing sequence
    min_ = lax.cummin(xs)
    # find largest difference to decreasing approximation
    min_ = jnp.abs(min_ - xs).max()

    # the same but for increasing
    max_ = lax.cummax(xs)
    max_ = jnp.abs(max_ - xs).max()

    # take the better approximation
    return jnp.minimum(min_, max_)

NameError: name 'jax' is not defined

In [5]:
data_X, data_y = datasets2d.DATASETS["Clusterlab10"]()
n_components = 12

gmm = sklearn.mixture.GaussianMixture(
    n_components=n_components,
    covariance_type="spherical",
)
gmm.fit(data_X)


NameError: name 'datasets2d' is not defined

In [6]:
# plot the data and density
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))

x = np.linspace(data_X[:, 0].min() - 0.1, data_X[:, 0].max() + 0.1, 128)
y = np.linspace(data_X[:, 1].min() - 0.1, data_X[:, 1].max() + 0.1, 128)

XY = np.stack(np.meshgrid(x, y), -1)

p = gmm.score_samples(XY.reshape(-1, 2)).reshape(128, 128)

p_jax = gmm_jax(XY.reshape(-1, 2), gmm.means_, gmm.covariances_, gmm.weights_)
p_jax = p_jax.reshape(128, 128)

print(np.abs(p - p_jax).max())

r = ax1.contourf(x, y, p, levels=64, cmap="coolwarm", alpha=0.5)

y = gmm.predict(data_X)

ax1.scatter(
    data_X[:, 0],
    data_X[:, 1],
    # c=y,
    # cmap="tab10",
    color="gray",
    alpha=0.5,
)

# plot centers
ax1.scatter(
    gmm.means_[:, 0],  # type: ignore
    gmm.means_[:, 1],  # type: ignore
    c="black",
    s=256,
    marker="x",
    lw=3,
)


adjacency = np.zeros((n_components, n_components))


# iterate through all pairs of indicies to build the graph
for i in range(n_components):
    for j in range(i + 1, n_components):

        ms, ts, ps = compute_interpolation(
            i, j, gmm.means_, gmm.covariances_, gmm.weights_
        )

        score = score_interpolation(ps)
        adjacency[i, j] = score

        # in case something goes wrong with the initialisation
        if np.isnan(ps).any():
            print("nan in", i, j)
            continue

        # only draw soem of the interpolations
        if i > 0:
            continue

        if j % 2 != 0:
            continue

        # print(ms)

        # if i == 0:
        # if j % 2 == 0:
        k = ax1.plot(ms[:, 0], ms[:, 1], lw=3, alpha=0.5)

        ax2.plot(
            ts,
            ps,
            color=k[0].get_color(),
            marker=k[0].get_marker(),
        )

        # print(i, j)
        # break

    # break

adjacency = adjacency + adjacency.T
adjacency = adjacency - np.diag(np.diag(adjacency))
adjacency = adjacency / adjacency.max()

adjacency = 1 - adjacency


for i in range(n_components):
    for j in range(i + 1, n_components):
        if (adjacency[i, j]) < 0.7:
            continue

        ax1.plot(
            [
                gmm.means_[i, 0],  # type: ignore
                gmm.means_[j, 0],  # type: ignore
            ],
            [
                gmm.means_[i, 1],  # type: ignore
                gmm.means_[j, 1],  # type: ignore
            ],
            color="black",
            alpha=(adjacency[i, j]),
        )

        # write value of adjacency nex to line
        ax1.text(
            (gmm.means_[i, 0] + gmm.means_[j, 0]) / 2,  # type: ignore
            (gmm.means_[i, 1] + gmm.means_[j, 1]) / 2,  # type: ignore
            f"{float(adjacency[i, j]):.2f}",
            alpha=(adjacency[i, j]),
        )


ax1.set_title("GMM")
ax2.set_title("Interpolation")

ax2.set_xlabel("t")
ax2.set_ylabel("log likelihood")

plt.tight_layout()
plt.show()

plt.hist(adjacency.flatten())

NameError: name 'plt' is not defined

In [7]:
!ls

2d_datasets_neb_examples.ipynb
README.md
[34mconfigs[m[m
environment.yml
graphdino_morphological_embeddings_tsne.pkl
main.py
[34mnotebooks[m[m
setup.py
[34msrc[m[m


In [8]:
from graphcut import disconnected_cut_threshold, label_components

ModuleNotFoundError: No module named 'graphcut'

In [9]:
unique_dist = sorted(list(set(adjacency.flatten())))

NameError: name 'adjacency' is not defined

In [10]:
cc_dict = {}
g = nx.from_numpy_array(adjacency)
prev_cc = nx.number_connected_components(g) 
cc_dict[prev_cc] = 0
for i in unique_dist:
    g = nx.from_numpy_array(adjacency > i)
    cc = nx.number_connected_components(g) 
    if cc >  prev_cc:
        print(cc, i)
        prev_cc = cc
        cc_dict[prev_cc] = i

NameError: name 'nx' is not defined

In [11]:
for i in range(2, n_components):
    t = cc_dict[i]
    adjacency_cut = adjacency > t
    cs = label_components(adjacency_cut)

    plt.scatter(
        data_X[:, 0],
        data_X[:, 1],
        # c=y,
        # cmap="tab10",
        color="gray",
        alpha=0.5,
    )

    plt.scatter(
        gmm.means_[:, 0],  # type: ignore
        gmm.means_[:, 1],  # type: ignore
        c=cs,
        cmap="tab10",
    )
    plt.title(f"Threshold: {t:.2f}, Clusters {i}")
    plt.show()




NameError: name 'n_components' is not defined

In [12]:
plt.xlabel("Number of clusters")
plt.ylabel("Normalised Threshold")

plt.plot([*range(2, n_components)], list(cc_dict.values())[1:-1])
plt.xticks([*range(2, n_components)])
plt.grid()
plt.show()

NameError: name 'plt' is not defined

In [None]:
# pip install -U "jax[cuda12]"