In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os, sys
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch
import h5py

In [None]:
DELQSAR_ROOT = os.getcwd() + '/../../'
sys.path += [DELQSAR_ROOT + '/../']

In [None]:
FINGERPRINTS_FILENAME = 'x_triazine_2048_bits_all_fps.h5' # should be in the experiments folder

sEH_FP_FFNN_MODEL_PATHS = [os.path.join(DELQSAR_ROOT, 'experiments', 'models', 'triazine_sEH',
                                'FP-FFNN', f'cycle123_seed_{i}.torch') for i in range(5)]

sEH_OH_FFNN_MODEL_PATHS = [os.path.join(DELQSAR_ROOT, 'experiments', 'models', 'triazine_sEH',
                                'OH-FFNN', f'cycle123_seed_{i}.torch') for i in range(5)]

sEH_D_MPNN_MODEL_PATHS = [os.path.join(DELQSAR_ROOT, 'experiments', 'models', 'triazine_sEH',
                                'D-MPNN', f'cycle123_seed_{i}.torch') for i in range(3)]

SIRT2_FP_FFNN_MODEL_PATHS = [os.path.join(DELQSAR_ROOT, 'experiments', 'models', 'triazine_SIRT2',
                                'FP-FFNN', f'cycle123_seed_{i}.torch') for i in range(5)]

SIRT2_OH_FFNN_MODEL_PATHS = [os.path.join(DELQSAR_ROOT, 'experiments', 'models', 'triazine_SIRT2',
                                'OH-FFNN', f'cycle123_seed_{i}.torch') for i in range(5)]

SIRT2_D_MPNN_MODEL_PATHS = [os.path.join(DELQSAR_ROOT, 'experiments', 'models', 'triazine_SIRT2',
                                'D-MPNN', f'cycle123_seed_{i}.torch') for i in range(3)]

# for D-MPNNs
NUM_WORKERS = 20

In [None]:
from del_qsar import models, featurizers, splitters
from del_qsar.enrichments import R_from_z, R_ranges

if not os.path.isdir('triazine_generalization_parity_plots'):
    os.mkdir('triazine_generalization_parity_plots')
    
def pathify(fname):
    return os.path.join('triazine_generalization_parity_plots', fname)

LOG_FILE = os.path.join(DELQSAR_ROOT, 'experiments', 'visualizations', 
                        'triazine_generalization_parity_plots', 'triazine_generalization_parity_plots.log')

In [None]:
import matplotlib
import matplotlib.pyplot as plt
plt.style.use('seaborn-paper')

matplotlib.rc('font', family='sans-serif') 
matplotlib.rc('font', serif='Arial') 
matplotlib.rc('text', usetex='false') 
matplotlib.rcParams.update({'font.size': 9})

In [None]:
df_data = pd.read_csv(os.path.join(DELQSAR_ROOT, 'experiments', 'datasets', 
                                   'triazine_lib_sEH_SIRT2_QSAR.csv'))

In [None]:
cyc3_dup_ids = []
cyc2_dup_ids = []
sym_BBs = pd.read_csv(os.path.join(DELQSAR_ROOT, 'experiments', 'visualizations', 'triazine_symmetric_BBs.csv'))
for i in range(len(sym_BBs)):
    if not np.isnan(sym_BBs.iloc[i]['cycle 2 BB ID']):
        cyc3_dup_ids.append(int(sym_BBs.iloc[i]['cycle 3 BB ID']))
        cyc2_dup_ids.append(int(sym_BBs.iloc[i]['cycle 2 BB ID']))

In [None]:
splitter = splitters.ThreeCycleSplitter(['cycle1','cycle2','cycle3'], LOG_FILE)

In [None]:
new_test_slices = [splitter(_, df_data, seed=i, getAllNewTestSlice=True, 
                            cyc2_dup_ids=cyc2_dup_ids, cyc3_dup_ids=cyc3_dup_ids) for i in tqdm(range(5))]

[len(s) for s in new_test_slices]

In [None]:
DEVICE = None
if torch.cuda.is_available():
    DEVICE = 'cuda:0'

# sEH

In [None]:
exp_counts = np.array(df_data[['sEH [strep]_tot']], dtype='int')
bead_counts = np.array(df_data[['beads-linker-only [strep]_tot']], dtype='int')
exp_tot = np.sum(exp_counts, axis=0) # column sums
bead_tot = np.sum(bead_counts, axis=0)

## Function for plots

In [None]:
def draw_predicted_enrichments_vs_true(eval_slices, modelType, zoomIn1=False, zoomIn2=False, legend=False):     
    if legend:
        fig = plt.figure(figsize=(7, 2), dpi=300)
    else:
        fig = plt.figure(figsize=(2.33, 2), dpi=300)
    colors = ["#4878D0", "#6ACC64", "#D65F5F",
            "#956CB4", "#D5BB67", "#82C6E2"]
        
    def plot_seed_0():
        R_0, R_lb_0, R_ub_0 = R_ranges(bead_counts[eval_slices[0], 0], bead_tot[0], 
                                       exp_counts[eval_slices[0], 0], exp_tot[0])
        if modelType != 'D-MPNN':
            test_enrichments_0 = model_0.predict_on_x(
                x[eval_slices[0], :], batch_size=BATCH_SIZE,
                device=DEVICE, num_workers=NUM_WORKERS,
            )
        else:
            test_enrichments_0 = model_0.predict_on_x(
                [x[i] for i in eval_slices[0]], batch_size=BATCH_SIZE,
                device=DEVICE,
            )
        
        lower_error_0 = R_0 - R_lb_0
        upper_error_0 = R_ub_0 - R_0
        error_0 = [lower_error_0, upper_error_0]
        
        container = plt.errorbar(
            x=R_0, 
            y=test_enrichments_0,
            xerr=error_0,
            color=colors[0],
            marker='o',
            markersize=3,
            elinewidth=0.75,
            ls='none',
            ecolor='k',
            capsize=1,
            capthick=0.75, 
            zorder=2,
            label='Seed 0',
        ) 
        lines = plt.plot(
            np.linspace(min(test_enrichments_0), max(test_enrichments_0), 100),
            np.linspace(min(test_enrichments_0), max(test_enrichments_0), 100),
            color="#64B5CD",
            label='Parity',
            linewidth=0.75, 
            zorder=3,
        )

    def plot_seed_1():
        R_1, R_lb_1, R_ub_1 = R_ranges(bead_counts[eval_slices[1], 0], bead_tot[0], 
                                       exp_counts[eval_slices[1], 0], exp_tot[0])
        if modelType != 'D-MPNN':
            test_enrichments_1 = model_1.predict_on_x(
                x[eval_slices[1], :], batch_size=BATCH_SIZE,
                device=DEVICE, num_workers=NUM_WORKERS,
            )
        else:
            test_enrichments_1 = model_1.predict_on_x(
                [x[i] for i in eval_slices[1]], batch_size=BATCH_SIZE,
                device=DEVICE,
            )
            
        lower_error_1 = R_1 - R_lb_1
        upper_error_1 = R_ub_1 - R_1
        error_1 = [lower_error_1, upper_error_1]
        container = plt.errorbar(
            x=R_1, 
            y=test_enrichments_1,
            xerr=error_1,
            color=colors[1],
            marker='o',
            markersize=3,
            elinewidth=0.75,
            ls='none',
            ecolor='k',
            capsize=1,
            capthick=0.75, 
            zorder=2,
            label='Seed 1',
        ) 
        lines = plt.plot(
            np.linspace(min(test_enrichments_1), max(test_enrichments_1), 100),
            np.linspace(min(test_enrichments_1), max(test_enrichments_1), 100),
            color="#64B5CD",
            linewidth=0.75, 
            zorder=3,
        )

    def plot_seed_2():
        R_2, R_lb_2, R_ub_2 = R_ranges(bead_counts[eval_slices[2], 0], bead_tot[0], 
                                       exp_counts[eval_slices[2], 0], exp_tot[0])
        if modelType != 'D-MPNN':
            test_enrichments_2 = model_2.predict_on_x(
                x[eval_slices[2], :], batch_size=BATCH_SIZE,
                device=DEVICE, num_workers=NUM_WORKERS,
            )
        else:
            test_enrichments_2 = model_2.predict_on_x(
                [x[i] for i in eval_slices[2]], batch_size=BATCH_SIZE,
                device=DEVICE,
            )
            
        lower_error_2 = R_2 - R_lb_2
        upper_error_2 = R_ub_2 - R_2
        error_2 = [lower_error_2, upper_error_2]
        container = plt.errorbar(
            x=R_2, 
            y=test_enrichments_2,
            xerr=error_2,
            color=colors[2],
            marker='o',
            markersize=3,
            elinewidth=0.75,
            ls='none',
            ecolor='k',
            capsize=1,
            capthick=0.75, 
            zorder=2,
            label='Seed 2',
        ) 
        lines = plt.plot(
            np.linspace(min(test_enrichments_2), max(test_enrichments_2), 100),
            np.linspace(min(test_enrichments_2), max(test_enrichments_2), 100),
            color="#64B5CD",
            linewidth=0.75, 
            zorder=3,
        )

    def plot_seed_3():
        R_3, R_lb_3, R_ub_3 = R_ranges(bead_counts[eval_slices[3], 0], bead_tot[0], 
                                       exp_counts[eval_slices[3], 0], exp_tot[0])
        if modelType != 'D-MPNN':
            test_enrichments_3 = model_3.predict_on_x(
                x[eval_slices[3], :], batch_size=BATCH_SIZE,
                device=DEVICE, num_workers=NUM_WORKERS,
            )
        else:
            test_enrichments_3 = model_3.predict_on_x(
                [x[i] for i in eval_slices[3]], batch_size=BATCH_SIZE,
                device=DEVICE,
            )
            
        lower_error_3 = R_3 - R_lb_3
        upper_error_3 = R_ub_3 - R_3
        error_3 = [lower_error_3, upper_error_3]
        container = plt.errorbar(
            x=R_3, 
            y=test_enrichments_3,
            xerr=error_3,
            color=colors[3],
            marker='o',
            markersize=3,
            elinewidth=0.75,
            ls='none',
            ecolor='k',
            capsize=1,
            capthick=0.75, 
            zorder=2,
            label='Seed 3',
        ) 
        lines = plt.plot(
            np.linspace(min(test_enrichments_3), max(test_enrichments_3), 100),
            np.linspace(min(test_enrichments_3), max(test_enrichments_3), 100),
            color="#64B5CD",
            linewidth=0.75, 
            zorder=3,
        )
    
    def plot_seed_4():
        R_4, R_lb_4, R_ub_4 = R_ranges(bead_counts[eval_slices[4], 0], bead_tot[0], 
                                       exp_counts[eval_slices[4], 0], exp_tot[0])
        if modelType != 'D-MPNN':
            test_enrichments_4 = model_4.predict_on_x(
                x[eval_slices[4], :], batch_size=BATCH_SIZE,
                device=DEVICE, num_workers=NUM_WORKERS,
            )
        else:
            test_enrichments_4 = model_4.predict_on_x(
                [x[i] for i in eval_slices[4]], batch_size=BATCH_SIZE,
                device=DEVICE,
            )
      
        lower_error_4 = R_4 - R_lb_4
        upper_error_4 = R_ub_4 - R_4
        error_4 = [lower_error_4, upper_error_4]
        container = plt.errorbar(
            x=R_4, 
            y=test_enrichments_4,
            xerr=error_4,
            color=colors[4],
            marker='o',
            markersize=3,
            elinewidth=0.75,
            ls='none',
            ecolor='k',
            capsize=1,
            capthick=0.75, 
            zorder=2,
            label='Seed 4',
        ) 
        lines = plt.plot(
            np.linspace(min(test_enrichments_4), max(test_enrichments_4), 100),
            np.linspace(min(test_enrichments_4), max(test_enrichments_4), 100),
            color="#64B5CD",
            linewidth=0.75, 
            zorder=3,
        )

    if modelType == 'FP-FFNN':
        plot_seed_4()
        plot_seed_1()
        plot_seed_2()
        plot_seed_0()
        plot_seed_3()
    elif modelType == 'OH-FFNN':
        plot_seed_4()
        plot_seed_3()
        plot_seed_0()
        plot_seed_1()
        plot_seed_2()
    elif modelType == 'D-MPNN':
        plot_seed_1()
        plot_seed_2()
        plot_seed_0()
    
    fig.canvas.draw() # required to get tick labels
    ax = plt.gca() 
    handles, labels = ax.get_legend_handles_labels()
    if modelType == 'FP-FFNN':
        order = [0, 4, 2, 3, 5, 1]
    elif modelType == 'OH-FFNN':
        order = [0, 3, 5, 4, 1, 2]
    elif modelType == 'D-MPNN':
        order = [0, 3, 2, 1]
    handles = [handles[i] for i in order]
    labels = [labels[i] for i in order]
    
    if legend:
        leg = plt.legend(handles, labels, loc='lower center', bbox_to_anchor = (0.5,-1.2), 
                         numpoints=1, fontsize=7, ncol=6)
        
    if modelType == 'FP-FFNN' and zoomIn1:
        ax.set_xlim([0, 40]) 
    elif modelType == 'FP-FFNN' and zoomIn2:
        ax.set_xlim([0.12, 2.51])
    elif modelType == 'OH-FFNN' and zoomIn1:
        ax.set_xlim([0, 40]) 
    elif modelType == 'OH-FFNN' and zoomIn2:
        ax.set_xlim([0.79, 1.37])
    elif modelType == 'D-MPNN' and zoomIn1:
        ax.set_xlim([-0.05, 60])
    elif modelType == 'D-MPNN' and zoomIn2:
        ax.set_xlim([-0.05, 6.68])
    ax.tick_params(labelsize=8)

    ax.grid(zorder=1)
    ax.set_xlabel('Calculated enrichment', fontsize=8)
    ax.set_ylabel('Predicted enrichment', fontsize=8)
    plt.tight_layout()
    print(plt.axis())
    if legend and modelType == 'D-MPNN':
        plt.savefig(pathify(f'Parity_plot_3_replicates_legend.png'), bbox_extra_artists=(leg,), bbox_inches='tight')
    elif legend and modelType != 'D-MPNN':
        plt.savefig(pathify(f'Parity_plot_5_replicates_legend.png'), bbox_extra_artists=(leg,), bbox_inches='tight')
    elif modelType == 'FP-FFNN' and zoomIn1:
        plt.savefig(pathify(f'Parity_plot_zoomed_in_1_sEH_FP-FFNN_seed_0.png'))
    elif modelType == 'FP-FFNN' and zoomIn2:
        plt.savefig(pathify(f'Parity_plot_zoomed_in_2_sEH_FP-FFNN_seed_0.png'))
    elif modelType == 'FP-FFNN':
        plt.savefig(pathify(f'Parity_plot_full_sEH_FP-FFNN_seed_0.png'))
    elif modelType == 'OH-FFNN' and zoomIn1:
        plt.savefig(pathify(f'Parity_plot_zoomed_in_1_sEH_OH-FFNN_seed_0.png'))
    elif modelType == 'OH-FFNN' and zoomIn2: 
        plt.savefig(pathify(f'Parity_plot_zoomed_in_2_sEH_OH-FFNN_seed_0.png'))
    elif modelType == 'OH-FFNN':
        plt.savefig(pathify(f'Parity_plot_full_sEH_OH-FFNN_seed_0.png'))
    elif modelType == 'D-MPNN' and zoomIn1:
        plt.savefig(pathify(f'Parity_plot_zoomed_in_1_sEH_D-MPNN_seed_0.png'))
    elif modelType == 'D-MPNN' and zoomIn2:
        plt.savefig(pathify(f'Parity_plot_zoomed_in_2_sEH_D-MPNN_seed_0.png'))
    elif modelType == 'D-MPNN':
        plt.savefig(pathify(f'Parity_plot_full_sEH_D-MPNN_seed_0.png'))
    plt.show()

## FP-FFNN

In [None]:
os.environ["HDF5_USE_FILE_LOCKING"] = 'FALSE'
hf = h5py.File(os.path.join(DELQSAR_ROOT, 'experiments', FINGERPRINTS_FILENAME), 'r')
x = np.array(hf['all_fps'])
INPUT_SIZE = x.shape[1]
hf.close()

In [None]:
BATCH_SIZE = 1024
LAYER_SIZES = [
    [1024, 256, 64],
    [1024, 512, 256],
    [512, 512, 512],
    [256, 128, 64],
    [1024, 512],
]
DROPOUT = [0.45, 0.15, 0.05, 0.1, 0.3]

# seed 0
model_0 = models.MLP(INPUT_SIZE, [int(size) for size in LAYER_SIZES[0]],
                    dropout=DROPOUT[0], torch_seed=0)
model_0.load_state_dict(torch.load(sEH_FP_FFNN_MODEL_PATHS[0]))

# seed 1
model_1 = models.MLP(INPUT_SIZE, [int(size) for size in LAYER_SIZES[1]],
                    dropout=DROPOUT[1], torch_seed=1)
model_1.load_state_dict(torch.load(sEH_FP_FFNN_MODEL_PATHS[1]))

# seed 2
model_2 = models.MLP(INPUT_SIZE, [int(size) for size in LAYER_SIZES[2]],
                    dropout=DROPOUT[2], torch_seed=2)
model_2.load_state_dict(torch.load(sEH_FP_FFNN_MODEL_PATHS[2]))       

# seed 3
model_3 = models.MLP(INPUT_SIZE, [int(size) for size in LAYER_SIZES[3]],
                    dropout=DROPOUT[3], torch_seed=3)
model_3.load_state_dict(torch.load(sEH_FP_FFNN_MODEL_PATHS[3]))                      

# seed 4
model_4 = models.MLP(INPUT_SIZE, [int(size) for size in LAYER_SIZES[4]],
                    dropout=DROPOUT[4], torch_seed=4)
model_4.load_state_dict(torch.load(sEH_FP_FFNN_MODEL_PATHS[4]))

if DEVICE:
    model_0 = model_0.to(DEVICE)
    model_1 = model_1.to(DEVICE)
    model_2 = model_2.to(DEVICE)
    model_3 = model_3.to(DEVICE)
    model_4 = model_4.to(DEVICE)

In [None]:
draw_predicted_enrichments_vs_true(new_test_slices, 'FP-FFNN')

In [None]:
draw_predicted_enrichments_vs_true(new_test_slices, 'FP-FFNN', zoomIn1=True)

In [None]:
draw_predicted_enrichments_vs_true(new_test_slices, 'FP-FFNN', zoomIn2=True)

In [None]:
# getting legend
draw_predicted_enrichments_vs_true(new_test_slices, 'FP-FFNN', legend=True)

## OH-FFNN

In [None]:
featurizer = featurizers.OneHotFeaturizer(df_data)
x = featurizer.prepare_x(df_data)
INPUT_SIZE = x.shape[1]

In [None]:
BATCH_SIZE = 1024
LAYER_SIZES = [
    [1024, 512, 256],
    [512, 512, 512],
    [1024, 512, 256],
    [1024, 1024, 1024],
    [512, 512, 512],
]
DROPOUT = [0.1, 0.2, 0.1, 0.3, 0.5]

# seed 0
model_0 = models.MLP(INPUT_SIZE, [int(size) for size in LAYER_SIZES[0]],
                    dropout=DROPOUT[0], torch_seed=0)
model_0.load_state_dict(torch.load(sEH_OH_FFNN_MODEL_PATHS[0]))

# seed 1
model_1 = models.MLP(INPUT_SIZE, [int(size) for size in LAYER_SIZES[1]],
                    dropout=DROPOUT[1], torch_seed=1)
model_1.load_state_dict(torch.load(sEH_OH_FFNN_MODEL_PATHS[1]))

# seed 2
model_2 = models.MLP(INPUT_SIZE, [int(size) for size in LAYER_SIZES[2]],
                    dropout=DROPOUT[2], torch_seed=2)
model_2.load_state_dict(torch.load(sEH_OH_FFNN_MODEL_PATHS[2]))       

# seed 3
model_3 = models.MLP(INPUT_SIZE, [int(size) for size in LAYER_SIZES[3]],
                    dropout=DROPOUT[3], torch_seed=3)
model_3.load_state_dict(torch.load(sEH_OH_FFNN_MODEL_PATHS[3]))                      

# seed 4
model_4 = models.MLP(INPUT_SIZE, [int(size) for size in LAYER_SIZES[4]],
                    dropout=DROPOUT[4], torch_seed=4)
model_4.load_state_dict(torch.load(sEH_OH_FFNN_MODEL_PATHS[4]))

if DEVICE:
    model_0 = model_0.to(DEVICE)
    model_1 = model_1.to(DEVICE)
    model_2 = model_2.to(DEVICE)
    model_3 = model_3.to(DEVICE)
    model_4 = model_4.to(DEVICE)

In [None]:
draw_predicted_enrichments_vs_true(new_test_slices, 'OH-FFNN')

In [None]:
draw_predicted_enrichments_vs_true(new_test_slices, 'OH-FFNN', zoomIn1=True)

In [None]:
draw_predicted_enrichments_vs_true(new_test_slices, 'OH-FFNN', zoomIn2=True)

## D-MPNN

In [None]:
smis = df_data['smiles']
targets = R_from_z(bead_counts, bead_tot, exp_counts, exp_tot, 0).tolist()
featurizer = featurizers.GraphFeaturizer(smis, targets)
x = featurizer.prepare_x()

In [None]:
BATCH_SIZE = 50
DEPTH = 6
HIDDEN_SIZE = 1500
FFN_NUM_LAYERS = 3
DROPOUT = 0.05

# seed 0
model_0 = models.MoleculeModel(depth = DEPTH,
                             hidden_size = HIDDEN_SIZE, 
                             ffn_num_layers = FFN_NUM_LAYERS,
                             dropout = DROPOUT, device = DEVICE, 
                             torch_seed=0)
model_0.load_state_dict(torch.load(sEH_D_MPNN_MODEL_PATHS[0]))

# seed 1
model_1 = models.MoleculeModel(depth = DEPTH,
                             hidden_size = HIDDEN_SIZE, 
                             ffn_num_layers = FFN_NUM_LAYERS,
                             dropout = DROPOUT, device = DEVICE, 
                             torch_seed=1)
model_1.load_state_dict(torch.load(sEH_D_MPNN_MODEL_PATHS[1]))

# seed 2
model_2 = models.MoleculeModel(depth = DEPTH,
                             hidden_size = HIDDEN_SIZE, 
                             ffn_num_layers = FFN_NUM_LAYERS,
                             dropout = DROPOUT, device = DEVICE, 
                             torch_seed=2)
model_2.load_state_dict(torch.load(sEH_D_MPNN_MODEL_PATHS[2]))   

if DEVICE:
    model_0 = model_0.to(DEVICE)
    model_1 = model_1.to(DEVICE)
    model_2 = model_2.to(DEVICE)

In [None]:
draw_predicted_enrichments_vs_true(new_test_slices, 'D-MPNN')

In [None]:
draw_predicted_enrichments_vs_true(new_test_slices, 'D-MPNN', zoomIn1=True)

In [None]:
draw_predicted_enrichments_vs_true(new_test_slices, 'D-MPNN', zoomIn2=True)

In [None]:
# getting legend
draw_predicted_enrichments_vs_true(new_test_slices, 'D-MPNN', legend=True)

# SIRT2

In [None]:
exp_counts = np.array(df_data[['SIRT2 [strep]_tot']], dtype='int')
bead_counts = np.array(df_data[['beads-linker-only [strep]_tot']], dtype='int')
exp_tot = np.sum(exp_counts, axis=0) # column sums
bead_tot = np.sum(bead_counts, axis=0)

## Function for plots

In [None]:
def draw_predicted_enrichments_vs_true(eval_slices, modelType, zoomIn1=False, zoomIn2=False, legend=False):     
    if legend:
        fig = plt.figure(figsize=(7, 2), dpi=300)
    else:
        fig = plt.figure(figsize=(2.33, 2), dpi=300)
    colors = ["#4878D0", "#6ACC64", "#D65F5F",
            "#956CB4", "#D5BB67", "#82C6E2"]
        
    def plot_seed_0():
        R_0, R_lb_0, R_ub_0 = R_ranges(bead_counts[eval_slices[0], 0], bead_tot[0], 
                                       exp_counts[eval_slices[0], 0], exp_tot[0])
        if modelType != 'D-MPNN':
            test_enrichments_0 = model_0.predict_on_x(
                x[eval_slices[0], :], batch_size=BATCH_SIZE,
                device=DEVICE, num_workers=NUM_WORKERS,
            )
        else:
            test_enrichments_0 = model_0.predict_on_x(
                [x[i] for i in eval_slices[0]], batch_size=BATCH_SIZE,
                device=DEVICE,
            )
        
        lower_error_0 = R_0 - R_lb_0
        upper_error_0 = R_ub_0 - R_0
        error_0 = [lower_error_0, upper_error_0]
        
        container = plt.errorbar(
            x=R_0, 
            y=test_enrichments_0,
            xerr=error_0,
            color=colors[0],
            marker='o',
            markersize=3,
            elinewidth=0.75,
            ls='none',
            ecolor='k',
            capsize=1,
            capthick=0.75, 
            zorder=2,
            label='Seed 0',
        ) 
        lines = plt.plot(
            np.linspace(min(test_enrichments_0), max(test_enrichments_0), 100),
            np.linspace(min(test_enrichments_0), max(test_enrichments_0), 100),
            color="#64B5CD",
            label='Parity',
            linewidth=0.75, 
            zorder=3,
        )

    def plot_seed_1():
        R_1, R_lb_1, R_ub_1 = R_ranges(bead_counts[eval_slices[1], 0], bead_tot[0], 
                                       exp_counts[eval_slices[1], 0], exp_tot[0])
        if modelType != 'D-MPNN':
            test_enrichments_1 = model_1.predict_on_x(
                x[eval_slices[1], :], batch_size=BATCH_SIZE,
                device=DEVICE, num_workers=NUM_WORKERS,
            )
        else:
            test_enrichments_1 = model_1.predict_on_x(
                [x[i] for i in eval_slices[1]], batch_size=BATCH_SIZE,
                device=DEVICE,
            )
            
        lower_error_1 = R_1 - R_lb_1
        upper_error_1 = R_ub_1 - R_1
        error_1 = [lower_error_1, upper_error_1]
        container = plt.errorbar(
            x=R_1, 
            y=test_enrichments_1,
            xerr=error_1,
            color=colors[1],
            marker='o',
            markersize=3,
            elinewidth=0.75,
            ls='none',
            ecolor='k',
            capsize=1,
            capthick=0.75, 
            zorder=2,
            label='Seed 1',
        ) 
        lines = plt.plot(
            np.linspace(min(test_enrichments_1), max(test_enrichments_1), 100),
            np.linspace(min(test_enrichments_1), max(test_enrichments_1), 100),
            color="#64B5CD",
            linewidth=0.75, 
            zorder=3,
        )

    def plot_seed_2():
        R_2, R_lb_2, R_ub_2 = R_ranges(bead_counts[eval_slices[2], 0], bead_tot[0], 
                                       exp_counts[eval_slices[2], 0], exp_tot[0])
        if modelType != 'D-MPNN':
            test_enrichments_2 = model_2.predict_on_x(
                x[eval_slices[2], :], batch_size=BATCH_SIZE,
                device=DEVICE, num_workers=NUM_WORKERS,
            )
        else:
            test_enrichments_2 = model_2.predict_on_x(
                [x[i] for i in eval_slices[2]], batch_size=BATCH_SIZE,
                device=DEVICE,
            )
            
        lower_error_2 = R_2 - R_lb_2
        upper_error_2 = R_ub_2 - R_2
        error_2 = [lower_error_2, upper_error_2]
        container = plt.errorbar(
            x=R_2, 
            y=test_enrichments_2,
            xerr=error_2,
            color=colors[2],
            marker='o',
            markersize=3,
            elinewidth=0.75,
            ls='none',
            ecolor='k',
            capsize=1,
            capthick=0.75, 
            zorder=2,
            label='Seed 2',
        ) 
        lines = plt.plot(
            np.linspace(min(test_enrichments_2), max(test_enrichments_2), 100),
            np.linspace(min(test_enrichments_2), max(test_enrichments_2), 100),
            color="#64B5CD",
            linewidth=0.75, 
            zorder=3,
        )

    def plot_seed_3():
        R_3, R_lb_3, R_ub_3 = R_ranges(bead_counts[eval_slices[3], 0], bead_tot[0], 
                                       exp_counts[eval_slices[3], 0], exp_tot[0])
        if modelType != 'D-MPNN':
            test_enrichments_3 = model_3.predict_on_x(
                x[eval_slices[3], :], batch_size=BATCH_SIZE,
                device=DEVICE, num_workers=NUM_WORKERS,
            )
        else:
            test_enrichments_3 = model_3.predict_on_x(
                [x[i] for i in eval_slices[3]], batch_size=BATCH_SIZE,
                device=DEVICE,
            )
            
        lower_error_3 = R_3 - R_lb_3
        upper_error_3 = R_ub_3 - R_3
        error_3 = [lower_error_3, upper_error_3]
        container = plt.errorbar(
            x=R_3, 
            y=test_enrichments_3,
            xerr=error_3,
            color=colors[3],
            marker='o',
            markersize=3,
            elinewidth=0.75,
            ls='none',
            ecolor='k',
            capsize=1,
            capthick=0.75, 
            zorder=2,
            label='Seed 3',
        ) 
        lines = plt.plot(
            np.linspace(min(test_enrichments_3), max(test_enrichments_3), 100),
            np.linspace(min(test_enrichments_3), max(test_enrichments_3), 100),
            color="#64B5CD",
            linewidth=0.75, 
            zorder=3,
        )
    
    def plot_seed_4():
        R_4, R_lb_4, R_ub_4 = R_ranges(bead_counts[eval_slices[4], 0], bead_tot[0], 
                                       exp_counts[eval_slices[4], 0], exp_tot[0])
        if modelType != 'D-MPNN':
            test_enrichments_4 = model_4.predict_on_x(
                x[eval_slices[4], :], batch_size=BATCH_SIZE,
                device=DEVICE, num_workers=NUM_WORKERS,
            )
        else:
            test_enrichments_4 = model_4.predict_on_x(
                [x[i] for i in eval_slices[4]], batch_size=BATCH_SIZE,
                device=DEVICE,
            )
            
        lower_error_4 = R_4 - R_lb_4
        upper_error_4 = R_ub_4 - R_4
        error_4 = [lower_error_4, upper_error_4]
        container = plt.errorbar(
            x=R_4, 
            y=test_enrichments_4,
            xerr=error_4,
            color=colors[4],
            marker='o',
            markersize=3,
            elinewidth=0.75,
            ls='none',
            ecolor='k',
            capsize=1,
            capthick=0.75, 
            zorder=2,
            label='Seed 4',
        ) 
        lines = plt.plot(
            np.linspace(min(test_enrichments_4), max(test_enrichments_4), 100),
            np.linspace(min(test_enrichments_4), max(test_enrichments_4), 100),
            color="#64B5CD",
            linewidth=0.75, 
            zorder=3,
        )

    if modelType == 'FP-FFNN':
        plot_seed_4()
        plot_seed_1()
        plot_seed_3()
        plot_seed_0()
        plot_seed_2()
    elif modelType == 'OH-FFNN':
        plot_seed_4()
        plot_seed_1()
        plot_seed_2()
        plot_seed_3()
        plot_seed_0()
    elif modelType == 'D-MPNN':
        plot_seed_1()
        plot_seed_2()
        plot_seed_0()
    
    fig.canvas.draw() # required to get tick labels
    ax = plt.gca() 
    handles, labels = ax.get_legend_handles_labels()
    if modelType == 'FP-FFNN':
        order = [0, 5, 2, 3, 4, 1]
    elif modelType == 'OH-FFNN':
        order = [0, 5, 1, 2, 4, 3]
    elif modelType == 'D-MPNN':
        order = [0, 3, 1, 2]
    handles = [handles[i] for i in order]
    labels = [labels[i] for i in order]
    
    if legend:
        leg = plt.legend(handles, labels, loc='lower center', bbox_to_anchor = (0.5,-1.2), 
                         numpoints=1, fontsize=7, ncol=6)
        
    if modelType == 'FP-FFNN' and zoomIn1:
        ax.set_xlim([0, 75])
    elif modelType == 'FP-FFNN' and zoomIn2:
        ax.set_xlim([0.36, 2.50])
    elif modelType == 'OH-FFNN' and zoomIn1:
        ax.set_xlim([0, 75]) 
    elif modelType == 'OH-FFNN' and zoomIn2:
        ax.set_xlim([1.19, 1.56])
    elif modelType == 'D-MPNN' and zoomIn1:
        ax.set_xlim([0, 75])
    elif modelType == 'D-MPNN' and zoomIn2:
        ax.set_xlim([0.18, 5.15])
    ax.tick_params(labelsize=8)

    ax.grid(zorder=1)
    ax.set_xlabel('Calculated enrichment', fontsize=8)
    ax.set_ylabel('Predicted enrichment', fontsize=8)
    plt.tight_layout()
    print(plt.axis())
    if legend and modelType == 'D-MPNN':
        plt.savefig(pathify(f'Parity_plot_3_replicates_legend.png'), bbox_extra_artists=(leg,), bbox_inches='tight')
    elif legend and modelType != 'D-MPNN':
        plt.savefig(pathify(f'Parity_plot_5_replicates_legend.png'), bbox_extra_artists=(leg,), bbox_inches='tight')
    elif modelType == 'FP-FFNN' and zoomIn1:
        plt.savefig(pathify(f'Parity_plot_zoomed_in_1_SIRT2_FP-FFNN_seed_0.png'))
    elif modelType == 'FP-FFNN' and zoomIn2:
        plt.savefig(pathify(f'Parity_plot_zoomed_in_2_SIRT2_FP-FFNN_seed_0.png'))
    elif modelType == 'FP-FFNN':
        plt.savefig(pathify(f'Parity_plot_full_SIRT2_FP-FFNN_seed_0.png'))
    elif modelType == 'OH-FFNN' and zoomIn1:
        plt.savefig(pathify(f'Parity_plot_zoomed_in_1_SIRT2_OH-FFNN_seed_0.png'))
    elif modelType == 'OH-FFNN' and zoomIn2: 
        plt.savefig(pathify(f'Parity_plot_zoomed_in_2_SIRT2_OH-FFNN_seed_0.png'))
    elif modelType == 'OH-FFNN':
        plt.savefig(pathify(f'Parity_plot_full_SIRT2_OH-FFNN_seed_0.png'))
    elif modelType == 'D-MPNN' and zoomIn1:
        plt.savefig(pathify(f'Parity_plot_zoomed_in_1_SIRT2_D-MPNN_seed_0.png'))
    elif modelType == 'D-MPNN' and zoomIn2:
        plt.savefig(pathify(f'Parity_plot_zoomed_in_2_SIRT2_D-MPNN_seed_0.png'))
    elif modelType == 'D-MPNN':
        plt.savefig(pathify(f'Parity_plot_full_SIRT2_D-MPNN_seed_0.png'))
    plt.show()

## FP-FFNN

In [None]:
os.environ["HDF5_USE_FILE_LOCKING"] = 'FALSE'
hf = h5py.File(os.path.join(DELQSAR_ROOT, 'experiments', FINGERPRINTS_FILENAME), 'r')
x = np.array(hf['all_fps'])
INPUT_SIZE = x.shape[1]
hf.close()

In [None]:
BATCH_SIZE = 1024
LAYER_SIZES = [
    [1024, 256, 64],
    [1024, 1024, 1024],
    [64, 64, 64],
    [128, 128, 128],
    [128],
]
DROPOUT = [0.1, 0.15, 0.15, 0.2, None]

# seed 0
model_0 = models.MLP(INPUT_SIZE, [int(size) for size in LAYER_SIZES[0]],
                    dropout=DROPOUT[0], torch_seed=0)
model_0.load_state_dict(torch.load(SIRT2_FP_FFNN_MODEL_PATHS[0]))

# seed 1
model_1 = models.MLP(INPUT_SIZE, [int(size) for size in LAYER_SIZES[1]],
                    dropout=DROPOUT[1], torch_seed=1)
model_1.load_state_dict(torch.load(SIRT2_FP_FFNN_MODEL_PATHS[1]))

# seed 2
model_2 = models.MLP(INPUT_SIZE, [int(size) for size in LAYER_SIZES[2]],
                    dropout=DROPOUT[2], torch_seed=2)
model_2.load_state_dict(torch.load(SIRT2_FP_FFNN_MODEL_PATHS[2]))       

# seed 3
model_3 = models.MLP(INPUT_SIZE, [int(size) for size in LAYER_SIZES[3]],
                    dropout=DROPOUT[3], torch_seed=3)
model_3.load_state_dict(torch.load(SIRT2_FP_FFNN_MODEL_PATHS[3]))                      

# seed 4
model_4 = models.MLP(INPUT_SIZE, [int(size) for size in LAYER_SIZES[4]],
                    dropout=DROPOUT[4], torch_seed=4)
model_4.load_state_dict(torch.load(SIRT2_FP_FFNN_MODEL_PATHS[4]))

if DEVICE:
    model_0 = model_0.to(DEVICE)
    model_1 = model_1.to(DEVICE)
    model_2 = model_2.to(DEVICE)
    model_3 = model_3.to(DEVICE)
    model_4 = model_4.to(DEVICE)

In [None]:
draw_predicted_enrichments_vs_true(new_test_slices, 'FP-FFNN')

In [None]:
draw_predicted_enrichments_vs_true(new_test_slices, 'FP-FFNN', zoomIn1=True)

In [None]:
draw_predicted_enrichments_vs_true(new_test_slices, 'FP-FFNN', zoomIn2=True)

## OH-FFNN

In [None]:
featurizer = featurizers.OneHotFeaturizer(df_data)
x = featurizer.prepare_x(df_data)
INPUT_SIZE = x.shape[1]

In [None]:
BATCH_SIZE = 1024
LAYER_SIZES = [
    [1024, 512, 256],
    [512, 256],
    [1024, 256, 64],
    [512, 512, 512],
    [512, 256, 128],
]
DROPOUT = [0.25, 0.35, 0.5, 0.3, 0.5]

# seed 0
model_0 = models.MLP(INPUT_SIZE, [int(size) for size in LAYER_SIZES[0]],
                    dropout=DROPOUT[0], torch_seed=0)
model_0.load_state_dict(torch.load(SIRT2_OH_FFNN_MODEL_PATHS[0]))

# seed 1
model_1 = models.MLP(INPUT_SIZE, [int(size) for size in LAYER_SIZES[1]],
                    dropout=DROPOUT[1], torch_seed=1)
model_1.load_state_dict(torch.load(SIRT2_OH_FFNN_MODEL_PATHS[1]))

# seed 2
model_2 = models.MLP(INPUT_SIZE, [int(size) for size in LAYER_SIZES[2]],
                    dropout=DROPOUT[2], torch_seed=2)
model_2.load_state_dict(torch.load(SIRT2_OH_FFNN_MODEL_PATHS[2]))       

# seed 3
model_3 = models.MLP(INPUT_SIZE, [int(size) for size in LAYER_SIZES[3]],
                    dropout=DROPOUT[3], torch_seed=3)
model_3.load_state_dict(torch.load(SIRT2_OH_FFNN_MODEL_PATHS[3]))                      

# seed 4
model_4 = models.MLP(INPUT_SIZE, [int(size) for size in LAYER_SIZES[4]],
                    dropout=DROPOUT[4], torch_seed=4)
model_4.load_state_dict(torch.load(SIRT2_OH_FFNN_MODEL_PATHS[4]))

if DEVICE:
    model_0 = model_0.to(DEVICE)
    model_1 = model_1.to(DEVICE)
    model_2 = model_2.to(DEVICE)
    model_3 = model_3.to(DEVICE)
    model_4 = model_4.to(DEVICE)

In [None]:
draw_predicted_enrichments_vs_true(new_test_slices, 'OH-FFNN')

In [None]:
draw_predicted_enrichments_vs_true(new_test_slices, 'OH-FFNN', zoomIn1=True)

In [None]:
draw_predicted_enrichments_vs_true(new_test_slices, 'OH-FFNN', zoomIn2=True)

## D-MPNN

In [None]:
smis = df_data['smiles']
targets = R_from_z(bead_counts, bead_tot, exp_counts, exp_tot, 0).tolist()
featurizer = featurizers.GraphFeaturizer(smis, targets)
x = featurizer.prepare_x()

In [None]:
BATCH_SIZE = 50
DEPTH = 6
HIDDEN_SIZE = 1500
FFN_NUM_LAYERS = 3
DROPOUT = 0.05

# seed 0
model_0 = models.MoleculeModel(depth = DEPTH,
                             hidden_size = HIDDEN_SIZE, 
                             ffn_num_layers = FFN_NUM_LAYERS,
                             dropout = DROPOUT, device = DEVICE, 
                             torch_seed=0)
model_0.load_state_dict(torch.load(SIRT2_D_MPNN_MODEL_PATHS[0]))

# seed 1
model_1 = models.MoleculeModel(depth = DEPTH,
                             hidden_size = HIDDEN_SIZE, 
                             ffn_num_layers = FFN_NUM_LAYERS,
                             dropout = DROPOUT, device = DEVICE, 
                             torch_seed=1)
model_1.load_state_dict(torch.load(SIRT2_D_MPNN_MODEL_PATHS[1]))

# seed 2
model_2 = models.MoleculeModel(depth = DEPTH,
                             hidden_size = HIDDEN_SIZE, 
                             ffn_num_layers = FFN_NUM_LAYERS,
                             dropout = DROPOUT, device = DEVICE, 
                             torch_seed=2)
model_2.load_state_dict(torch.load(SIRT2_D_MPNN_MODEL_PATHS[2]))       

if DEVICE:
    model_0 = model_0.to(DEVICE)
    model_1 = model_1.to(DEVICE)
    model_2 = model_2.to(DEVICE)

In [None]:
draw_predicted_enrichments_vs_true(new_test_slices, 'D-MPNN')

In [None]:
draw_predicted_enrichments_vs_true(new_test_slices, 'D-MPNN', zoomIn1=True)

In [None]:
draw_predicted_enrichments_vs_true(new_test_slices, 'D-MPNN', zoomIn2=True)