# Visualize AUC and enrichment in 3D
Ilya Balabin <ibalabin@avicenna-bio.com>

In [None]:
import os, sys
import numpy as np
import pandas as pd

from dash import Dash, dcc, html, callback, Output, Input
import dash_bootstrap_components as dbc
from dash import dash_table as dt

# import plotly.express as px
import plotly
import plotly.graph_objs as go
from plotly.subplots import make_subplots

# import warnings
# warnings.filterwarnings('ignore')

f_in = 'assembled.csv.gz'

In [None]:
# read_data
df = pd.read_csv(f_in)

# fix plotly rainbow colorscale
myrainbow = plotly.colors.PLOTLY_SCALES["Rainbow"].copy()
myrainbow[0] = [0.0, 'rgb(0,0,150)']

In [None]:
def View3D(df, sizex=900, sizey=750):

    container = html.Div([
        html.Hr(),
        dbc.Row([
            dbc.Col([
                html.Br(), 
                html.H1(
                    '', 
                    ), 
                html.Br(), 
            ])
        ]),
        dbc.Row([
            dbc.Col([
                
                html.Div([
                    html.H4('Metric'), 
                    dcc.RadioItems(
                        id='metric', 
                        options=[ 
                            {'label': html.Span("ROC AUC", style={'padding-right': 27}), 'value':'AUC'}, 
                            {'label': html.Span("Enrichment", style={'padding-right': 0}), 'value':'Enr'}, 
                        ], 
                        value='AUC', 
                        inline=True, 
                    )], id='metric-div', style= {'display': 'block'}), 
                                
            ], width='auto'), 
            
            dbc.Col([
                dcc.Graph(figure={}, id='graph'), 
            ], width='auto')
        ]),
    ])
    
    # App layout
    app = Dash(__name__, external_stylesheets=[dbc.themes.UNITED])
    app.layout = container
    
    # master callback
    @app.callback(
        Output("graph", component_property="figure"), 
        Input('metric', 'value'), 
    )
    
    def update_graph(metric, sizex=sizex, sizey=sizey):

        # parse data
        nsimmin, nsimmax, vmin, vmax = df.Nsim.min(), df.Nsim.max(), df.V.min(), df.V.max()
        
        # relevant columns
        col_mean, col_std = '_'.join([metric, 'bin_mean']), '_'.join([metric, 'bin_std'])       
        df_one = df.pivot(index='Nsim', columns='V', values=col_mean).T
        
        fig = go.Figure()
        cbar_x, cbar_y = 1.05, 0.55
        fig.add_trace(
            go.Contour(
                x=df_one.columns, y=df_one.index, z=df_one, 
                colorscale=myrainbow, 
                ncontours = 25, 
                colorbar=dict(len=0.315, x=cbar_x, y=cbar_y), 
                hovertemplate='V=%{y:.1f}, Nsim=%{x:.d}:    %{z:.3f}<extra></extra>', 
            ))

        # mark maxima with markers
        df_two = df[df[col_mean]==df[col_mean].max()]
        text = ["%.3f" % df[col_mean].max()] * len(df_two)
        fig.add_trace(
            go.Scatter(
                x=df_two.Nsim, y=df_two.V, showlegend=False, hovertemplate=None, hoverinfo='skip', 
                mode='markers', marker_size=20, marker_line_width=3, marker_symbol='hash', marker_color='black', 
            ))
        
        # figure for the paper
        fig.update_layout(width=sizex, height=sizey, autosize=True, hovermode='closest', 
                          xaxis_range=[nsimmin, nsimmax], yaxis_range=[vmin, vmax], 
                          title='%s average' % metric, 
                          xaxis_title='Nsim', yaxis_title='T (nM)', font=dict(size=18),)
        
        return fig
        
    app.run(debug=False, host='0.0.0.0', port=8873)
    return

View3D(df)