In [20]:
%matplotlib qt
import numpy as np
import matplotlib.pyplot as plt

In [21]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import LassoSelector
from matplotlib.path import Path
from collections import Counter

def colors_from_lbs(lbs, colors=None):
    mpl_20 = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',
              '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf',
              '#3397dc', '#ff993e', '#3fca3f', '#df5152', '#a985ca',
              '#ad7165', '#e992ce', '#999999', '#dbdc3c', '#35d8e9']

    if colors is None:
        colors = np.array(mpl_20)
    else:
        colors = np.array(colors)
    lbs = np.array(lbs) % len(colors)
    return colors[lbs]

def _update_histogram(ax, features, feature_name):
    ax.clear()
    feature_counts = Counter(features)
    sorted_counts = sorted(feature_counts.items(), key=lambda x: x[1], reverse=True)
    ax.bar([item[0] for item in sorted_counts], [item[1] for item in sorted_counts], color='red')
    ax.set_ylabel('Count', fontsize=8)
    ax.set_title(f'Histogram of {feature_name}', fontsize=8)
    ax.tick_params(axis='x', labelrotation=45, labelsize=8)
    ax.tick_params(axis='y', labelsize=8)

def _update_histograms(axs, features, feature_names):
    for ax, feature, name in zip(axs, features.T, feature_names):
        _update_histogram(ax, feature, name)
    plt.tight_layout()

class InteractiveCluster:
    def __init__(self, fig, pts, features, feature_names, ps=None, lbs=None, **kwargs):
        self.fig = fig
        self.ax_cluster = fig.axes[0]
        self.ax_histograms = fig.axes[1:]

        if lbs is None:
            self.lbs_ = np.array([0] * len(pts))
        else:
            self.lbs_ = lbs
        self.colors = self.colors_from_lbs(self.lbs_)

        self.path_collection = self.ax_cluster.scatter(pts[:, 0], pts[:, 1], c=self.colors, **kwargs)
        self.ax_cluster.axis('equal')

        self.pts = pts
        if ps is not None:
            self.ps = ps
        else:
            self.ps = None
        self.features = features
        self.feature_names = feature_names

        self.ind = np.arange(len(pts))  
        self.pts_selected = self.pts

        self.lbs = np.array(len(pts) * [-1])

        self.num_clusters = 0

        self.lasso = LassoSelector(self.ax_cluster, onselect=self.onselect)
        self.press = self.fig.canvas.mpl_connect("key_press_event", self.press_key)

        
        self.plot_initial_histograms()

    def plot_initial_histograms(self):
        
        for i, feature_name in enumerate(self.feature_names):
            self._update_histogram(self.ax_histograms[i], self.features[:, i], feature_name)

    def onselect(self, verts):
        path = Path(verts)
        self.ind = np.nonzero(path.contains_points(self.pts))[0]
        if self.ind.size != 0:
            self.pts_selected = self.pts[self.ind]

            
            selected_features = self.features[self.ind]
            self._update_histograms(self.ax_histograms, selected_features, self.feature_names)
            self.fig.canvas.draw_idle()

    def press_key(self, event):
        if event.key == "enter":
            if self.ind.any():
                self.lbs[self.ind] = self.num_clusters
                self.num_clusters += 1
                print("One cluster has been selected.")

    def colors_from_lbs(self, lbs):
        
        return plt.cm.jet(lbs / max(lbs.max(), 1))

    def _update_histogram(self, ax, data, feature_name):
        ax.clear()
        ax.hist(data, bins=30, alpha=0.7)
        ax.set_title(f'Histogram of {feature_name}', fontsize=8)
        ax.tick_params(axis='x', labelrotation=45, labelsize=8)
        ax.tick_params(axis='y', labelsize=8)
        plt.tight_layout()

    def _update_histograms(self, ax_histograms, features, feature_names):
        for ax, feature, name in zip(ax_histograms, features.T, feature_names):
            self._update_histogram(ax, feature, name)
        plt.tight_layout()
### HERE I MAKE PS OPTIONAL 
def interactive_clusters(pts, features, feature_names, ps=None, lbs=None, **kwargs):
    num_features = features.shape[1]

    if (num_features + 1) % 2 == 0:
        n = int((num_features + 1) / 2)
        m = 2
    else:
        n = int((num_features + 2) / 2)
        m = 2

    if num_features > 2:
        fig, ax = plt.subplots(n, m, figsize=(9, 9))
    else:
        fig, ax = plt.subplots(1, num_features + 1, figsize=(6 * (num_features + 1), 6))

    for i, feature_name in enumerate(feature_names):
        if num_features > 2:
            _update_histogram(ax.flatten()[i + 1], features[:, i], feature_name)
        else:
            _update_histogram(ax[i + 1], features[:, i], feature_name)

##HI JIADONG ELNAZ I CHANGED THE DIMENSIONS OF THE SUBPLOTS
    plt.subplots_adjust(left=0.1, bottom=0.1, right=0.9, top=0.9, wspace=0.4, hspace=0.4)

    app = InteractiveCluster(fig, pts, features, feature_names, ps, lbs, **kwargs)
    return app

Here onwards is the realworld data example,
we expect two inputs, test and metadata
where test is (N,2) dataframe and metadata is (N,M) dataframe consisting of non-numerical attributes

In [42]:
test=pd.read_csv('blindedtest.csv')
meta=pd.read_csv('blind_metadata.csv')
test.set_index('pts',inplace=True)
meta.set_index('pts',inplace=True)
meta.dtypes


class      object
cogdx       int64
msex        int64
braaksc     int64
dtype: object

In [43]:
for col in meta.columns:
        meta[col] = meta[col].astype(str)
print(meta)

       class cogdx msex braaksc
pts                            
53       AD1     3    1       2
22       AD1     4    0       3
40       AD1     1    0       3
89       AD1     1    0       3
55       AD1     2    1       4
..       ...   ...  ...     ...
92   control     1    0       3
63   control     2    0       4
64   control     1    0       4
34   control     1    0       2
99   control     1    1       4

[569 rows x 4 columns]


In [44]:

"""index of metadata and test must be called ID if you need to check index"""
#this reorders index to ensure metadata and test are same
def reorder_checkingindex(test, metadata):
    if not test.index.equals(metadata.index):
        if not all(test.index.isin(metadata.index)):
            print("Indexes in test are not in metadata")
        test = test.reindex(index=metadata['ID'])
        test = test.reset_index()
    return test

#conversion to numpy array
def make_df_to_np(test,metadata):
    test=test.to_numpy()
    metadata=metadata.select_dtypes('object')
    colnames=metadata.columns.tolist()
    column_arrays = []

    for col in metadata.columns:
        column_arrays.append(metadata[col].values)

    testmeta = np.column_stack(column_arrays)
    return colnames,test,testmeta

In [53]:
test_df= reorder_checkingindex(test,meta)
colnames, test_array, meta_array = make_df_to_np(test_df,meta)
testapp = interactive_clusters(pts=test_array,ps=None,features=meta_array,feature_names=colnames, lbs=lbs)
#plt.show()

In [52]:
lbs = np.random.choice([0, 1, 2, 3, 4], 569)