# RAPIDS + Plotly Dash on Paperspace ML-Driven Web Apps - Tutorial 3
## Install Extra Libraries
The RAPIDSAI container is our starting point.  However, to build this dashboard, we need some extra libraries from plot.ly and jupyter.  You'll only need to install this once per container.  Leave uncommented if this is your first run.  Otherwise comment the `!pip install` lines to save time running. 

In [1]:
# !pip install jupyter-dash tables plotly
# !pip install dash dash-bootstrap-components dash-html-components matplotlib plotly
# 
## To download data file
# !wget https://rapidsai-data.s3.us-east-2.amazonaws.com/community-examples/plotly-webapp-demo/scrna_data.h5

In [2]:
import cudf
import cuml
import cupy as cp
import numpy as np

import plotly.graph_objects as go

import dash
import dash_bootstrap_components as dbc
import dash_core_components as dcc
import dash_html_components as html
from dash.dependencies import Input, Output, State

from flask import request

from jupyter_dash import JupyterDash
from IPython.display import display, HTML

import warnings
warnings.filterwarnings('ignore', 'Expected ')
warnings.simplefilter('ignore')

In the previous two tutorials we learned:
 - Constructing an interactive application using Plotly and Dash components
 - Use RAPIDS for data preprocessing and ML

In this tutorial we will inter twine UI and ML functions.

# Import the Data
We load Single Cell RNA data container cell wise(rows) genes(columns) counts. We will be using RAPIDS to infer information from the gene count matrix. Once the data is loaded, we arbitrarily select the first three genes as MARKER genes. We will use these MARKER genes data for comparision against processed data.

In [3]:
%%time
# If this cell fails, please download 'scrna_data.h5' from
# https://rapidsai-data.s3.us-east-2.amazonaws.com/community-examples/plotly-webapp-demo/scrna_data.h5
df = cudf.read_hdf('scrna_data.h5')
df_markers = df.iloc[:, [0, 1, 2]]

MARKERS = df.columns[0:3].tolist()
MARKERS = [x + '_val' for x in MARKERS]

df_markers.columns = MARKERS

additional_cols = [x for x in MARKERS]
additional_cols

CPU times: user 4.81 s, sys: 2.75 s, total: 7.57 s
Wall time: 7.62 s


['Xkr4_val', 'Gm1992_val', 'Sox17_val']

# Process
Now, we will use [cuML](https://docs.rapids.ai/api/cuml/stable/api.html) in RAPIDS to process the gene count matrix. First RAPIDS is used to perform PCA for dimensionality reduction and the reduced matrix is used to identify clusters using KMeans and learn the manifolds using UMAP. We also store the original index of the dataframe for correlation later to manage events in the dashboard.

In [4]:
%%time
n_pca=50
n_clusters=30

def cluster(df, n_pca_comp=50, n_clusters=30, additional_cols=None):

    if 'orig_index' in df.columns:
        print('Deleting orig_index')
        idx = df['orig_index']
        df.drop(['orig_index'], inplace=True, axis=1)
    else:
        idx = df.index
    
    prop_series, df = remove_non_feature_cols(df, additional_cols)
    
    pca = cuml.PCA(n_components=n_pca_comp)
    pca.fit(df)
    embedding = pca.transform(df)

    kmeans_cuml = cuml.KMeans(n_clusters=n_clusters)
    kmeans_cuml.fit(embedding)
    kmeans_labels = kmeans_cuml.predict(embedding)

    umap = cuml.manifold.UMAP()
    Xt =  umap.fit_transform(embedding)

    embedding['labels'] = kmeans_labels
    embedding['x'] = Xt[0]
    embedding['y'] = Xt[1]
    embedding.index = idx
    embedding['orig_index'] = idx
    
    for col in prop_series.keys():
        embedding[col] = prop_series[col]
    return embedding

def remove_non_feature_cols(df, additional_cols):
    prop_series = {}
    for col_name in additional_cols:
        if col_name in df.columns:
            prop_series[col_name] = df[col_name]
            df = df.drop([col_name], axis=1)

    for col_name in ['labels', 'x', 'y']:
        if col_name in df.columns:
            df.drop([col_name], inplace=True, axis=1)
            
    return prop_series, df


orig_df = cluster(df, n_pca_comp=n_pca, n_clusters=n_clusters, additional_cols=additional_cols)

# Add marker genes back
orig_df = cudf.concat([orig_df, df_markers], axis=1)
orig_df.shape

CPU times: user 5.41 s, sys: 487 ms, total: 5.9 s
Wall time: 5.9 s


(100000, 57)

In [5]:
orig_df.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,47,48,49,labels,x,y,orig_index,Xkr4_val,Gm1992_val,Sox17_val
0,1.868425,0.336145,1.421227,-2.847102,-0.633875,3.891086,3.511035,-3.105289,-3.548421,0.086103,...,0.845912,1.19573,-1.716885,23,-2.693675,-0.198515,0,-0.139578,-0.005757,0.049112
1,0.660085,-0.41182,2.13223,-4.078625,-1.024646,3.650779,3.14976,-5.575364,0.577942,-0.278277,...,-0.046324,-0.700149,-0.310466,9,-0.709105,0.085514,1,-0.130967,0.001012,0.045747
2,13.293064,3.452299,-5.101409,-3.4598,0.248678,0.371149,1.881885,-2.333594,0.43003,2.331941,...,-0.615356,-2.6874,-1.101635,12,-4.732382,5.286013,2,-0.133182,0.000788,0.018643
3,-4.962597,-1.0948,-1.184917,12.477499,0.852307,-2.062974,-0.705089,-7.453671,-0.3262,-1.523405,...,2.053468,-0.682838,1.046343,1,6.900117,4.495668,3,-0.155867,-0.009243,-0.116312
4,-9.520083,0.326495,2.546679,-5.504425,-0.62128,0.145227,-0.697787,-4.823132,3.346006,0.645233,...,0.951622,-0.019756,-0.619416,3,2.419586,-2.063457,4,-0.159265,-0.004439,-0.25279


In [6]:
# Utility function
def generate_colors(num_colors):
    """
    Generates evenly disributed colors
    """
    a = ((np.random.random(size=num_colors) * 255))
    b = ((np.random.random(size=num_colors) * 255))
    return ["#%02x%02x%02x" % (int(r), int(g), 125) for r, g in zip(a, b)]

colors = generate_colors(24)

## Let's make scatter plot!

The following method uses plotly to create a [scatter plot](https://plotly.com/python/line-and-scatter/). This plot will have the following properties:
- Each cell is rendered as a point at the location(x and y) computed using UMAP
- Color each cell using the 'label' generated using KMeans
- Set cluster label as the text to be displayed on mouse over

In [7]:
def generate_main_graph(df):
    fig = go.Figure(layout = {'colorscale' : {}})
    label_count = len(df['labels'].unique())
    
    for i in range(0, label_count):
        si = str(i)
        gdf = df.query('labels == ' + si)
        fig.add_trace(
            go.Scattergl({
            'x': gdf['x'].to_array(),
            'y': gdf['y'].to_array(),
            'text': gdf['labels'].to_array(),
            'customdata': gdf['orig_index'].to_array(),
            'name': 'Cluster ' + si,
            'mode': 'markers',
            'marker': {'size': 2, 'color': colors[i % len(colors)]}
        }))

    fig.update_layout(
            showlegend=True, clickmode='event', title='UMAP', dragmode='select',
            annotations=[
                dict(x=0.5, y=-0.07, showarrow=False, text='UMAP_1', xref="paper", yref="paper"),
                dict(x=-0.05, y=0.5, showarrow=False, text="UMAP_2", xref="paper", yref="paper", textangle=-90, )])
    return fig

Implement functions that will filter data based on user action and recompute UMAP and recluster the filtered data.

In [8]:
def update_scatter_plots(df, additional_cols=None):
    localdf = cluster(df, n_pca_comp=n_pca, n_clusters=n_clusters, additional_cols=additional_cols)

    fig = generate_main_graph(localdf)
    fig_marker1 = graph_scatter(localdf, MARKERS[0])
    fig_marker2 = graph_scatter(localdf, MARKERS[1])
    fig_marker3 = graph_scatter(localdf, MARKERS[2])

    return localdf, fig, fig_marker1, fig_marker2, fig_marker3


def update_scatter_plots_by_points(df, selected_point_indexes, additional_cols=None):
    df = df[df['orig_index'].isin(selected_point_indexes)]
    return update_scatter_plots(df, additional_cols=additional_cols)


def update_scatter_plots_by_cluster_ids(df, selected_cluster_ids, additional_cols=None):
    df = df[df['labels'].isin(selected_cluster_ids)]
    return update_scatter_plots(df, additional_cols=additional_cols)

Create map for the marker genes

In [9]:
def graph_scatter(df, label):
    
    scatter_fig = go.Figure(layout = {'colorscale' : {}})
    df[label] = df[label].round(3)
    gdf = df.query(label + ' != 0').groupby(['labels',label], as_index=False).agg({'x':'count'})
    
    xmax = gdf['x'].max()
    xmin = gdf['x'].min()
    gdf = df.groupby(['labels', label], as_index=False).agg({'x':'count'})
    
    for i in range(0, len(gdf['labels'].unique())):
        sdf = gdf.query('labels == ' + str(i))
        _t = sdf['x'].astype('int')
        _ace = (sdf[label] * 10).astype('int')
        _z = cp.zeros(21)
        cp.put(_z, _ace, _t)
        _z[_z == 0] = None
        scatter_fig.add_trace(
                go.Scattergl({
                'x': sdf[label].to_array(),
                'y': sdf['labels'].to_array(),
                'text': cp.asnumpy(_z),
                'mode': 'markers',
                'marker':{
                    'size':4,
                    'color':cp.asnumpy(_z),
                    'colorscale':'Viridis',
                    'cmin': xmin,
                    'cmax': xmax,
                    'showscale':True
                }
            }))

    scatter_fig.update_layout(showlegend=False, clickmode='select', title=label, dragmode='select',
        annotations=[
            dict(x=0.5,   y=-0.15, showarrow=False, xref="paper", yref="paper", text=label),
            dict(x=-0.15, y=0.5,   showarrow=False, xref="paper", yref="paper", text="Clusters", textangle=-90)
        ],)
    return scatter_fig

In [10]:
tdf = orig_df.copy()
tdf.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,47,48,49,labels,x,y,orig_index,Xkr4_val,Gm1992_val,Sox17_val
0,1.868425,0.336145,1.421227,-2.847102,-0.633875,3.891086,3.511035,-3.105289,-3.548421,0.086103,...,0.845912,1.19573,-1.716885,23,-2.693675,-0.198515,0,-0.139578,-0.005757,0.049112
1,0.660085,-0.41182,2.13223,-4.078625,-1.024646,3.650779,3.14976,-5.575364,0.577942,-0.278277,...,-0.046324,-0.700149,-0.310466,9,-0.709105,0.085514,1,-0.130967,0.001012,0.045747
2,13.293064,3.452299,-5.101409,-3.4598,0.248678,0.371149,1.881885,-2.333594,0.43003,2.331941,...,-0.615356,-2.6874,-1.101635,12,-4.732382,5.286013,2,-0.133182,0.000788,0.018643
3,-4.962597,-1.0948,-1.184917,12.477499,0.852307,-2.062974,-0.705089,-7.453671,-0.3262,-1.523405,...,2.053468,-0.682838,1.046343,1,6.900117,4.495668,3,-0.155867,-0.009243,-0.116312
4,-9.520083,0.326495,2.546679,-5.504425,-0.62128,0.145227,-0.697787,-4.823132,3.346006,0.645233,...,0.951622,-0.019756,-0.619416,3,2.419586,-2.063457,4,-0.159265,-0.004439,-0.25279


## Let's get that view on

Please notice the 'recluster' callback method which integrates the UI events(data selection and commands) to trigger the recomputations using RAPIDs.

In [12]:
# Please change the proxy_port if the port is already in use.
proxy_port = 8080

app = JupyterDash(__name__, 
                  external_stylesheets=['https://codepen.io/chriddyp/pen/bWLwgP.css', dbc.themes.BOOTSTRAP], 
                  requests_pathname_prefix='/proxy/' + str(proxy_port) + '/')
# external_stylesheets = ['https://codepen.io/chriddyp/pen/bWLwgP.css', dbc.themes.BOOTSTRAP]
# app = dash.Dash(__name__, external_stylesheets=external_stylesheets)

fig = generate_main_graph(tdf)
fig_ace = graph_scatter(tdf, MARKERS[0])
fig_tmp = graph_scatter(tdf, MARKERS[1])
fig_epc = graph_scatter(tdf, MARKERS[2])

app.layout = html.Div(id='site_body', children=[
    html.Div(className='row', children=[
        html.Div([
            dcc.Graph(id='basic-interactions', figure=fig),
        ], className='nine columns', style={'verticalAlign': 'text-top',}),

        html.Div([
            html.Div(className='row', children=[
                dbc.Button("Directions", id="bt_open_directions"),
                dbc.Modal([
                        dbc.ModalHeader("Directions"),
                        dbc.ModalBody(
                            dcc.Markdown("""
                                The main scatterplot shows the UMAP visualization of single cells. 
                                
                                ### Re-running Clustering and Visualization
                                #### Reclusering by Clicking on Groups:
                                1. Click on any point in a Cluster of Interest. The cluster to which that point belongs to will populate the Cluster field .  
                                2. Click **Recluster on Selected Cluster**.
                                #### Reclustering by Selecting Points: 
                                1. Use the **Box Select** or **Lasso Select** tool to select your points of interest. A number of points will populate the inthe Selected Points field .  
                                2. Click **Recluster on Selected Points**.
                                
                                You can also manually add either Cluster number or selected points by typing it in. To do so, enter one or more numbers, e.g. “1” or “1,2,3” in the text box above the button for the operation you want to perform.
                                
                                ### Using the Toolbar
                                Hover the mouse over the top right corner of the screen to see a toolbar. Hover over each tool to see its name. The tool options from left to right are:
                                - **Camera:** download a snapshot of the current view as .png
                                - **Zoom:** Click and drag to select a region of the plot to zoom into
                                - **Pan:** Click and drag to shift the current view to a different region of the plot
                                - **Box Select/Lasso Select:** both these tools can be used to select a region on the plot. The selected points are exported under ‘selection data’. See below to export the selected points to a dataframe.
                                - **Zoom In/Zoom Out:** Zoom in and out centered on the current view.
                            """),
                        ),
                        dbc.ModalFooter(
                            dbc.Button("Close", id="bt_close_directions", className="ml-auto")
                        ),
                    ], id="md_directions"),
            ]),
            
            html.Div(className='row', children=[
                dcc.Markdown("""
                    **Click Data**

                    Click on points in the graph.
                """),
            ], style={'marginTop': 18,}),
            
            html.Div(className='row', children=[
                dcc.Input(id='cluster_id', type='text', style={'width': '80%',}),
            ]),
            
            html.Div(className='row', children=[
                dbc.Button('Recluster on Selected Clusters', id='recluster_by_id', n_clicks=0),
            ], style={'marginTop': 6,}),
            
            html.Div(className='row', children=[
                dcc.Markdown("""
                    **Selection Data**

                    Choose the lasso or rectangle tool in the graph's menu
                    bar and then select points in the graph.
                """),
            ], style={'marginTop': 18,}),
            html.Div(className='row', children=[
                html.Div(dcc.Input(id='point_index', type='text')),
            ]),
            
            html.Div(className='row', children=[html.Div(id='point_index_cnt'),]),
            
            html.Div(className='row', children=[
                dbc.Button('Recluster on Selected Points', id='recluster_by_point', n_clicks=0),
            ], style={'marginTop': 6,}),

            html.Div(className='row', children=[dbc.Button("Exit", id="bt_close"),], style={'marginTop': 6,}),

        ], className='three columns', style={'marginTop': 90, 'verticalAlign': 'text-top',}),
    ]),

    html.Div(className='row', children=[
        html.Div([dcc.Graph(id='ace-interactions',figure=fig_ace)], className='four columns'),
        html.Div([dcc.Graph(id='tmp-interactions',figure=fig_tmp)], className='four columns'),
        html.Div([dcc.Graph(id='epc-interactions',figure=fig_epc)], className='four columns')
    ]),
])


@app.callback(Output("site_body", "children"), Input('bt_close', 'n_clicks'))
def close_app(export_clicks):
    if not dash.callback_context.triggered:
        raise dash.exceptions.PreventUpdate

    func = request.environ.get('werkzeug.server.shutdown')
    if func is None:
        raise RuntimeError('Not running with the Werkzeug Server')
    func()
    return 'Closed'


@app.callback(
    Output("md_directions", "is_open"),
    [Input("bt_open_directions", "n_clicks"), Input("bt_close_directions", "n_clicks")],
    [State("md_directions", "is_open")])
def toggle_directions_dialog(n1, n2, is_open):
    if n1 or n2:
        return not is_open
    return is_open


@app.callback(
    [Output('cluster_id', 'value'),
     Output('point_index_cnt', 'children'),
     Output('point_index', 'value')],
    [Input('basic-interactions', 'clickData'),
     Input('basic-interactions', 'selectedData')],
    [State("cluster_id", "value"), 
     State('point_index', 'value')])
def display_selected_data(clicked_cluster, selected_point_index, 
                          selected_clusters, point_index):
    if not dash.callback_context.triggered:
        raise dash.exceptions.PreventUpdate

    comp_id, event_type = dash.callback_context.triggered[0]['prop_id'].split('.')
    print('Event %s from %s' % (event_type, comp_id))
    
    cluster_id = ''
    point_cnt_str = ''
    point_indexes = ''
    
    if comp_id == 'basic-interactions' and event_type == 'clickData':
        # Event - On selecting cluster on the main scatter plot
        if not selected_clusters:
            selected_labels = []
        else:
            selected_labels = list(map(int, selected_clusters.split(","))) 
        points = clicked_cluster['points']
        for point in points:
            selected_label = point['text']
            if selected_label in selected_labels:
                selected_labels.remove(selected_label)
            else:
                selected_labels.append(selected_label)
        cluster_id = ','.join(map(str, selected_labels))
        
    elif comp_id == 'basic-interactions' and event_type == 'selectedData':
        # Event - On selection on the main scatterplot
        print('Event - On selecting points on the main scatterplot')
        if not selected_point_index:
            raise dash.exceptions.PreventUpdate

        selected_point_indexes = []
        for point in selected_point_index['points']:
            selected_point_indexes.append(point['customdata'])
            
        if len(selected_point_indexes) <= 1:
            raise dash.exceptions.PreventUpdate
            
        point_cnt_str = str(len(selected_point_indexes)) + ' points selected'
        point_indexes = ', '.join(map(str, selected_point_indexes))
     
    else:
        print('Unhandled event')
        raise dash.exceptions.PreventUpdate
        
    return cluster_id, point_cnt_str, point_indexes #, fig_ace, fig_tmp, fig_epc


@app.callback([Output('basic-interactions', 'figure'),
               Output('ace-interactions', 'figure'),
               Output('tmp-interactions', 'figure'),
               Output('epc-interactions', 'figure')],
              [Input('recluster_by_id', 'n_clicks'),
               Input('recluster_by_point', 'n_clicks')],
              [State("cluster_id", "value"), 
               State('point_index', 'value')])
def recluster(recluster_by_id, recluster_by_point, 
              cluster_id, point_index):
    global tdf, fig_ace, fig_tmp, fig_epc

    if not dash.callback_context.triggered:
        raise dash.exceptions.PreventUpdate
    comp_id, event_type = dash.callback_context.triggered[0]['prop_id'].split('.')
    print('Event %s from %s' % (event_type, comp_id))

    if comp_id == 'recluster_by_point' and event_type == 'n_clicks':
        if not point_index:
            raise dash.exceptions.PreventUpdate
        # Event - On click 'recluster' button
        selected_point_indexes = list(map(int, point_index.split(",")))
        
        tdf, fig, fig_ace, fig_tmp, fig_epc = update_scatter_plots_by_points(tdf, 
                                                                             selected_point_indexes, 
                                                                             additional_cols=additional_cols)
    elif comp_id == 'recluster_by_id' and event_type == 'n_clicks':
        if not cluster_id:
            raise dash.exceptions.PreventUpdate
        # Event - On click 'recluster' button
        selected_cluster_ids = list(map(int, cluster_id.split(",")))
        tdf, fig, fig_ace, fig_tmp, fig_epc = update_scatter_plots_by_cluster_ids(tdf, 
                                                                                  selected_cluster_ids,
                                                                                  additional_cols=additional_cols)
    else:
        print('Unhandled event')
        raise dash.exceptions.PreventUpdate
    return fig, fig_ace, fig_tmp, fig_epc

js = "<b style='color: red'>Please click on <a href='/proxy/" + str(proxy_port) + "/' target='_blank'>here</a> to open the dash</b>"
display(HTML(js))

srv = app.run_server(debug=True, use_reloader=False, port=proxy_port)
srv

Dash app running on http://127.0.0.1:8080/proxy/8080/
Event selectedData from basic-interactions
Event - On selecting points on the main scatterplot
Event selectedData from basic-interactions
Event - On selecting points on the main scatterplot
