# Draw Venn plots of successes/outliers compared to experimental values

In [1]:
import os
import sys
sys.path.append(os.path.join(os.getcwd(), '..'))

import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from  plotly import colors
import pandas as pd
import yaml
import pint
unit_registry = pint.UnitRegistry()

from tqdm.notebook import tqdm

from PLBenchmarks import targets, ligands, edges


import benchmarkpl
path = benchmarkpl.__path__[0]
targets.set_data_dir(path)
results_dir = '10_results'





# Read in data

### Function to read in Parsley data

In [2]:
names = ['experiment',
        'OFF-all',
        'OFF-conv-I',
        'OFF-conv-II',
        'gaff2',
         'cgenff',
         'opls3e-gap',
         'opls3e-sch',
         'null'
        ]
identifiers = [f"experiment_hahn",
              f"pmx_openff-1.0.0.offxml_hahn",
              f"pmx_converged_openff-1.0.0.offxml_hahn",
              f"pmx_repeatfilter_openff-1.0.0.offxml_hahn",
              f"pmx_gaff_gapsys",
              f"pmx_cgenff_gapsys",
              f"fep_opls3e_5_gapsys",
              f"fep+_opls3e_schindler",
              f"null_null_hahn"
             ]

In [3]:
data = {}
for target in tqdm(targets.target_dict.keys()):
    data[target] = {}
    for idx in identifiers:
        file_name = os.path.join(path, targets.get_target_dir(target), results_dir,
                                       f'{target}_{idx}.yaml'
                           )
        if os.path.exists(file_name):
            with open(file_name, 'r') as file:
                data[target][idx] = yaml.safe_load(file)
        else:
            print(f"File {file_name} for target {target} not available")

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=22.0), HTML(value='')))

File /home/dahahn/projects/I_OPENFORCEFIELD/09_benchmarkpl/benchmarkpl/05_venn_diagrams/../benchmarkpl/2019-09-23_jnk1/10_results/jnk1_fep+_opls3e_schindler.yaml for target jnk1 not available
File /home/dahahn/projects/I_OPENFORCEFIELD/09_benchmarkpl/benchmarkpl/05_venn_diagrams/../benchmarkpl/2019-09-23_pde2/10_results/pde2_fep+_opls3e_schindler.yaml for target pde2 not available
File /home/dahahn/projects/I_OPENFORCEFIELD/09_benchmarkpl/benchmarkpl/05_venn_diagrams/../benchmarkpl/2019-09-23_thrombin/10_results/thrombin_fep+_opls3e_schindler.yaml for target thrombin not available
File /home/dahahn/projects/I_OPENFORCEFIELD/09_benchmarkpl/benchmarkpl/05_venn_diagrams/../benchmarkpl/2019-12-09_p38/10_results/p38_fep+_opls3e_schindler.yaml for target p38 not available
File /home/dahahn/projects/I_OPENFORCEFIELD/09_benchmarkpl/benchmarkpl/05_venn_diagrams/../benchmarkpl/2019-12-12_ptp1b/10_results/ptp1b_fep+_opls3e_schindler.yaml for target ptp1b not available
File /home/dahahn/projects/I

In [4]:
all_edges = pd.DataFrame()
for target, tdata in tqdm(data.items()):
    dfs = []
    for software, sdata in tdata.items():
        df = pd.DataFrame(sdata).T
        df['target'] = target
        df['edge'] = [f'edge_{row["ligandA"]}_{row["ligandB"]}' for i, row in df.iterrows()] 
        for i, row in df.iterrows():
            df.loc[i, f'DDG_{software}'] = unit_registry.Quantity(row['DDG'], row['unit']).to('kilocalories/mole').magnitude
            df.loc[i, f'dDDG_{software}'] = unit_registry.Quantity(row['dDDG'], row['unit']).to('kilocalories/mole').magnitude
        df = df.drop(labels=['DDG', 'dDDG'], axis=1)
        dfs.append(df)
    if len(dfs) > 0:
        df = pd.concat(dfs, axis=1)
        df = df.loc[:,~df.columns.duplicated()]
        all_edges = all_edges.append(df)
all_edges.head()

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=22.0), HTML(value='')))




Unnamed: 0,ligandA,ligandB,unit,target,edge,DDG_experiment_hahn,dDDG_experiment_hahn,DDG_pmx_openff-1.0.0.offxml_hahn,dDDG_pmx_openff-1.0.0.offxml_hahn,DDG_pmx_converged_openff-1.0.0.offxml_hahn,...,DDG_pmx_gaff_gapsys,dDDG_pmx_gaff_gapsys,DDG_pmx_cgenff_gapsys,dDDG_pmx_cgenff_gapsys,DDG_fep_opls3e_5_gapsys,dDDG_fep_opls3e_5_gapsys,DDG_null_null_hahn,dDDG_null_null_hahn,DDG_fep+_opls3e_schindler,dDDG_fep+_opls3e_schindler
jnk1_edge_17124-1_18631-1,17124-1,18631-1,kilocalories / mole,jnk1,edge_17124-1_18631-1,0.26,0.37,1.19,0.1,1.19,...,1.331262,0.731358,0.776769,0.100382,1.517686,0.069312,0.0,0.1,,
jnk1_edge_17124-1_18634-1,17124-1,18634-1,kilocalories / mole,jnk1,edge_17124-1_18634-1,-0.33,0.29,0.58,0.12,0.58,...,0.499522,0.160134,0.250956,0.145793,0.583174,0.043021,0.0,0.1,,
jnk1_edge_18626-1_18624-1,18626-1,18624-1,kilocalories / mole,jnk1,edge_18626-1_18624-1,0.38,0.21,0.56,0.1,0.556667,...,1.125717,0.090822,0.114723,0.033461,1.073136,0.040631,0.0,0.1,,
jnk1_edge_18626-1_18625-1,18626-1,18625-1,kilocalories / mole,jnk1,edge_18626-1_18625-1,0.77,0.21,-0.03,0.11,-0.03,...,0.707457,0.112333,0.475621,0.41826,1.445985,0.033461,0.0,0.1,,
jnk1_edge_18626-1_18627-1,18626-1,18627-1,kilocalories / mole,jnk1,edge_18626-1_18627-1,0.39,0.22,0.14,0.05,0.14,...,0.4326,0.076482,0.157744,0.086042,0.39675,0.081262,0.0,0.1,,


In [5]:
for idx in identifiers[1:]:
    all_edges[f'error_{idx}'] = all_edges[f'DDG_{idx}'] - all_edges['DDG_experiment_hahn']
    all_edges[f'abserror_{idx}'] = all_edges[f'error_{idx}'].abs()

In [6]:
def get_inliers(threshold):
    inliers = all_edges.filter(items=[f'abserror_{idx}' for idx in identifiers[1:]])
    for idx in identifiers[1:]:
        inliers.loc[:, f'inlier_{idx}'] = inliers.loc[:, f'abserror_{idx}'] <= threshold
        
        inliers.loc[inliers.loc[:, f'abserror_{idx}'].isna(), f'inlier_{idx}'] = np.nan
    inliers = inliers.filter(items=[f'inlier_{idx}' for idx in identifiers[1:]])
    return inliers
inliers = get_inliers(3)
inliers.isna().sum(), (inliers == 1).sum(), (inliers == 0).sum()

(inlier_pmx_openff-1.0.0.offxml_hahn                 173
 inlier_pmx_converged_openff-1.0.0.offxml_hahn       241
 inlier_pmx_repeatfilter_openff-1.0.0.offxml_hahn    753
 inlier_pmx_gaff_gapsys                              691
 inlier_pmx_cgenff_gapsys                            691
 inlier_fep_opls3e_5_gapsys                          691
 inlier_fep+_opls3e_schindler                        707
 inlier_null_null_hahn                                89
 dtype: int64,
 inlier_pmx_openff-1.0.0.offxml_hahn                  887
 inlier_pmx_converged_openff-1.0.0.offxml_hahn        846
 inlier_pmx_repeatfilter_openff-1.0.0.offxml_hahn     398
 inlier_pmx_gaff_gapsys                               454
 inlier_pmx_cgenff_gapsys                             443
 inlier_fep_opls3e_5_gapsys                           460
 inlier_fep+_opls3e_schindler                         428
 inlier_null_null_hahn                               1024
 dtype: int64,
 inlier_pmx_openff-1.0.0.offxml_hahn              

In [7]:
def get_outliers(threshold):
    outliers = all_edges.filter(items=[f'abserror_{idx}' for idx in identifiers[1:]])
    for idx in identifiers[1:]:
        outliers.loc[:, f'outlier_{idx}'] = outliers.loc[:, f'abserror_{idx}'] > threshold
        
        outliers.loc[outliers.loc[:, f'abserror_{idx}'].isna(), f'outlier_{idx}'] = np.nan
    outliers = outliers.filter(items=[f'outlier_{idx}' for idx in identifiers[1:]])
    return outliers
outliers = get_outliers(3)
outliers.isna().sum(), (outliers == 1).sum(), (outliers == 0).sum()

(outlier_pmx_openff-1.0.0.offxml_hahn                 173
 outlier_pmx_converged_openff-1.0.0.offxml_hahn       241
 outlier_pmx_repeatfilter_openff-1.0.0.offxml_hahn    753
 outlier_pmx_gaff_gapsys                              691
 outlier_pmx_cgenff_gapsys                            691
 outlier_fep_opls3e_5_gapsys                          691
 outlier_fep+_opls3e_schindler                        707
 outlier_null_null_hahn                                89
 dtype: int64,
 outlier_pmx_openff-1.0.0.offxml_hahn                 96
 outlier_pmx_converged_openff-1.0.0.offxml_hahn       69
 outlier_pmx_repeatfilter_openff-1.0.0.offxml_hahn     5
 outlier_pmx_gaff_gapsys                              11
 outlier_pmx_cgenff_gapsys                            22
 outlier_fep_opls3e_5_gapsys                           5
 outlier_fep+_opls3e_schindler                        21
 outlier_null_null_hahn                               43
 dtype: int64,
 outlier_pmx_openff-1.0.0.offxml_hahn             

In [11]:
def get_overlap(dataframe, idx1, idx2, idx3, which='outlier'):
    # following order: (100, 010, 110, 001, 101, 011, 111)
    ynn = []
    nyn = []
    nny = []
    yyn = []
    yny = []
    nyy = []
    for i, row in dataframe.iterrows():
        row[row.isna()] = False
        ynn.append(    row[f'{which}_{idx1}'] and 
                   not row[f'{which}_{idx2}'] and 
                   not row[f'{which}_{idx3}'])
        nyn.append(not row[f'{which}_{idx1}'] and 
                       row[f'{which}_{idx2}'] and 
                   not row[f'{which}_{idx3}'])
        nny.append(not row[f'{which}_{idx1}'] and 
                   not row[f'{which}_{idx2}'] and 
                       row[f'{which}_{idx3}'])
        yyn.append(    row[f'{which}_{idx1}'] and 
                       row[f'{which}_{idx2}'] and 
                   not row[f'{which}_{idx3}'])
        yny.append(    row[f'{which}_{idx1}'] and 
                   not row[f'{which}_{idx2}'] and 
                       row[f'{which}_{idx3}'])
        nyy.append(not row[f'{which}_{idx1}'] and 
                       row[f'{which}_{idx2}'] and 
                       row[f'{which}_{idx3}'])
    overlap = []
    sizes = [dataframe[dataframe[f'{which}_{idx}']==1].shape[0]
             for idx in [idx1, idx2, idx3]
            ]
    for n in [ynn, nyn, yyn, nny, yny, nyy]:
        overlap.append(np.sum(n))
    overlap.append(sizes[0]-np.sum(ynn)-np.sum(yyn)-np.sum(yny))
    overlap.append(sizes[1]-np.sum(nyn)-np.sum(yyn)-np.sum(nyy))
    overlap.append(sizes[2]-np.sum(nny)-np.sum(nyy)-np.sum(yny))
    overlap.append(sizes[0])
    overlap.append(sizes[1])
    overlap.append(sizes[2])
    return overlap

In [12]:
from matplotlib_venn import _venn3
def venn_plot(idx1, idx2, idx3, threshold, which='outlier'):
    labels = [names[identifiers.index(idx)] for idx in [idx1, idx2, idx3]]
    if which=='outlier':
        numbers = get_outliers(threshold)
        color_number = 6
        title = f'Outliers with Δ(ΔΔG) > {threshold} kcal mol<sup>-1</sup>'
    elif which=='inlier':
        numbers = get_inliers(threshold)
        color_number = 0
        title = f'Successes with Δ(ΔΔG) <= {threshold} kcal mol<sup>-1</sup>'
    else:
        raise ValueError(f'{which} argument not known.')
        
    overlap = get_overlap(numbers, idx1, idx2, idx3, which=which)
    sizes = overlap[-3:]
    sim_sizes = [
        numbers.shape[0] - numbers[f'{which}_{idx}'].isna().sum() for idx in [idx1, idx2, idx3]
    ]
    areas = _venn3.compute_venn3_areas(overlap[:7])
    centers, radii = _venn3.solve_venn3_circles(areas)
    regions = _venn3.compute_venn3_regions(centers, radii)
    label_positions = np.array([centers[0] + np.array([-radii[0] / 2, radii[0]])*1.1,
                                centers[1] + np.array([radii[1] / 2, radii[1]])*1.1,
                                centers[2] + np.array([0.0, -radii[2] * 1.1])])
    #colors = _venn3.compute_venn3_colors(set_colors)
    
    subset_positions = np.array([r.label_position() for r in regions])
    subset_labels = [f'{int(s):d}' for s in overlap[:7]]
    
    fig = go.Figure()
    colorway=colors.qualitative.Safe + colors.qualitative.Vivid
    # Create scatter trace of text labels
    fig.add_trace(go.Scatter(
        x=label_positions[:,0],
        y=label_positions[:,1],
        text=[f'{l} ({s}/{a})' for l, s, a in zip(labels, sizes, sim_sizes)],
        mode="text",
        textfont=dict(
            color="black",
            size=18
        )
    ))
    fig.add_trace(go.Scatter(
        x=subset_positions[:,0],
        y=subset_positions[:,1],
        text=subset_labels,
        mode="text",
        textfont=dict(
            color="black",
            size=18
        )
    ))
    # Update axes properties
    fig.update_xaxes(
        showticklabels=False,
        showgrid=False,
        zeroline=False,
    )

    fig.update_yaxes(
        showticklabels=False,
        showgrid=False,
        zeroline=False,
    )

    # Add circles
    for i in range(3):
        fig.add_shape(
                type="circle",
                fillcolor=colors.qualitative.Prism[color_number+i],
                x0=centers[i][0]-radii[i],
                y0=centers[i][1]-radii[i],
                x1=centers[i][0]+radii[i],
                y1=centers[i][1]+radii[i],
                line_color=colors.qualitative.Prism[color_number+i]
            )
    fig.update_shapes(dict(
        opacity=0.5,
        xref="x",
        yref="y",
        layer="below"
    ))
    axlim = (np.max(np.fabs(centers)) + np.max(np.fabs(radii)))*1.2
    # Update figure dimensions
    fig.update_layout(
        title={
            'text': title,
            'y':0.95,
            'x':0.5,
            'xanchor': 'center',
            'yanchor': 'top',
            'font': {'size': 24}},
        margin=dict(
            l=30,
            r=30,
            b=30,
            t=30
        ),
        xaxis=dict(range=[-axlim, axlim]),
        yaxis=dict(range=[-axlim, axlim]),
        height=800,
        width=800,
        plot_bgcolor="white",
        showlegend=False
    )
    return fig

In the following interactive cell, a Venn plot is created. The three calculated sets can be chosen in the dropdown menus of `idx1`, `idx2` and `idx3`. A `threshold` (kcal/mol) can be chosen next and whether you want to see the successes/inliers or the outliers. The Venn plot is based on the comparison between calculated set and experimental values of edges/relative free energies (DDG values). The sizes of the different fields are written into the fields. The numbers in brackets behind the force field name are the number of successes or outliers and the total number of available simulations. 

In [13]:
from ipywidgets import widgets, interact
out = interact(venn_plot, idx1=identifiers[1:], idx2=identifiers[1:], idx3=identifiers[1:], threshold=np.arange(0, 5, 0.5), which=['inlier', 'outlier'])

interactive(children=(Dropdown(description='idx1', options=('pmx_openff-1.0.0.offxml_hahn', 'pmx_converged_ope…