In [9]:
%matplotlib qt
import numpy as np
import matplotlib.pyplot as plt

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import LassoSelector
from matplotlib.path import Path
from collections import Counter

In [10]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import LassoSelector
from matplotlib.path import Path
from collections import Counter
import seaborn as sns

def colors_from_lbs(lbs, colors=None):
    mpl_20 = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',
              '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf',
              '#3397dc', '#ff993e', '#3fca3f', '#df5152', '#a985ca',
              '#ad7165', '#e992ce', '#999999', '#dbdc3c', '#35d8e9']

    if colors is None:
        colors = np.array(mpl_20)
    else:
        colors = np.array(colors)
    lbs = np.array(lbs) % len(colors)
    return colors[lbs]

def _update_histogram(ax, features, feature_name, color):
    ax.clear()
    feature_counts = Counter(features)
    sorted_counts = sorted(feature_counts.items(), key=lambda x: x[1], reverse=True)
    ax.bar([item[0] for item in sorted_counts], [item[1] for item in sorted_counts], color=color)
    ax.set_ylabel('Count')
    ax.set_title(f'Histogram of {feature_name}', fontsize=8)
    ax.tick_params(axis='x', labelrotation=45, labelsize=8)

def _update_histograms(axs, features, feature_names, color):
    for ax, feature, name in zip(axs, features.T, feature_names):
        _update_histogram(ax, feature, name, color)
    plt.tight_layout()

class InteractiveCluster:

    def __init__(self, fig, pts, features, feature_names, lbs=None, **kwargs):
        self.fig = fig
        self.ax_kde = fig.axes[0]
        self.ax_cluster = fig.axes[1]
        self.ax_histograms = fig.axes[2:]

        if lbs is None:
            self.lbs_ = np.array([0]*len(pts))
        else:
            self.lbs_ = lbs
        self.colors = colors_from_lbs(self.lbs_)

        # KDE plot (this is a test, it is also interactive) 
        sns.kdeplot(x=pts[:, 0], y=pts[:, 1], ax=self.ax_kde, cmap="viridis", fill=True)
        self.ax_kde.set_title("KDE Plot")
        self.kde_path_collection = self.ax_kde.scatter(pts[:, 0], pts[:, 1], c=self.colors, alpha=0)  # Invisible points for selection
        self.selected_points_kde, = self.ax_kde.plot([], [], 'ro', markersize=5)
        
        # Scatter plot for interactive selection
        self.path_collection = self.ax_cluster.scatter(pts[:, 0], pts[:, 1], c=self.colors, **kwargs)
        self.ax_cluster.axis('equal')
        self.ax_cluster.set_title("Interactive Scatter Plot")

        # Initialize histograms for all data points
        self.features = features
        self.feature_names = feature_names
        self.ind = np.arange(len(pts))  # Initially, all points are selected and a general histogram is displayed
        self.pts_selected = self.pts = pts
        self.lbs = np.array(len(pts) * [-1])
        self.num_clusters = 0

        self.lasso_kde = LassoSelector(self.ax_kde, onselect=self.onselect_kde)
        self.lasso_cluster = LassoSelector(self.ax_cluster, onselect=self.onselect_cluster)
        self.press = self.fig.canvas.mpl_connect("key_press_event", self.press_key)
        self.plot_initial_histograms()

    def plot_initial_histograms(self):
        # Plot histograms of all data points
        for i, feature_name in enumerate(self.feature_names):
            _update_histogram(self.ax_histograms[i], self.features[:, i], feature_name, 'blue')

    def onselect_kde(self, verts):
        path = Path(verts)
        self.ind = np.nonzero(path.contains_points(self.pts))[0]
        self.update_selection()

    def onselect_cluster(self, verts):
        path = Path(verts)
        self.ind = np.nonzero(path.contains_points(self.pts))[0]
        self.update_selection()

    def update_selection(self):
        if self.ind.size != 0:
            self.pts_selected = self.pts[self.ind]
            # Update KDE plot with selected points
            self.selected_points_kde.set_data(self.pts_selected[:, 0], self.pts_selected[:, 1])
            # Update scatter plot with selected points
            self.path_collection.set_offsets(self.pts)
            self.path_collection.set_facecolors(self.colors)
            self.kde_path_collection.set_facecolors(self.colors)
            self.fig.canvas.draw_idle()

    def press_key(self, event):
        if event.key == "enter":
            if self.ind.any():
                # The idea is to match the color selection with the histogram color, but for now is not working (also we must change selection (like the lasso or pen) color)
                cluster_color = self.colors[self.num_clusters % len(self.colors)]
                self.lbs[self.ind] = self.num_clusters
                selected_features = self.features[self.ind]
                self.update_colors()
                _update_histograms(self.ax_histograms, selected_features, self.feature_names, cluster_color)
                self.num_clusters += 1
                print("One cluster has been selected.")
        elif event.key == "escape":
            self.undo_last_selection()

    def update_colors(self):
        self.colors = colors_from_lbs(self.lbs)
        self.path_collection.set_color(self.colors)
        self.kde_path_collection.set_color(self.colors)
        self.fig.canvas.draw_idle()

    def undo_last_selection(self):
        if self.num_clusters > 0:
            self.num_clusters -= 1
            last_indices = np.where(self.lbs == self.num_clusters)[0]
            self.lbs[last_indices] = -1
            self.update_colors()
            self.plot_initial_histograms()
            print("Last selection has been undone.")

def interactive_clusters(pts, features, feature_names, lbs=None, **kwargs):
    num_features = features.shape[1]
    fig, ax = plt.subplots(1, num_features + 2, figsize=(6 * (num_features + 2), 6))
    app = InteractiveCluster(fig, pts, features, feature_names, lbs, **kwargs)
    plt.show() 
    return app


In [11]:
# Generate random points
pts = np.random.random((1000, 2))

# Generate random patches
ps = np.random.random((1000, 45, 45))

# Generate random non-numeric features
md = np.column_stack([
    np.random.choice(['10', '20', '30'], size=(1000,)),
    np.random.choice(['M', 'F', 'NB'], size=(1000,)),
    np.random.choice(['A', 'B'], size=(1000,))
])
md_names = ['Age', 'Sex', 'Stage']

# Create the interactive clusters application
app = interactive_clusters(pts, md, md_names)
plt.show()

  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):


One cluster has been selected.
One cluster has been selected.
One cluster has been selected.
One cluster has been selected.
One cluster has been selected.


In [12]:
app.lbs

array([-1, -1, -1, -1,  2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
       -1,  1,  1, -1, -1, -1,  4,  4, -1, -1,  4, -1, -1, -1, -1, -1,  2,
       -1, -1, -1, -1, -1, -1, -1,  4, -1, -1,  3, -1, -1, -1, -1, -1, -1,
       -1,  4, -1, -1, -1,  1,  4,  4,  4, -1, -1, -1, -1,  4, -1, -1, -1,
       -1, -1, -1, -1,  1, -1, -1, -1,  4,  4,  4, -1,  4, -1, -1, -1,  2,
       -1,  4, -1,  4, -1, -1,  3, -1, -1,  4, -1, -1, -1, -1, -1, -1, -1,
       -1, -1, -1, -1, -1, -1, -1,  3, -1,  1, -1, -1, -1, -1, -1, -1, -1,
        1, -1, -1,  4, -1,  4, -1, -1, -1, -1, -1, -1, -1, -1, -1,  4, -1,
       -1, -1,  4, -1, -1,  4, -1, -1, -1, -1,  1, -1,  4, -1, -1, -1,  1,
       -1,  4, -1, -1, -1,  2,  4,  4, -1,  4,  1, -1, -1,  4, -1,  3, -1,
       -1, -1,  4,  4, -1, -1, -1,  4, -1, -1, -1, -1,  4,  4,  4, -1, -1,
       -1,  4, -1, -1, -1, -1, -1, -1, -1,  4, -1, -1,  4, -1, -1, -1, -1,
       -1, -1, -1,  4, -1,  1,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
       -1, -1, -1, -1,  4