In [1]:
import os
import sys
sys.path.append(os.path.join(os.getcwd(), '..'))

import yaml

import numpy as np
# import matplotlib
from matplotlib import pyplot as plt
import plotly
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from  plotly import colors
import pandas as pd

from arsenic import plotting, stats

from svgutils import transform as sg

from IPython.core.display import HTML
from scipy.stats import norm

from PLBenchmarks import targets, ligands, edges

from tqdm.notebook import tqdm
import pint
unit_registry = pint.UnitRegistry()

from ipywidgets import widgets, interact

import benchmarkpl
path = benchmarkpl.__path__[0]
targets.set_data_dir(path)



In [2]:
all_ligands = pd.read_csv('../03_comparison_experiment/03b_all_ligands_all_ffs.csv', index_col=0)
all_ligands.head()

Unnamed: 0,ligand,target,DG_Exp.,dDG_Exp.,DG_OpenFF-1.0,DG_OpenFF-1.0_converged,DG_OpenFF-1.2,DG_OpenFF-1.2_converged,DG_OpenFF-2.0,DG_OpenFF-2.0_converged,...,dDG_OpenFF-1.2,dDG_OpenFF-1.2_converged,dDG_OpenFF-2.0,dDG_OpenFF-2.0_converged,dDG_OPLS3e,dDG_GAFF2,dDG_cGenFF,dDG_Consensus_OpenFF_GAFF2_cGenFF,dDG_Consensus_OpenFF_GAFF2,dDG_Consensus_all
jnk1_lig_17124-1,17124-1,jnk1,-9.76,0.29,-10.498849,-10.41024,,,-11.647484,-11.641639,...,,,0.126579,0.127663,0.05215,0.193056,0.174543,0.245633,0.299317,0.103591
jnk1_lig_18631-1,18631-1,jnk1,-9.5,0.23,-9.208652,-9.114652,,,-9.480523,-9.474986,...,,,0.094876,0.096155,0.049214,0.402302,0.16823,0.325155,0.320194,0.127863
jnk1_lig_18634-1,18634-1,jnk1,-10.09,0.04,-10.098436,-10.020989,,,-11.048733,-11.048163,...,,,0.126496,0.127282,0.044415,0.132499,0.185769,0.223787,0.261344,0.093516
jnk1_lig_18626-1,18626-1,jnk1,-8.95,0.14,-8.185239,-8.133063,,,-8.077112,-8.077298,...,,,0.080243,0.081041,0.032254,0.107774,0.141012,0.165191,0.146609,0.066499
jnk1_lig_18624-1,18624-1,jnk1,-8.57,0.15,-7.995236,-7.936688,,,-7.067734,-7.068933,...,,,0.07894,0.079947,0.03482,0.1169,0.140377,0.178736,0.151406,0.068561


In [3]:
identifiers = [idx[3:] for idx in all_ligands.columns if idx.startswith("DG")]
identifiers

['Exp.',
 'OpenFF-1.0',
 'OpenFF-1.0_converged',
 'OpenFF-1.2',
 'OpenFF-1.2_converged',
 'OpenFF-2.0',
 'OpenFF-2.0_converged',
 'OPLS3e',
 'GAFF2',
 'cGenFF',
 'Consensus_OpenFF_GAFF2_cGenFF',
 'Consensus_OpenFF_GAFF2',
 'Consensus_all']

### Function to create correlation plots

In [4]:
def prepare_ax(fig, ax_min, ax_max, col=None, row=None, guidelines=True, origins=True):
    if origins:
        # x= 0
        fig.add_trace(go.Scatter(x=[0, 0],
                                 y=[ax_min, ax_max],
                                 line_color='black',
                                 mode='lines',
                                 showlegend=False
                                 ),
                  row = row,
                  col =col
                     )
        # y = 0
        fig.add_trace(go.Scatter(x=[ax_min, ax_max],
                                 y=[0, 0],
                                 line_color='black',
                                 mode='lines',
                                 showlegend=False
                                 ),
                  row = row,
                  col =col
                     )


    if guidelines:
        small_dist = 0.5
        fig.add_trace(go.Scatter(x=[ax_min, ax_max, ax_max, ax_min],
                                 y=[ax_min + 2. * small_dist, ax_max + 2. * small_dist, ax_max - 2. * small_dist,
                                    ax_min - 2. * small_dist],
                                 name='1 kcal/mol margin',
                                 hoveron='points+fills',
                                 hoverinfo='name',
                                 fill='toself',
                                 mode='lines', line_width=0, fillcolor='rgba(0, 0, 0, 0.2)',
                                 showlegend=False),
                  row = row,
                  col =col
                     )

        fig.add_trace(go.Scatter(x=[ax_min, ax_max, ax_max, ax_min],
                                 y=[ax_min + small_dist, ax_max + small_dist, ax_max - small_dist, ax_min - small_dist],
                                 name='.5 kcal/mol margin',
                                 hoveron='points+fills',
                                 hoverinfo='name',
                                 fill='toself',
                                 mode='lines', line_width=0, fillcolor='rgba(0, 0, 0, 0.2)',
                                 showlegend=False),
                  row = row,
                  col =col
                     )

    # diagonal
    fig.add_trace(go.Scatter(x=[ax_min, ax_max],
                     y=[ax_min, ax_max],
                     line_color='black',
                     mode='lines',
                     showlegend=False
                            ),
                  row = row,
                  col =col
                 )

    fig.update_xaxes(range=[ax_min, ax_max],
                  row = row,
                  col =col
                    )
    fig.update_yaxes(range=[ax_min, ax_max],
                  row = row,
                  col =col
                    )


In [5]:
def correlation_plot_dg(x, y, dx, dy, title='', name1='', name2=''):
    ax_min = -17
    ax_max = -0
    fig = go.Figure()
    prepare_ax(fig, ax_min, ax_max, col=None, row=None, guidelines=True, origins=True)
    # x values          
    fig.add_trace(go.Scatter(x=x, y=y,
                             mode='markers',
                             name=f'{target}',
                             marker=dict(
                                 symbol='circle',
                                 size=3,
#                                  color=clr,
#                                  colorscale='BlueRed'
                             ),
                             error_x=dict(
                                 type='data',  # value of error bar given in data coordinates
                                 array=dx,
                                 visible=True),
                             error_y=dict(
                                 type='data',  # value of error bar given in data coordinates
                                 array=dy,
                                 visible=True),
                             showlegend=False
                             )
                    )        

    #stats and title
    nsamples = len(x)
    statistics = ['RMSE','MUE','R2','KTAU']
    string = []
    for statistic in statistics:
        s = stats.bootstrap_statistic(np.array(x), np.array(y), statistic=statistic)
        string.append(
            f"{statistic + ':':5s}{s['mle']:.1f} [95%: {s['low']:.1f}, {s['high']:.1f}]")
    statistics_string = '<br>'.join(string)

    long_title = f'{title} (N = {nsamples})<br>{statistics_string}'
    
    fig.update_layout(
        template="simple_white",
        title=dict(
            text=long_title,
    #         font_family='monospace',
            xref='paper',
            yref='container',
            xanchor='left',
            yanchor='top',
            x=0.0,
            y=0.7,
            font_size=12,
        ),
        xaxis=dict(
            title=f'$\Delta G' + name1 + '\,[\mathrm{kcal\,mol^{-1}]}$',
            titlefont_size=14,
            tickfont_size=12,
            range=(ax_min, ax_max)
        ),
        yaxis=dict(
            title=f'$\Delta G' + name1 + '\,[\mathrm{kcal\,mol^{-1}]}$',
            titlefont_size=14,
            tickfont_size=12,
            range=(ax_min, ax_max)
        ),
        width=400,
        height=400
    )

    return fig

### Create correlation plots for DG

In [6]:
import plotly


for j, idx in enumerate(identifiers[1:]):
    data_x = []
    data_y = []
    data_dx = []
    data_dy = []
    if idx == 'GAFF2':
        ff = "GAFF2.1x"   
    elif idx == 'cGenFF':
        ff = "CGenFF/MATCH*"
    elif idx == 'Consensus_OpenFF_GAFF2_cGenFF':
        ff = "Consensus"
    else:
        ff = idx
    for i, target in enumerate(targets.target_dict.keys()):
        ddf = all_ligands[["target", f"DG_Exp.", f"DG_{idx}", f"dDG_Exp.", f"dDG_{idx}"]]
        ddf = ddf.dropna(axis=0)
        tdata_x = ddf.loc[ddf["target"]==target, f"DG_Exp."]
        tdata_y = ddf.loc[ddf["target"]==target, f"DG_{idx}"]
        tdata_dx = ddf.loc[ddf["target"]==target, f"dDG_Exp."]
        tdata_dy = ddf.loc[ddf["target"]==target, f"dDG_{idx}"]
        data_x += list(tdata_x)
        data_y += list(tdata_y)
        data_dx += list(tdata_dx)
        data_dy += list(tdata_dy)
        if len(tdata_x) > 1:
            fig = correlation_plot_dg(
                tdata_x, 
                tdata_y, 
                tdata_dx, 
                tdata_dy, 
                title=target, 
                name1=f'Exp.', 
                name2=f'{ff}'
            )
            fig.write_image(f'04f_corr_dg_{target}_{idx}.svg')
    fig = correlation_plot_dg(
        data_x, 
        data_y, 
        data_dx, 
        data_dy, 
        title="All targets", 
        name1=f'Exp.', 
        name2=f'{ff}'
    )
    fig.write_image(f'04f_corr_dg_all_{idx}.svg') 

In [7]:
import plotly

identifiers_to_plot = [
 'OpenFF-2.0',
 'OPLS3e',
 'GAFF2',
 'cGenFF',
 'Consensus_OpenFF_GAFF2_cGenFF',
                            ]
guidelines = True
origins = True
titles = {}
plot_type='ddg'
method_name='pmx'
ax_min=-17
ax_max=0
long_title = ''#DDG calc vs. exp'
subplot_titles = []
# for target in targets.target_dict.keys():
#     for idx in identifiers[1:]:
#         subplot_titles.append(f'{target}-{idx}')

cols = len(identifiers_to_plot)
rows = 6#len(targets.target_dict)
ratio = float(cols/rows)

for part in range(4):
    fig = plotly.subplots.make_subplots(cols = cols,
                                    rows = rows,
                                    shared_xaxes=True,
                                    shared_yaxes=True,
                                    x_title = f'Experimental ΔG [kcal mol<sup>-1</sup>]',
                                    y_title=f'Calculated ΔG [kcal mol<sup>-1</sup>]',
                                    vertical_spacing=0.005*ratio,
                                    horizontal_spacing=0.005,
#                                     subplot_titles = subplot_titles,
                                    row_titles=list(targets.target_dict.keys())[part*6:(part+1)*6],
                                    column_titles=[idx for idx in identifiers_to_plot[:-1] + ["Consensus"]]
                                   )
    for i, target in enumerate(list(targets.target_dict.keys())[part*6:(part+1)*6]):
        for j, idx in enumerate(identifiers_to_plot):
            ddf = all_ligands[["target", f"DG_Exp.", f"DG_{idx}", f"dDG_Exp.", f"dDG_{idx}"]]
            ddf = ddf.dropna(axis=0)
            x = ddf.loc[ddf["target"]==target, f"DG_Exp."]
            y = ddf.loc[ddf["target"]==target, f"DG_{idx}"]
            xerr = ddf.loc[ddf["target"]==target, f"dDG_Exp."]
            yerr = ddf.loc[ddf["target"]==target, f"dDG_{idx}"]

# stats and title
#             nsamples = len(x)
#             statistics = ['RMSE', 'MUE']
#             string = []
#             for statistic in statistics:            
            s = stats.bootstrap_statistic(x, y, statistic='RMSE')
            string = f"{'RMSE':5s}:{s['mle']:.1f} [{s['low']:.1f},{s['high']:.1f}]<br>"
            s = stats.bootstrap_statistic(x, y, statistic='KTAU')
            string += f"{'tau':5s}:{s['mle']:.1f} [{s['low']:.1f},{s['high']:.1f}]"
#             statistics_string = '<br>'.join(string)
            fig.add_trace(go.Scatter(
                x=[ax_min+.2],
                y=[ax_max-.2],
                mode="text",
                text=[string],
                textposition="bottom right",
                showlegend=False
                ),
                row = i+1,
                col = j+1
            )
#             titles[f'{target}-{idx}'] = f'{target} {idx} (N = {nsamples})<br>{statistics_string}'

            if origins:
                # x=0
                fig.add_trace(go.Scatter(x=[0, 0],
                                         y=[ax_min, ax_max],
                                         line_color='black',
                                         line_width=.5,
                                         mode='lines',
                                         showlegend=False
                                         ),
                          row = i+1,
                          col = j+1
                             )
                # y =0
                fig.add_trace(go.Scatter(x=[ax_min, ax_max],
                                         y=[0, 0],
                                         line_color='black',
                                         line_width=.5,
                                         mode='lines',
                                         showlegend=False
                                         ),
                          row = i+1,
                          col = j+1
                             )


            if guidelines:
                small_dist = 0.5
                fig.add_trace(go.Scatter(x=[ax_min, ax_max, ax_max, ax_min],
                                         y=[ax_min + 2. * small_dist, ax_max + 2. * small_dist, ax_max - 2. * small_dist,
                                            ax_min - 2. * small_dist],
                                         name='1 kcal/mol margin',
                                         hoveron='points+fills',
                                         hoverinfo='name',
                                         fill='toself',
                                         mode='lines', line_width=0, fillcolor='rgba(0, 0, 0, 0.2)',
                                         showlegend=False),
                          row = i+1,
                          col = j+1
                             )

                fig.add_trace(go.Scatter(x=[ax_min, ax_max, ax_max, ax_min],
                                         y=[ax_min + small_dist, ax_max + small_dist, ax_max - small_dist, ax_min - small_dist],
                                         name='.5 kcal/mol margin',
                                         hoveron='points+fills',
                                         hoverinfo='name',
                                         fill='toself',
                                         mode='lines', line_width=0, fillcolor='rgba(0, 0, 0, 0.2)',
                                         showlegend=False),
                          row = i+1,
                          col = j+1
                             )

            # diagonal
            fig.add_trace(go.Scatter(x=[ax_min, ax_max],
                             y=[ax_min, ax_max],
                             line_color='black',
                            line_width=.5,
                             mode='lines',
                             showlegend=False
                             ),
                          row = i+1,
                          col = j+1
                         )
            clr = np.abs(x - y) / 2.372
            fig.add_trace(go.Scatter(x=x, y=y,
                                 mode='markers',
                                 name=f'{target}',
                                 marker=dict(
                                     symbol='circle',
                                     color='black',
                                     size=6
#                                     colorscale='BlueRed'
                                 ),
#                                  error_x=dict(
#                                      type='data',  # value of error bar given in data coordinates
#                                      array=xerr,
#                                      visible=True),
#                                  error_y=dict(
#                                      type='data',  # value of error bar given in data coordinates
#                                      array=yerr,
#                                      visible=True),
                                 showlegend=False
                                 ),
                      row = i+1,
                      col = j+1
                        )

            fig.update_xaxes(range=[ax_min, ax_max],
                          row = i+1,
                          col = j+1
                            )
            fig.update_yaxes(range=[ax_min, ax_max],
                          row = i+1,
                          col = j+1
                            )
    plot_type='dg'
    method_name='pmx'
    long_title = 'DG calc vs. exp'
    # for i, annot in enumerate(fig.layout.annotations):
    #     key = annot['text']
    #     if key in titles:
    #         fig['layout']['annotations'][i].update(text=titles[key], font_size=6)
    #     else:
    #         fig['layout']['annotations'][i].update(font_size=6)
    fig.update_layout(
        template="simple_white",
    #     title=dict(
    #         text=long_title,
    # #         font_family='monospace',
    #         x=0.5,
    #         y=1.0,
    #         font_size=18#,
    # #         vertical_alignment=bottom
    #     ),
        xaxis=dict(
            titlefont_size=14,
            tickfont_size=12,
            range=(ax_min, ax_max)
        ),
        yaxis=dict(
            titlefont_size=14,
            tickfont_size=12,
            range=(ax_min, ax_max)
        ),
        width=1000,
        height=1000 / ratio
    )

    fig.write_image(f'04f_dg_ff_vs_exp_{part}.png')

In [8]:
identifiers_to_plot = [
                     'OpenFF-1.0',
                     'OpenFF-1.2',
                     'OpenFF-2.0',
                        ]
targets_to_plot = ['cmet', 'eg5', 'cdk8', 'hif2a', 'pfkfb3', 'shp2', 'syk', 'tnks2']
guidelines = True
origins = True
titles = {}
plot_type='dg'
method_name='pmx'
ax_min=-17.0
ax_max=0.0
long_title = ''#DDG calc vs. exp'
subplot_titles = []
# for target in targets.target_dict.keys():
#     for idx in identifiers[1:]:
#         subplot_titles.append(f'{target}-{idx}')

cols = len(identifiers_to_plot)
rows = 6#len(targets.target_dict)
ratio = float(cols/rows)

for part in range(4):
    fig = plotly.subplots.make_subplots(cols = cols,
                                    rows = rows,
                                    shared_xaxes=True,
                                    shared_yaxes=True,
                                    x_title = f'Experimental ΔG [kcal mol<sup>-1</sup>]',
                                    y_title=f'Calculated ΔG [kcal mol<sup>-1</sup>]',
                                    vertical_spacing=0.005*ratio,
                                    horizontal_spacing=0.005,
#                                     subplot_titles = subplot_titles,
                                    row_titles=list(targets_to_plot)[part*6:(part+1)*6],
                                    column_titles=[idx for idx in identifiers_to_plot]
                                   )
    for i, target in enumerate(targets_to_plot[part*6:(part+1)*6]):
        for j, idx in enumerate(identifiers_to_plot):
            ddf = all_ligands[["target", f"DG_Exp.", f"DG_{idx}", f"dDG_Exp.", f"dDG_{idx}"]]
            ddf = ddf.dropna(axis=0)
            x = ddf.loc[ddf["target"]==target, f"DG_Exp."]
            y = ddf.loc[ddf["target"]==target, f"DG_{idx}"]
            xerr = ddf.loc[ddf["target"]==target, f"dDG_Exp."]
            yerr = ddf.loc[ddf["target"]==target, f"dDG_{idx}"]

# stats and title
#             nsamples = len(x)
#             statistics = ['RMSE', 'MUE']
#             string = []
#             for statistic in statistics:
            s = stats.bootstrap_statistic(x, y, statistic='RMSE')
            string = f"{'RMSE':5s}: {s['mle']:.1f} [{s['low']:.1f},{s['high']:.1f}]<br>"
            s = stats.bootstrap_statistic(x, y, statistic='KTAU')
            string += f"{'tau':5s}: {s['mle']:.1f} [{s['low']:.1f},{s['high']:.1f}]"
#             statistics_string = '<br>'.join(string)
            fig.add_trace(go.Scatter(
                x=[ax_min+.2],
                y=[ax_max-.2],
                mode="text",
                text=[string],
                textposition="bottom right",
                showlegend=False
                ),
                row = i+1,
                col = j+1
            )
#             titles[f'{target}-{idx}'] = f'{target} {idx} (N = {nsamples})<br>{statistics_string}'

            if origins:
                # x=0
                fig.add_trace(go.Scatter(x=[0, 0],
                                         y=[ax_min, ax_max],
                                         line_color='black',
                                         line_width=.5,
                                         mode='lines',
                                         showlegend=False
                                         ),
                          row = i+1,
                          col = j+1
                             )
                # y =0
                fig.add_trace(go.Scatter(x=[ax_min, ax_max],
                                         y=[0, 0],
                                         line_color='black',
                                         line_width=.5,
                                         mode='lines',
                                         showlegend=False
                                         ),
                          row = i+1,
                          col = j+1
                             )


            if guidelines:
                small_dist = 0.5
                fig.add_trace(go.Scatter(x=[ax_min, ax_max, ax_max, ax_min],
                                         y=[ax_min + 2. * small_dist, ax_max + 2. * small_dist, ax_max - 2. * small_dist,
                                            ax_min - 2. * small_dist],
                                         name='1 kcal/mol margin',
                                         hoveron='points+fills',
                                         hoverinfo='name',
                                         fill='toself',
                                         mode='lines', line_width=0, fillcolor='rgba(0, 0, 0, 0.2)',
                                         showlegend=False),
                          row = i+1,
                          col = j+1
                             )

                fig.add_trace(go.Scatter(x=[ax_min, ax_max, ax_max, ax_min],
                                         y=[ax_min + small_dist, ax_max + small_dist, ax_max - small_dist, ax_min - small_dist],
                                         name='.5 kcal/mol margin',
                                         hoveron='points+fills',
                                         hoverinfo='name',
                                         fill='toself',
                                         mode='lines', line_width=0, fillcolor='rgba(0, 0, 0, 0.2)',
                                         showlegend=False),
                          row = i+1,
                          col = j+1
                             )

            # diagonal
            fig.add_trace(go.Scatter(x=[ax_min, ax_max],
                             y=[ax_min, ax_max],
                             line_color='black',
                            line_width=.5,
                             mode='lines',
                             showlegend=False
                             ),
                          row = i+1,
                          col = j+1
                         )
            clr = np.abs(x - y) / 2.372
            fig.add_trace(go.Scatter(x=x, y=y,
                                 mode='markers',
                                 name=f'{target}',
                                 marker=dict(
                                     symbol='circle',
                                     color='black',
                                     size=6
#                                     colorscale='BlueRed'
                                 ),
#                                  error_x=dict(
#                                      type='data',  # value of error bar given in data coordinates
#                                      array=xerr,
#                                      visible=True),
#                                  error_y=dict(
#                                      type='data',  # value of error bar given in data coordinates
#                                      array=yerr,
#                                      visible=True),
                                 showlegend=False
                                 ),
                      row = i+1,
                      col = j+1
                        )

            fig.update_xaxes(range=[ax_min, ax_max],
                          row = i+1,
                          col = j+1
                            )
            fig.update_yaxes(range=[ax_min, ax_max],
                          row = i+1,
                          col = j+1
                            )
    plot_type='dg'
    method_name='pmx'
    long_title = 'DG calc vs. exp'

    fig.update_layout(
        template="simple_white",
        xaxis=dict(
            titlefont_size=14,
            tickfont_size=12,
            range=(ax_min, ax_max)
        ),
        yaxis=dict(
            titlefont_size=14,
            tickfont_size=12,
            range=(ax_min, ax_max)
        ),
        width=1000,
        height=1000 / ratio
    )

    fig.write_image(f'04f_dg_openff_vs_exp_{part}.png')

In [9]:
ax_min=-5
ax_max=+5

def correlation_plot_dg(target, idx1='Exp.', idx2='OpenFF-2.0'):
    fig = go.Figure()
    prepare_ax(fig, ax_min, ax_max, col=None, row=None, guidelines=True, origins=True)

    ddf = all_ligands[["target", f"DG_{idx1}", f"DG_{idx2}", f"dDG_{idx1}", f"dDG_{idx2}"]]
    ddf = ddf.dropna(axis=0)
    x = ddf.loc[ddf["target"]==target, f"DG_{idx1}"]
    y = ddf.loc[ddf["target"]==target, f"DG_{idx2}"]
    xerr = ddf.loc[ddf["target"]==target, f"dDG_{idx1}"]
    yerr = ddf.loc[ddf["target"]==target, f"dDG_{idx2}"]
            


#     clr = np.abs(x - y) / 2.372
    fig.add_trace(go.Scatter(x=x, y=y,
                             mode='markers',
                             name=f'{target}',
                             marker=dict(
                                 symbol='circle',
#                                  color=clr,
#                                  colorscale='BlueRed'
                             ),
                             error_x=dict(
                                 type='data',  # value of error bar given in data coordinates
                                 array=xerr,
                                 visible=True),
                             error_y=dict(
                                 type='data',  # value of error bar given in data coordinates
                                 array=yerr,
                                 visible=True),
                             showlegend=False
                             )
                    )        

    #stats and title
    nsamples = len(x)
    statistics = ['RMSE','MUE','R2','KTAU']
    string = []
    for statistic in statistics:
        s = stats.bootstrap_statistic(x, y, statistic=statistic)
        string.append(
            f"{statistic + ':':5s}{s['mle']:5.2f} [95%: {s['low']:5.2f}, {s['high']:5.2f}]")
    statistics_string = '<br>'.join(string)

    long_title = f'{idx} (N = {nsamples})<br>{statistics_string}'
    
    fig.update_layout(
        template="simple_white",
        title=dict(
            text=long_title,
    #         font_family='monospace',
            xref='paper',
            yref='container',
            xanchor='left',
            yanchor='top',
            x=0.0,
            y=.8,
            font_size=12,
        ),
        xaxis=dict(
            title=f'DG {idx1} [kcal/mol]',
            titlefont_size=14,
            tickfont_size=12,
            range=(ax_min, ax_max)
        ),
        yaxis=dict(
            title=f'DG {idx2} [kcal/mol]',
            titlefont_size=14,
            tickfont_size=12,
            range=(ax_min, ax_max)
        ),
        width=400,
        height=400
    )

    return fig

In [10]:
out = interact(correlation_plot_dg, target=list(targets.target_dict.keys()), idx1=identifiers, idx2=identifiers[1:])

interactive(children=(Dropdown(description='target', options=('jnk1', 'pde2', 'thrombin', 'p38', 'ptp1b', 'gal…