# Feature Extraction and Morphology Clustering using Vista2D and RAPIDS

For background information on the methods outlined see the accompanying [blog post](). 

## 1. Feature Extraction with Vista2D

### Loading a Pre-trained Model

This lab uses a pre-trained Vista2D model, which is provided as a model checkpoint below. 

In [None]:
model_ckpt = "cell_vista_segmentation/results/model.pt"

### Running Segmentation

To keep the notebook clean, most of the Vista2D code is wrapped up in helper functions which can be found in `segmentation.py` in the root directory of this repository. For now let's just import the functions we will be using. 

In [None]:
from segmentation import segment_cells, plot_segmentation, feature_extract, plot_cells_by_class

This lab uses data from [Live Cell](https://www.nature.com/articles/s41592-021-01249-6), a database of live cell imaging designed for this type of segmentation task. 

In the code cell below, we single out one of the images from this database and run it through segmentation. 

In [None]:
img_path="example_livecell_image.tif"
patch, segmentation, pred_mask = segment_cells(img_path, model_ckpt)

This function gives us back the original cell image, the segmentations, and the instances of each segmentation. We will plot that information below. 

In [None]:
plot_segmentation(patch, segmentation, pred_mask)

Now that we have segmentation complete we can see how many cells are in the image, and we even have enough information to segment out each individual cell. In the next section of the lab, we will take each individual cell and run a feature extractor on it, to get a representation of that cell in the feature space of the model. These feature vectors will be used later for clustering. 

### Feature Extraction

The helper function `feature_extract` will run each cell segmentation through a feature extractor. For this lab, we are using the Vista2D model again, but this time as a feature extractor. We will send it through the first half of the model, the encoding portion, so that the model can encode information about the cell shape in this feature vector. 

In [None]:
cell_features = feature_extract(pred_mask, patch, model_ckpt) 

Now in `cell_features` each row represents a different cell segmentation from the image, and each column represents a feature in the embedding space of the Vista2D model. Thus we have a full feature matrix for the cell segmentations and we are reading for clustering. 

## 2. Morphology Clustering with RAPIDS

### What is RAPIDS

[RAPIDS](https://developer.nvidia.com/rapids) is an suite of tools from NVIDIA for accelerated data science. In this lab, we will use it to run clustering on our cell segmentations. 

Below are the libraries we will need for this section of the lab. 

In [None]:
import numpy as np
from cuml import DBSCAN, TruncatedSVD

### Running clustering

To make our clustering problem easier, and so that we can visualize the results later, let's reduce our feature vector size from 256 down to 3 (one for x, y, z dimensions). To do this, we can use the `TruncatedSVD` algorithm from the `cuml` package in RAPIDS and we will select that we want 3 components. 

In [None]:
dim_red_model = TruncatedSVD(n_components=3)
X = dim_red_model.fit_transform(cell_features)

Now we are ready to run clustering. For this we will use `DBSCAN` from the `cuml` package in RAPIDS. 

In [None]:
model = DBSCAN(eps=0.003, min_samples=2)
labels = model.fit_predict(X)

### Visualizing the clusters

The clustering is done, but in order to understand it ourselves, we must visualize it somehow. In this section we will show a few ways to look at the data. 

In [None]:
import plotly
import plotly.graph_objs as go

# Configure Plotly to be rendered inline in the notebook.
plotly.offline.init_notebook_mode()

We can create a dictionary of cluster labels and the cell segmentations they contain. This will also make it easier for downstream analysis so we can index into clusters and data points. 

In [None]:
# Background is 0, so cell IDs start at 1
labels_dict = {x:np.add(np.where(labels==x),1) for x in np.unique(labels)}

# Label -1 means "data was too noisy" so we remove it
labels_dict.pop(-1)
labels_dict

Now we can see we have 4 clusters, and we can see how many cell segmentations belong to each cluster. Let's visualize these clusters in 3D space using a scatterplot. 

In [None]:
data = []

for l in labels_dict.keys():
    
    cluster_indices = labels_dict[l][0]-1

    # Configure the trace
    trace = go.Scatter3d(
        x=X[cluster_indices,0],  
        y=X[cluster_indices,1],  
        z=X[cluster_indices,2],
        name="Cluster "+str(l),
        mode='markers',
        marker={
            'size': 10,
            'opacity': 0.8,
        }
    )
    
    data.append(trace)

# Configure the layout
layout = go.Layout(
    margin={'l': 0, 'r': 0, 'b': 0, 't': 0}
)

plot_figure = go.Figure(data=data, layout=layout)

# Render the plot
plotly.offline.iplot(plot_figure)

This method with these parameters uses 4 distinct clusters. 

## Conclusion