In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.lines import Line2D

import volumembo

# Three moons dataset

### Load/create the dataset

In [None]:
# Load three moons dataset
data, labels = volumembo.datasets.load_dataset("3_moons", N=100, noise=0.1)
print("Data:\t", data.shape)
print("Labels:\t", labels.shape)

### Construct a VolumeMBO object

In [None]:
MBO = volumembo.MBO(
    data=data,
    labels=labels,
    number_of_neighbors=6,
    diffusion_time=1.0,
    number_of_known_labels=5,
    # lower_limit=[20,20,20],
    # upper_limit=[40,40,40],
    # temperature=1,
    initial_clustering_method="random",
    threshold_method="fit_median_cpp",
    diffusion_method="A_3",
)

In [None]:
MBO.print_parameters()

In [None]:
### colors
cmap = "Greys"

weight_matrix = MBO.weight_matrix.todense()

print(
    "W ∈ [{}, {}] | mean: {}".format(
        np.min(weight_matrix), np.max(weight_matrix), np.mean(weight_matrix)
    )
)
#############################################################################################
fig = plt.figure(figsize=(3.5, 3.5))
gs = gridspec.GridSpec(nrows=1, ncols=1)
#############################################################################################
ax0 = fig.add_subplot(gs[0, 0])
ax0.tick_params(
    direction="in", which="both", bottom=True, top=True, left=True, right=True
)
ax0.minorticks_on()
ax0.set_aspect("equal")
ax0.set_title(r"Weight matrix")

ax0.imshow(weight_matrix, cmap=cmap, vmin=0, vmax=1)

#############################################################################################
plt.subplots_adjust(left=0.05, right=0.95, top=0.98, bottom=0.115)
# fig.savefig('./clustering0.pdf', transparent=True, dpi=600)
# fig.savefig('/home/thomas/Documents/projects/VolumeMBO/tex/figures/weight_matrix.pdf', transparent=True, dpi=600)

### Do a single MBO run and plot the result

In [None]:
MBO.make_fidelity_set()
MBO.run_mbo(verbose=True)
print(MBO.new_volume)

In [None]:
_, initial_labels, fidelity_set = MBO.get_initial_cluster()

In [None]:
### colors
cmap = "hot"
color_fidelity_set = "lime"


label_colors = np.array(["deepskyblue", "gold", "magenta"])
point_colors = label_colors[MBO.labels]
point_colors_init = label_colors[initial_labels]
point_colors_new = label_colors[MBO.new_labels]


def plot_fidelity_set(ax):
    ax.scatter(
        MBO.data[fidelity_set, 0],
        MBO.data[fidelity_set, 1],
        ec=color_fidelity_set,
        fc="none",
        s=10,
        label="Fidelity Set",
    )


#############################################################################################
fig = plt.figure(figsize=(8, 1.75))
gs = gridspec.GridSpec(nrows=1, ncols=3)
#############################################################################################
ax0 = fig.add_subplot(gs[0, 0])
ax0.tick_params(
    direction="in", which="both", bottom=True, top=True, left=True, right=True
)
ax0.minorticks_on()
ax0.set_aspect("equal")
ax0.set_title(r"Target configuration")

ax0.scatter(MBO.data[:, 0], MBO.data[:, 1], c=point_colors, edgecolor="k")
# Highlight fidelity set
plot_fidelity_set(ax0)

ax0.set_xlabel(r"x")
ax0.set_ylabel(r"y")
#############################################################################################
ax1 = fig.add_subplot(gs[0, 1])
ax1.tick_params(
    direction="in", which="both", bottom=True, top=True, left=True, right=True
)
ax1.minorticks_on()
ax1.set_aspect("equal")
ax1.set_title(r"Initial configuration")

ax1.scatter(MBO.data[:, 0], MBO.data[:, 1], c=point_colors_init, edgecolor="k")
# Highlight fidelity set
plot_fidelity_set(ax1)

ax1.set_xlabel(r"x")
#############################################################################################
ax2 = fig.add_subplot(gs[0, 2])
ax2.tick_params(
    direction="in", which="both", bottom=True, top=True, left=True, right=True
)
ax2.minorticks_on()
ax2.set_aspect("equal")
ax2.set_title(r"Final configuration")

ax2.scatter(MBO.data[:, 0], MBO.data[:, 1], c=point_colors_new, edgecolor="k")
# Highlight fidelity set
plot_fidelity_set(ax2)

ax2.set_xlabel(r"x")
#############################################################################################
plt.subplots_adjust(left=0.075, right=0.95, top=0.98, bottom=0.115)
# fig.savefig("./volumeMBO_spectral_random_init.pdf", transparent=True, dpi=600)

### Do several iterations of MBO for the three methods: argmax, fit_median, fit_median_legacy, and print the execution time

In [None]:
iterations = 100

In [None]:
MBO.set_threshold_function("argmax")
MBO.run(iterations=iterations, save_results=False)

In [None]:
MBO.set_threshold_function("fit_median_cpp")
MBO.run(iterations=iterations, save_results=False)

In [None]:
MBO.set_threshold_function("fit_median")
MBO.run(iterations=iterations, save_results=False)

In [None]:
MBO.set_threshold_function("fit_median_legacy")
MBO.run(iterations=iterations, save_results=False)

# Visualize iterative method to find median m

In [None]:
from volumembo.median_fitter import VolumeMedianFitter
from _volumembo import fit_median_cpp

In [None]:
MBO.make_fidelity_set()
# fs = MBO.fidelity_set
# chi = MBO.cluster_initialization_function(fs)
# labels = MBO.labels
chi, labels, fidelity_set = MBO.get_initial_cluster()
u = MBO.diffuse(chi)
upper = MBO.volume
lower = MBO.volume
print("χ: {}".format(chi.shape))
print("u: {} | (min, max) = ({}, {})".format(u.shape, np.min(u), np.max(u)))
print(np.min(np.sum(u, axis=1)))
print(np.max(np.sum(u, axis=1)))
print("upper limit: {}".format(upper))
print("lower limit: {}".format(lower))

In [None]:
fitter = VolumeMedianFitter(u, lower, upper)
labels_diffused, median_history = fitter.run(return_history=True)

In [None]:
test_labels = np.argmax(u - median_history[-1], axis=1)
new_sizes = np.bincount(test_labels, minlength=3)
print(new_sizes)
np.all(labels_diffused == test_labels)

In [None]:
#############################################################################################
fig = plt.figure(figsize=(14, 5))
gs = gridspec.GridSpec(nrows=1, ncols=1)
#############################################################################################
ax0 = fig.add_subplot(gs[0, 0])
ax0.tick_params(
    direction="in", which="both", bottom=True, top=True, left=True, right=True
)
ax0.minorticks_on()

simplex = volumembo.plot.SimplexPlotter(ax=ax0)
simplex.plot_simplex_outline(lw=2)
simplex.add_grid_lines()
simplex.add_ticks(n=10, show_labels=True)
simplex.set_axis_labels()
simplex.plot_points(points=u, labels=labels_diffused, ec="k", s=45)
simplex.plot_median(point=median_history[-1], s=50)
simplex.plot_trace(
    median_history, color="lime", linestyle=":", linewidth=1.5, zorder=10
)

# Custom legend entry: red line
red_line = Line2D(
    [], [], color="lime", linestyle=":", linewidth=1.5, label="Trace of order statistic"
)

# Add legend with the custom entry
ax0.legend(handles=[red_line], loc=(0, 0.8))
#############################################################################################
plt.subplots_adjust(left=0.05, right=0.95, top=1.1, bottom=0.025)
# fig.savefig("./clustering0.png", transparent=False, dpi=300)

In [None]:
from mpl_toolkits.mplot3d.art3d import Poly3DCollection

# Triangle vertices of the 2-simplex
A = [1, 0, 0]
B = [0, 1, 0]
C = [0, 0, 1]

# Create figure and 3D axes
fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")

# Plot the triangle (simplex plane)
triangle = [A, B, C]
ax.add_collection3d(
    Poly3DCollection([triangle], facecolors="lightgray", alpha=0.3, edgecolors="k")
)

# Optional: plot some points on the simplex (barycentric coordinates)
ax.scatter(u[:, 0], u[:, 1], u[:, 2], color="blue", s=20)

# Set limits and labels
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.set_zlim(0, 1)
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_zlabel("Z")
ax.view_init(elev=20, azim=0)  # Adjust view angle

plt.tight_layout()
plt.show()