In [1]:
import os
import sys
sys.path.append(os.path.join(os.getcwd(), '..'))

import yaml

import numpy as np
# import matplotlib
from matplotlib import pyplot as plt
import plotly
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from  plotly import colors
import pandas as pd

from arsenic import plotting, stats

from svgutils import transform as sg

from IPython.core.display import HTML
from scipy.stats import norm

from PLBenchmarks import targets, ligands, edges

from tqdm.notebook import tqdm
import pint
unit_registry = pint.UnitRegistry()

from ipywidgets import widgets, interact

import benchmarkpl
path = benchmarkpl.__path__[0]
targets.set_data_dir(path)



In [2]:
all_edges = pd.read_csv('../03_comparison_experiment/03a_all_edges_all_ffs.csv', index_col=0)
all_edges.head()

Unnamed: 0,target,edge,ligandA,ligandB,unit,DDG_Exp.,dDDG_Exp.,DDG_OpenFF-1.0,dDDG_OpenFF-1.0,DDG_OpenFF-1.0_converged,...,error_GAFF2,abserror_GAFF2,error_cGenFF,abserror_cGenFF,error_Consensus_OpenFF_GAFF2_cGenFF,abserror_Consensus_OpenFF_GAFF2_cGenFF,error_Consensus_OpenFF_GAFF2,abserror_Consensus_OpenFF_GAFF2,error_Consensus_all,abserror_Consensus_all
jnk1_edge_17124-1_18631-1,jnk1,edge_17124-1_18631-1,17124-1,18631-1,kilocalories / mole,0.26,0.37,1.19,0.096086,1.19,...,1.071262,1.071262,0.516769,0.516769,0.646112,0.646112,0.784876,0.784876,1.222263,1.222263
jnk1_edge_17124-1_18634-1,jnk1,edge_17124-1_18634-1,17124-1,18634-1,kilocalories / mole,-0.33,0.29,0.58,0.128639,0.58,...,0.829522,0.829522,0.580956,0.580956,0.852556,0.852556,0.928604,0.928604,0.798413,0.798413
jnk1_edge_18626-1_18624-1,jnk1,edge_18626-1_18624-1,18626-1,18624-1,kilocalories / mole,0.38,0.21,0.556667,0.099301,0.556667,...,0.745717,0.745717,-0.265277,0.265277,0.309516,0.309516,0.616033,0.616033,0.446727,0.446727
jnk1_edge_18626-1_18625-1,jnk1,edge_18626-1_18625-1,18626-1,18625-1,kilocalories / mole,0.77,0.21,-0.03,0.107462,-0.03,...,-0.062543,0.062543,-0.294379,0.294379,-0.388337,0.388337,-0.155679,0.155679,0.143932,0.143932
jnk1_edge_18626-1_18627-1,jnk1,edge_18626-1_18627-1,18626-1,18627-1,kilocalories / mole,0.39,0.22,0.14,0.046151,0.14,...,0.0426,0.0426,-0.232256,0.232256,-0.020344,0.020344,0.064101,0.064101,-0.12406,0.12406


In [3]:
identifiers = [idx[4:] for idx in all_edges.columns if idx.startswith("DDG")]
identifiers

['Exp.',
 'OpenFF-1.0',
 'OpenFF-1.0_converged',
 'OpenFF-1.2',
 'OpenFF-1.2_converged',
 'OpenFF-2.0',
 'OpenFF-2.0_converged',
 'OPLS3e',
 'GAFF2',
 'cGenFF',
 'Consensus_OpenFF_GAFF2_cGenFF',
 'Consensus_OpenFF_GAFF2',
 'Consensus_all']

### Function to create correlation plots

In [4]:
def prepare_ax(fig, ax_min, ax_max, col=None, row=None, guidelines=True, origins=True):
    if origins:
        # x= 0
        fig.add_trace(go.Scatter(x=[0, 0],
                                 y=[ax_min, ax_max],
                                 line_color='black',
                                 mode='lines',
                                 showlegend=False
                                 ),
                  row = row,
                  col =col
                     )
        # y = 0
        fig.add_trace(go.Scatter(x=[ax_min, ax_max],
                                 y=[0, 0],
                                 line_color='black',
                                 mode='lines',
                                 showlegend=False
                                 ),
                  row = row,
                  col =col
                     )


    if guidelines:
        small_dist = 0.5
        fig.add_trace(go.Scatter(x=[ax_min, ax_max, ax_max, ax_min],
                                 y=[ax_min + 2. * small_dist, ax_max + 2. * small_dist, ax_max - 2. * small_dist,
                                    ax_min - 2. * small_dist],
                                 name='1 kcal/mol margin',
                                 hoveron='points+fills',
                                 hoverinfo='name',
                                 fill='toself',
                                 mode='lines', line_width=0, fillcolor='rgba(0, 0, 0, 0.2)',
                                 showlegend=False),
                  row = row,
                  col =col
                     )

        fig.add_trace(go.Scatter(x=[ax_min, ax_max, ax_max, ax_min],
                                 y=[ax_min + small_dist, ax_max + small_dist, ax_max - small_dist, ax_min - small_dist],
                                 name='.5 kcal/mol margin',
                                 hoveron='points+fills',
                                 hoverinfo='name',
                                 fill='toself',
                                 mode='lines', line_width=0, fillcolor='rgba(0, 0, 0, 0.2)',
                                 showlegend=False),
                  row = row,
                  col =col
                     )

    # diagonal
    fig.add_trace(go.Scatter(x=[ax_min, ax_max],
                     y=[ax_min, ax_max],
                     line_color='black',
                     mode='lines',
                     showlegend=False
                            ),
                  row = row,
                  col =col
                 )

    fig.update_xaxes(range=[ax_min, ax_max],
                  row = row,
                  col =col
                    )
    fig.update_yaxes(range=[ax_min, ax_max],
                  row = row,
                  col =col
                    )


In [5]:
def correlation_plot_ddg(x, y, dx, dy, title='', name1='', name2=''):
    ax_min = -7
    ax_max = 7
    fig = go.Figure()
    prepare_ax(fig, ax_min, ax_max, col=None, row=None, guidelines=True, origins=True)
    # x values          
    fig.add_trace(go.Scatter(x=x, y=y,
                             mode='markers',
                             name=f'{target}',
                             marker=dict(
                                 symbol='circle',
                                 size=3,
#                                  color=clr,
#                                  colorscale='BlueRed'
                             ),
                             error_x=dict(
                                 type='data',  # value of error bar given in data coordinates
                                 array=dx,
                                 visible=True),
                             error_y=dict(
                                 type='data',  # value of error bar given in data coordinates
                                 array=dy,
                                 visible=True),
                             showlegend=False
                             )
                    )        

    #stats and title
    nsamples = len(x)
    statistics = ['RMSE','MUE']
    string = []
    for statistic in statistics:
        s = stats.bootstrap_statistic(np.array(x), np.array(y), statistic=statistic)
        string.append(
            f"{statistic + ':':5s}{s['mle']:5.2f} [95%: {s['low']:5.2f}, {s['high']:5.2f}]")
    statistics_string = '<br>'.join(string)

    long_title = f'{title} (N = {nsamples})<br>{statistics_string}'
    
    fig.update_layout(
        template="simple_white",
        title=dict(
            text=long_title,
    #         font_family='monospace',
            xref='paper',
            yref='container',
            xanchor='left',
            yanchor='top',
            x=0.0,
            y=0.9,
            font_size=12,
        ),
        xaxis=dict(
            title=f'DDG {name1} [kcal/mol]',
            titlefont_size=14,
            tickfont_size=12,
            range=(ax_min, ax_max)
        ),
        yaxis=dict(
            title=f'DDG {name2} [kcal/mol]',
            titlefont_size=14,
            tickfont_size=12,
            range=(ax_min, ax_max)
        ),
        width=400,
        height=400
    )

    return fig

### Create correlation plots for DDG

In [6]:
import plotly


for j, idx in enumerate(identifiers[1:]):
    data_x = []
    data_y = []
    data_dx = []
    data_dy = []
    if idx == 'GAFF2':
        ff = "GAFF2.1x"   
    elif idx == 'cGenFF':
        ff = "CGenFF/MATCH*"
    elif idx == 'Consensus_OpenFF_GAFF2_cGenFF':
        ff = "Consensus"
    else:
        ff = idx
    for i, target in enumerate(targets.target_dict.keys()):
        ddf = all_edges[["target", f"DDG_Exp.", f"DDG_{idx}", f"dDDG_Exp.", f"dDDG_{idx}"]]
        ddf = ddf.dropna(axis=0)
        tdata_x = ddf.loc[ddf["target"]==target, f"DDG_Exp."]
        tdata_y = ddf.loc[ddf["target"]==target, f"DDG_{idx}"]
        tdata_dx = ddf.loc[ddf["target"]==target, f"dDDG_Exp."]
        tdata_dy = ddf.loc[ddf["target"]==target, f"dDDG_{idx}"]
        data_x += list(tdata_x)
        data_y += list(tdata_y)
        data_dx += list(tdata_dx)
        data_dy += list(tdata_dy)
        if len(tdata_x) > 1:
            fig = correlation_plot_ddg(
                tdata_x, 
                tdata_y, 
                tdata_dx, 
                tdata_dy, 
                title=target, 
                name1=f'Exp.', 
                name2=f'{ff}'
            )
            fig.write_image(f'04e_corr_ddg_{target}_{idx}.svg')
    fig = correlation_plot_ddg(
        data_x, 
        data_y, 
        data_dx, 
        data_dy, 
        title="All targets", 
        name1=f'Exp.', 
        name2=f'{ff}'
    )
    fig.write_image(f'04e_corr_ddg_all_{idx}.svg')

In [7]:
identifiers

['Exp.',
 'OpenFF-1.0',
 'OpenFF-1.0_converged',
 'OpenFF-1.2',
 'OpenFF-1.2_converged',
 'OpenFF-2.0',
 'OpenFF-2.0_converged',
 'OPLS3e',
 'GAFF2',
 'cGenFF',
 'Consensus_OpenFF_GAFF2_cGenFF',
 'Consensus_OpenFF_GAFF2',
 'Consensus_all']

In [36]:
import plotly

identifiers_to_plot = [
 'OpenFF-2.0',
 'OPLS3e',
 'GAFF2',
 'cGenFF',
 'Consensus_OpenFF_GAFF2_cGenFF',
                            ]
guidelines = True
origins = True
titles = {}
plot_type='ddg'
method_name='pmx'
ax_min=-6.5
ax_max=+6.5
long_title = ''#DDG calc vs. exp'
subplot_titles = []
# for target in targets.target_dict.keys():
#     for idx in identifiers[1:]:
#         subplot_titles.append(f'{target}-{idx}')

cols = len(identifiers_to_plot)
rows = 6#len(targets.target_dict)
ratio = float(cols/rows)

for part in range(4):
    fig = plotly.subplots.make_subplots(cols = cols,
                                    rows = rows,
                                    shared_xaxes=True,
                                    shared_yaxes=True,
                                    x_title = f'Experimental {plot_type} [kcal mol<sup>-1</sup>]',
                                    y_title=f'Calculated {plot_type} [kcal mol<sup>-1</sup>]',
                                    vertical_spacing=0.005*ratio,
                                    horizontal_spacing=0.005,
#                                     subplot_titles = subplot_titles,
                                    row_titles=list(targets.target_dict.keys())[part*6:(part+1)*6],
                                    column_titles=[idx.split('_')[0] for idx in identifiers_to_plot],
                                   )
    for i, target in enumerate(list(targets.target_dict.keys())[part*6:(part+1)*6]):
        for j, idx in enumerate(identifiers_to_plot):
            ddf = all_edges[["target", f"DDG_Exp.", f"DDG_{idx}", f"dDDG_Exp.", f"dDDG_{idx}"]]
            ddf = ddf.dropna(axis=0)
            x = ddf.loc[ddf["target"]==target, f"DDG_Exp."]
            y = ddf.loc[ddf["target"]==target, f"DDG_{idx}"]
            xerr = ddf.loc[ddf["target"]==target, f"dDDG_Exp."]
            yerr = ddf.loc[ddf["target"]==target, f"dDDG_{idx}"]

# stats and title
#             nsamples = len(x)
#             statistics = ['RMSE', 'MUE']
#             string = []
#             for statistic in statistics:
            s = stats.bootstrap_statistic(x, y, statistic='RMSE')
            string = f"RMSE:{s['mle']:.1f}[{s['low']:.1f},{s['high']:.1f}]"
#             statistics_string = '<br>'.join(string)
            fig.add_trace(go.Scatter(
                x=[ax_min+.5],
                y=[ax_max-1],
                mode="text",
                text=[string],
                textfont_size=6,
                textposition="bottom right",
                showlegend=False
                ),
                row = i+1,
                col = j+1
            )
#             titles[f'{target}-{idx}'] = f'{target} {idx} (N = {nsamples})<br>{statistics_string}'

            if origins:
                # x=0
                fig.add_trace(go.Scatter(x=[0, 0],
                                         y=[ax_min, ax_max],
                                         line_color='black',
                                         line_width=.5,
                                         mode='lines',
                                         showlegend=False
                                         ),
                          row = i+1,
                          col = j+1
                             )
                # y =0
                fig.add_trace(go.Scatter(x=[ax_min, ax_max],
                                         y=[0, 0],
                                         line_color='black',
                                         line_width=.5,
                                         mode='lines',
                                         showlegend=False
                                         ),
                          row = i+1,
                          col = j+1
                             )


            if guidelines:
                small_dist = 0.5
                fig.add_trace(go.Scatter(x=[ax_min, ax_max, ax_max, ax_min],
                                         y=[ax_min + 2. * small_dist, ax_max + 2. * small_dist, ax_max - 2. * small_dist,
                                            ax_min - 2. * small_dist],
                                         name='1 kcal/mol margin',
                                         hoveron='points+fills',
                                         hoverinfo='name',
                                         fill='toself',
                                         mode='lines', line_width=0, fillcolor='rgba(0, 0, 0, 0.2)',
                                         showlegend=False),
                          row = i+1,
                          col = j+1
                             )

                fig.add_trace(go.Scatter(x=[ax_min, ax_max, ax_max, ax_min],
                                         y=[ax_min + small_dist, ax_max + small_dist, ax_max - small_dist, ax_min - small_dist],
                                         name='.5 kcal/mol margin',
                                         hoveron='points+fills',
                                         hoverinfo='name',
                                         fill='toself',
                                         mode='lines', line_width=0, fillcolor='rgba(0, 0, 0, 0.2)',
                                         showlegend=False),
                          row = i+1,
                          col = j+1
                             )

            # diagonal
            fig.add_trace(go.Scatter(x=[ax_min, ax_max],
                             y=[ax_min, ax_max],
                             line_color='black',
                            line_width=.5,
                             mode='lines',
                             showlegend=False
                             ),
                          row = i+1,
                          col = j+1
                         )
            clr = np.abs(x - y) / 2.372
            fig.add_trace(go.Scatter(x=x, y=y,
                                 mode='markers',
                                 name=f'{target}',
                                 marker=dict(
                                     symbol='circle',
                                     color='blue',
                                     size=3
#                                     colorscale='BlueRed'
                                 ),
#                                  error_x=dict(
#                                      type='data',  # value of error bar given in data coordinates
#                                      array=xerr,
#                                      visible=True),
#                                  error_y=dict(
#                                      type='data',  # value of error bar given in data coordinates
#                                      array=yerr,
#                                      visible=True),
                                 showlegend=False
                                 ),
                      row = i+1,
                      col = j+1
                        )

            fig.update_xaxes(range=[ax_min, ax_max],
                          row = i+1,
                          col = j+1
                            )
            fig.update_yaxes(range=[ax_min, ax_max],
                          row = i+1,
                          col = j+1
                            )
    plot_type='ddg'
    method_name='pmx'
    long_title = 'DDG calc vs. exp'
    # for i, annot in enumerate(fig.layout.annotations):
    #     key = annot['text']
    #     if key in titles:
    #         fig['layout']['annotations'][i].update(text=titles[key], font_size=6)
    #     else:
    #         fig['layout']['annotations'][i].update(font_size=6)
    fig.update_layout(
        template="simple_white",
    #     title=dict(
    #         text=long_title,
    # #         font_family='monospace',
    #         x=0.5,
    #         y=1.0,
    #         font_size=18#,
    # #         vertical_alignment=bottom
    #     ),
        xaxis=dict(
            titlefont_size=8,
            tickfont_size=6,
            range=(ax_min, ax_max)
        ),
        yaxis=dict(
            titlefont_size=8,
            tickfont_size=6,
            range=(ax_min, ax_max)
        ),
        width=500,
        height=500 / ratio,
    )

    fig.update_annotations(font_size=8)

    fig.write_image(f'04e_ddg_ff_vs_exp_{part}.png', scale=4)

In [37]:
identifiers

['Exp.',
 'OpenFF-1.0',
 'OpenFF-1.0_converged',
 'OpenFF-1.2',
 'OpenFF-1.2_converged',
 'OpenFF-2.0',
 'OpenFF-2.0_converged',
 'OPLS3e',
 'GAFF2',
 'cGenFF',
 'Consensus_OpenFF_GAFF2_cGenFF',
 'Consensus_OpenFF_GAFF2',
 'Consensus_all']

In [38]:
identifiers_to_plot = [
                     'OpenFF-1.0',
                     'OpenFF-1.2',
                     'OpenFF-2.0',
                        ]
targets_to_plot = ['cmet', 'eg5', 'cdk8', 'hif2a', 'pfkfb3', 'shp2', 'syk', 'tnks2']
guidelines = True
origins = True
titles = {}
plot_type='ddg'
method_name='pmx'
ax_min=-6.5
ax_max=+6.5
long_title = ''#DDG calc vs. exp'
subplot_titles = []
# for target in targets.target_dict.keys():
#     for idx in identifiers[1:]:
#         subplot_titles.append(f'{target}-{idx}')

cols = len(identifiers_to_plot)
rows = 6#len(targets.target_dict)
ratio = float(cols/rows)

for part in range(4):
    fig = plotly.subplots.make_subplots(cols = cols,
                                    rows = rows,
                                    shared_xaxes=True,
                                    shared_yaxes=True,
                                    x_title = f'Experimental {plot_type} [kcal mol<sup>-1</sup>]',
                                    y_title=f'Calculated {plot_type} [kcal mol<sup>-1</sup>]',
                                    vertical_spacing=0.005*ratio,
                                    horizontal_spacing=0.005,
#                                     subplot_titles = subplot_titles,
                                    row_titles=list(targets_to_plot)[part*6:(part+1)*6],
                                    column_titles=[idx for idx in identifiers_to_plot]
                                   )
    for i, target in enumerate(targets_to_plot[part*6:(part+1)*6]):
        for j, idx in enumerate(identifiers_to_plot):
            ddf = all_edges[["target", f"DDG_Exp.", f"DDG_{idx}", f"dDDG_Exp.", f"dDDG_{idx}"]]
            ddf = ddf.dropna(axis=0)
            x = ddf.loc[ddf["target"]==target, f"DDG_Exp."]
            y = ddf.loc[ddf["target"]==target, f"DDG_{idx}"]
            xerr = ddf.loc[ddf["target"]==target, f"dDDG_Exp."]
            yerr = ddf.loc[ddf["target"]==target, f"dDDG_{idx}"]

# stats and title
#             nsamples = len(x)
#             statistics = ['RMSE', 'MUE']
#             string = []
#             for statistic in statistics:
            s = stats.bootstrap_statistic(x, y, statistic='RMSE')
            string = f"RMSE:{s['mle']:.1f}[{s['low']:.1f},{s['high']:.1f}]"
#             statistics_string = '<br>'.join(string)
            fig.add_trace(go.Scatter(
                x=[ax_min+.2],
                y=[ax_max-.2],
                mode="text",
                text=[string],
                textposition="bottom right",
                showlegend=False
                ),
                row = i+1,
                col = j+1
            )
#             titles[f'{target}-{idx}'] = f'{target} {idx} (N = {nsamples})<br>{statistics_string}'

            if origins:
                # x=0
                fig.add_trace(go.Scatter(x=[0, 0],
                                         y=[ax_min, ax_max],
                                         line_color='black',
                                         line_width=.5,
                                         mode='lines',
                                         showlegend=False
                                         ),
                          row = i+1,
                          col = j+1
                             )
                # y =0
                fig.add_trace(go.Scatter(x=[ax_min, ax_max],
                                         y=[0, 0],
                                         line_color='black',
                                         line_width=.5,
                                         mode='lines',
                                         showlegend=False
                                         ),
                          row = i+1,
                          col = j+1
                             )


            if guidelines:
                small_dist = 0.5
                fig.add_trace(go.Scatter(x=[ax_min, ax_max, ax_max, ax_min],
                                         y=[ax_min + 2. * small_dist, ax_max + 2. * small_dist, ax_max - 2. * small_dist,
                                            ax_min - 2. * small_dist],
                                         name='1 kcal/mol margin',
                                         hoveron='points+fills',
                                         hoverinfo='name',
                                         fill='toself',
                                         mode='lines', line_width=0, fillcolor='rgba(0, 0, 0, 0.2)',
                                         showlegend=False),
                          row = i+1,
                          col = j+1
                             )

                fig.add_trace(go.Scatter(x=[ax_min, ax_max, ax_max, ax_min],
                                         y=[ax_min + small_dist, ax_max + small_dist, ax_max - small_dist, ax_min - small_dist],
                                         name='.5 kcal/mol margin',
                                         hoveron='points+fills',
                                         hoverinfo='name',
                                         fill='toself',
                                         mode='lines', line_width=0, fillcolor='rgba(0, 0, 0, 0.2)',
                                         showlegend=False),
                          row = i+1,
                          col = j+1
                             )

            # diagonal
            fig.add_trace(go.Scatter(x=[ax_min, ax_max],
                             y=[ax_min, ax_max],
                             line_color='black',
                            line_width=.5,
                             mode='lines',
                             showlegend=False
                             ),
                          row = i+1,
                          col = j+1
                         )
            clr = np.abs(x - y) / 2.372
            fig.add_trace(go.Scatter(x=x, y=y,
                                 mode='markers',
                                 name=f'{target}',
                                 marker=dict(
                                     symbol='circle',
                                     color='blue',
                                     size=4
#                                     colorscale='BlueRed'
                                 ),
#                                  error_x=dict(
#                                      type='data',  # value of error bar given in data coordinates
#                                      array=xerr,
#                                      visible=True),
#                                  error_y=dict(
#                                      type='data',  # value of error bar given in data coordinates
#                                      array=yerr,
#                                      visible=True),
                                 showlegend=False
                                 ),
                      row = i+1,
                      col = j+1
                        )

            fig.update_xaxes(range=[ax_min, ax_max],
                          row = i+1,
                          col = j+1
                            )
            fig.update_yaxes(range=[ax_min, ax_max],
                          row = i+1,
                          col = j+1
                            )
    plot_type='ddg'
    method_name='pmx'
    long_title = 'DDG calc vs. exp'
    # for i, annot in enumerate(fig.layout.annotations):
    #     key = annot['text']
    #     if key in titles:
    #         fig['layout']['annotations'][i].update(text=titles[key], font_size=6)
    #     else:
    #         fig['layout']['annotations'][i].update(font_size=6)
    fig.update_layout(
        template="simple_white",
    #     title=dict(
    #         text=long_title,
    # #         font_family='monospace',
    #         x=0.5,
    #         y=1.0,
    #         font_size=18#,
    # #         vertical_alignment=bottom
    #     ),
        xaxis=dict(
            titlefont_size=12,
            tickfont_size=10,
            range=(ax_min, ax_max)
        ),
        yaxis=dict(
            titlefont_size=12,
            tickfont_size=10,
            range=(ax_min, ax_max)
        ),
        width=500,
        height=500 / ratio
    )

    fig.write_image(f'04e_ddg_openff_vs_exp_{part}.png', scale=4)

In [22]:
ax_min=-5
ax_max=+5

def correlation_plot_ddg(target, idx1='Exp.', idx2='OpenFF-2.0'):
    fig = go.Figure()
    prepare_ax(fig, ax_min, ax_max, col=None, row=None, guidelines=True, origins=True)

    ddf = all_edges[["target", f"DDG_{idx1}", f"DDG_{idx2}", f"dDDG_{idx1}", f"dDDG_{idx2}"]]
    ddf = ddf.dropna(axis=0)
    x = ddf.loc[ddf["target"]==target, f"DDG_{idx1}"]
    y = ddf.loc[ddf["target"]==target, f"DDG_{idx2}"]
    xerr = ddf.loc[ddf["target"]==target, f"dDDG_{idx1}"]
    yerr = ddf.loc[ddf["target"]==target, f"dDDG_{idx2}"]
            


#     clr = np.abs(x - y) / 2.372
    fig.add_trace(go.Scatter(x=x, y=y,
                             mode='markers',
                             name=f'{target}',
                             marker=dict(
                                 symbol='circle',
#                                  color=clr,
#                                  colorscale='BlueRed'
                             ),
                             error_x=dict(
                                 type='data',  # value of error bar given in data coordinates
                                 array=xerr,
                                 visible=True),
                             error_y=dict(
                                 type='data',  # value of error bar given in data coordinates
                                 array=yerr,
                                 visible=True),
                             showlegend=False
                             )
                    )        

    #stats and title
    nsamples = len(x)
    statistics = ['RMSE','MUE']
    string = []
    for statistic in statistics:
        s = stats.bootstrap_statistic(x, y, statistic=statistic)
        string.append(
            f"{statistic + ':':5s}{s['mle']:5.2f} [95%: {s['low']:5.2f}, {s['high']:5.2f}]")
    statistics_string = '<br>'.join(string)

    long_title = f'{idx} (N = {nsamples})<br>{statistics_string}'
    
    fig.update_layout(
        template="simple_white",
        title=dict(
            text=long_title,
    #         font_family='monospace',
            xref='paper',
            yref='container',
            xanchor='left',
            yanchor='top',
            x=0.0,
            y=.7,
            font_size=8,
        ),
        xaxis=dict(
            title=f'DDG {idx1} [kcal/mol]',
            titlefont_size=12,
            tickfont_size=10,
            range=(ax_min, ax_max)
        ),
        yaxis=dict(
            title=f'DDG {idx2} [kcal/mol]',
            titlefont_size=12,
            tickfont_size=10,
            range=(ax_min, ax_max)
        ),
        width=400,
        height=400
    )

    return fig

In [18]:
out = interact(correlation_plot_ddg, target=list(targets.target_dict.keys()), idx1=identifiers, idx2=identifiers[1:])

interactive(children=(Dropdown(description='target', options=('jnk1', 'pde2', 'thrombin', 'p38', 'ptp1b', 'gal…