In [1]:
%matplotlib qt
import numpy as np
import matplotlib.pyplot as plt

In [15]:

def colors_from_lbs(lbs, colors=None):
    mpl_20 = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',
          '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf',
          '#3397dc', '#ff993e', '#3fca3f', '#df5152', '#a985ca',
          '#ad7165', '#e992ce', '#999999', '#dbdc3c', '#35d8e9']
    
    if colors is None:
        colors = np.array(mpl_20)
    else:
        colors = np.array(colors)
    lbs = np.array(lbs) % len(colors)
    return colors[lbs]

def _update_mean_patch(ax, p, cmap):
    if ax.images:  # ax.images not empty
        ax.images[0].set_data(p)
        ax.images[0].set_cmap(cmap)
    else:
        ax.imshow(p, cmap=cmap)

class InteractiveCluster:

    def __init__(self, fig, pts, metadata=None, column_name=None, lbs=None, **kwargs):
        self.fig = fig
        self.ax_cluster = fig.add_subplot(121)  # Left subplot for scatter plot
        self.ax_patch = fig.add_subplot(122)    # Right subplot for mean patch

        if lbs is None:
            self.lbs_ = np.array([0] * len(pts))
        else:
            self.lbs_ = lbs
        self.colors = colors_from_lbs(self.lbs_)

        self.path_collection = self.ax_cluster.scatter(pts[:, 0], pts[:, 1], c=self.colors, **kwargs)
        self.ax_cluster.axis('equal')
        self.ax_cluster.set_title('Scatter Plot')

        self.ps = metadata[column_name] if metadata is not None and column_name in metadata.columns else None

        if self.ps is not None:
            self.ax_patch.set_xlim(0 - 0.5, len(self.ps) - 0.5)
            self.ax_patch.set_ylim(self.ps.min() - 0.5, self.ps.max() - 0.5)
            self.ax_patch.set_title(f'Scatter Plot of {column_name}')
        else:
            self.ax_patch.set_title('Mean Patch')

        self.pts = pts

        self.ind = None
        self.pts_selected = None

        self.lbs = np.full(len(self.pts), -1)

        self.num_clusters = 0

        self.lasso = LassoSelector(self.ax_cluster, onselect=self.onselect)
        self.press = self.fig.canvas.mpl_connect("key_press_event", self.press_key)

    def onselect(self, event):
        if self.ps is None:
            return
        
        path = Path(event)
        self.ind = np.nonzero(path.contains_points(self.pts))[0]
        if self.ind.size != 0:
            self.pts_selected = self.pts[self.ind]
            
            # Get selected data from metadata column
            selected_data = self.ps.iloc[self.ind]
            
            # Plot scatter points on ax_patch
            self.ax_patch.clear()
            self.ax_patch.scatter(np.arange(len(selected_data)), selected_data, c=self.colors[self.ind])
            self.ax_patch.set_xlabel('Index')
            self.ax_patch.set_ylabel(self.ps.name if self.ps is not None else '')
            self.ax_patch.set_title(f'Scatter Plot of {self.ps.name if self.ps is not None else "Unknown"}')
            self.ax_patch.grid(True)
            
            self.fig.canvas.draw_idle()

    def press_key(self, event):
        if event.key == "enter":
            if self.ind is not None and len(self.ind) > 0:
                self.lbs[self.ind] = self.num_clusters
                self.num_clusters += 1
                print("One cluster has been selected.")

def interactive_clusters(pts, metadata=None, column_name=None, lbs=None, **kwargs):
    fig = plt.figure(figsize=(12, 6))
    app = InteractiveCluster(fig, pts, metadata, column_name, lbs, **kwargs)
    plt.show()
    return app

In [16]:
import pandas as pd

In [19]:
# load test data
test = pd.read_csv('blindedtest.csv')
test.set_index('pts')
test

Unnamed: 0,pts,1,2
0,53,-10.938460,-2.056227
1,22,-5.455390,-1.962838
2,40,1.376928,-2.914821
3,89,14.770596,-3.323953
4,55,10.929103,-3.954391
...,...,...,...
564,92,2.357191,12.662691
565,63,5.174474,8.100725
566,64,5.509953,8.610787
567,34,2.781563,2.612346


In [20]:
metadata= pd.read_csv('blind_metadata.csv')
metadata

Unnamed: 0,pts,msex,educ,race,spanish,apoe_genotype,age_at_visit_max,age_first_ad_dx,age_death,cts_mmse30_first_ad_dx,cts_mmse30_lv,pmi,braaksc,ceradsc,cogdx,dcfdx_lv,class
0,53,1,14,1,2,33.0,73.489390828199859,,74.450376454483234,,27.0,7.016667,2,2,3,3,AD1
1,22,0,21,1,2,34.0,80.558521560574945,79.39493498,80.687200547570157,28.0,28.0,6.750000,3,1,4,4,AD1
2,40,0,15,1,2,34.0,78.143737166324442,,78.444900752908964,,30.0,7.450000,3,2,1,1,AD1
3,89,0,13,1,1,33.0,76.755646817248461,,77.163586584531146,,27.0,5.500000,3,2,1,1,AD1
4,55,1,20,1,2,34.0,85.215605749486656,,85.850787132101303,,27.0,10.633333,4,2,2,2,AD1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
564,92,0,21,1,2,33.0,75.32101300479124,,80.435318275154003,,30.0,4.916667,3,4,1,1,control
565,63,0,19,1,2,33.0,87.553730321697472,,88.290212183435997,,28.0,19.083333,4,3,2,2,control
566,64,0,14,1,2,33.0,87.813826146475023,,88.39151266255989,,28.0,10.216667,4,3,1,1,control
567,34,0,20,1,2,23.0,87.750855578370974,,88.637919233401774,,29.0,18.166667,2,4,1,1,control


In [21]:
interactive_clusters(test, metadata, column_name='msex')

InvalidIndexError: (slice(None, None, None), 0)