In [1]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets

from corc.graph_metrics import gwg, gwgmara

In [2]:
noisy_circles = datasets.make_circles(
    n_samples=1000, factor=0.5, noise=0.05, random_state=42
)

In [3]:
n_components = 15

In [4]:
mgwg = gwg.GWG(
        latent_dim=2,
        n_clusters=2, 
        n_components=n_components,
        n_neighbors=3,
        seed=42,
        )
mgwgmara = gwgmara.GWGMara(
        latent_dim=2, 
        n_clusters=2, 
        n_components=n_components,
        n_neighbors=3,
        filter_edges=False,
        seed=42
        )

In [5]:
algorithm = mgwgmara
dataset = noisy_circles

In [6]:
X, y = dataset
y = [0]*len(X) if y is None else np.array(y, dtype='int')

algorithm.fit(X)

In [None]:
from itertools import cycle, islice
            
if hasattr(algorithm, "labels_"):
    y_pred = algorithm.labels_.astype(int)
else:
    y_pred = algorithm.predict(X)

colors = np.array(
    list(
        islice(
            cycle(
                [
                    "#377eb8",
                    "#ff7f00",
                    "#4daf4a",
                    "#f781bf",
                    "#a65628",
                    "#984ea3",
                    "#999999",
                    "#e41a1c",
                    "#dede00",
                ]
            ),
            int(max(max(y_pred), max(y)) + 1),
        )
    )
)
colors = np.append(colors, ["#000000"])
plt.scatter(X[:, 0], X[:, 1], s=10, color=colors[y_pred])
algorithm.plot_graph()

In [8]:
thresholds, cluster_numbers, clusterings = algorithm.get_thresholds_and_cluster_numbers()

In [None]:
algorithm.plot_thresholds(cluster_numbers)

In [10]:
for n in np.arange(1,n_components):
    y_pred = algorithm.predict(X, target_number_clusters=n)
    colors = np.array(
        list(
            islice(
                cycle(
                    [
                        "#377eb8",
                        "#ff7f00",
                        "#4daf4a",
                        "#f781bf",
                        "#a65628",
                        "#984ea3",
                        "#999999",
                        "#e41a1c",
                        "#dede00",
                    ]
                ),
                int(max(max(y_pred), max(y)) + 1),
            )
        )
    )
    colors = np.append(colors, ["#000000"])
    plt.scatter(X[:, 0], X[:, 1], s=10, color=colors[y_pred])
    algorithm.plot_graph()
    plt.show()