In [None]:
%load_ext autoreload
%autoreload 2

In [1]:
import os, sys
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch
import h5py

In [2]:
DELQSAR_ROOT = os.getcwd() + '/../../'
sys.path += [DELQSAR_ROOT + '/../']

In [None]:
FINGERPRINTS_FILENAME = 'x_triazine_2048_bits_all_fps.h5' # should be in the experiments folder

sEH_RANDOM_SPLIT_MODEL_PATH = os.path.join(DELQSAR_ROOT, 'experiments', 'models', 'triazine_sEH', 
                                          'FP-FFNN', 'random_seed_0.torch')

SIRT2_RANDOM_SPLIT_MODEL_PATH = os.path.join(DELQSAR_ROOT, 'experiments', 'models', 'triazine_SIRT2', 
                                            'FP-FFNN', 'random_seed_0.torch')

In [3]:
from del_qsar import models, featurizers, splitters
from del_qsar.enrichments import R_from_z, R_ranges

if not os.path.isdir('test_triazine_parity_plots'):
    os.mkdir('test_triazine_parity_plots')
    
def pathify(fname):
    return os.path.join('triazine_parity_plots', fname)

LOG_FILE = os.path.join(DELQSAR_ROOT, 'experiments', 'visualizations', 
                        'triazine_parity_plots', 'triazine_parity_plots.log')

In [None]:
import matplotlib
import matplotlib.pyplot as plt
plt.style.use('seaborn-paper')

matplotlib.rc('font', family='sans-serif') 
matplotlib.rc('font', serif='Arial') 
matplotlib.rc('text', usetex='false') 
matplotlib.rcParams.update({'font.size': 9})

In [None]:
df_data = pd.read_csv(os.path.join(DELQSAR_ROOT, 'experiments', 'datasets', 'triazine_lib_sEH_SIRT2_QSAR.csv'))

In [None]:
os.environ["HDF5_USE_FILE_LOCKING"] = 'FALSE'
hf = h5py.File(os.path.join(DELQSAR_ROOT, 'experiments', FINGERPRINTS_FILENAME), 'r')
x = np.array(hf['all_fps'])
INPUT_SIZE = x.shape[1]
hf.close()

In [None]:
SEED = 0
torch.manual_seed(SEED)

In [None]:
splitter = splitters.RandomSplitter()
train_slice, valid_slice, test_slice  = splitter(x, df_data, seed=SEED)

In [None]:
# sEH random split model
BATCH_SIZE = 1024
LAYER_SIZES = [256, 128, 64]
DROPOUT = 0.4
model_sEH = models.MLP(INPUT_SIZE, [int(size) for size in LAYER_SIZES],
                    dropout=DROPOUT, torch_seed=SEED)
model_sEH.load_state_dict(torch.load(sEH_RANDOM_SPLIT_MODEL_PATH))
print(str(model_sEH))

In [None]:
# SIRT2 random split model
BATCH_SIZE = 1024
LAYER_SIZES = [256, 128, 64]
DROPOUT = 0.1
model_SIRT2 = models.MLP(INPUT_SIZE, [int(size) for size in LAYER_SIZES],
                    dropout=DROPOUT, torch_seed=SEED)
model_SIRT2.load_state_dict(torch.load(SIRT2_RANDOM_SPLIT_MODEL_PATH))
print(str(model_SIRT2))

In [None]:
DEVICE = None
if torch.cuda.is_available():
    DEVICE = 'cuda:0'
    model_SIRT2 = model_SIRT2.to('cuda:0')
    model_sEH = model_sEH.to('cuda:0')

In [None]:
def draw_predicted_enrichments_vs_true(model, eval_slice, out='', x_ub=None):
    R, R_lb, R_ub = R_ranges(bead_counts[eval_slice, 0], bead_tot[0], exp_counts[eval_slice, 0], exp_tot[0])
    test_enrichments = model.predict_on_x(
        x[eval_slice, :], batch_size=BATCH_SIZE, device=DEVICE,
    )
    
    fig = plt.figure(figsize=(3.33, 2), dpi=300) 

    lower_error = R - R_lb
    upper_error = R_ub - R
    error = [lower_error, upper_error]
    container = plt.errorbar(
        x=R, 
        y=test_enrichments,
        xerr=error,
        color='#1f77b4', # blue
        marker='o',
        markersize=3, 
        elinewidth=0.75,
        ls='none',
        ecolor='k',
        capsize=1,
        capthick=0.75, 
        zorder=2,
    ) 
    
    lines = plt.plot(
        np.linspace(min(test_enrichments), max(test_enrichments), 100),
        np.linspace(min(test_enrichments), max(test_enrichments), 100),
        color='#2ca02c', # green
        label='Parity',
        linewidth=0.75, 
        zorder=3,
    )
    
    if x_ub: 
        plt.legend(fontsize=7, loc='lower right')
    else:
        plt.legend(fontsize=7)
        
    fig.canvas.draw() # required to get tick labels
    ax = plt.gca() 
    ax.tick_params(labelsize=8)
    if x_ub:
        ax.set_xlim([0, x_ub])
        
    ax.grid(zorder=1)
    ax.set_xlabel('Calculated enrichment', fontsize=8)
    ax.set_ylabel('Predicted enrichment', fontsize=8)
    plt.tight_layout()
    plt.savefig(pathify(str(out)))
    
    plt.show()

# sEH FP-FFNN

In [None]:
exp_counts = np.array(df_data[['sEH [strep]_tot']], dtype='int')
bead_counts = np.array(df_data[['beads-linker-only [strep]_tot']], dtype='int')
exp_tot = np.sum(exp_counts, axis=0) # column sums
bead_tot = np.sum(bead_counts, axis=0)

In [None]:
# take random subset of test_slice
test_slice_indices = np.arange(int(len(test_slice)))
np.random.seed(5)
np.random.shuffle(test_slice_indices)
test_slice_subset = [test_slice[i] for i in test_slice_indices[:20000]]

In [None]:
draw_predicted_enrichments_vs_true(model_sEH, test_slice_subset, out='Parity scatter plot_full_triazine_sEH_FP-FFNN_random_seed_0.png')

In [None]:
draw_predicted_enrichments_vs_true(model_sEH, test_slice_subset, x_ub=400, out='Parity scatter plot_zoomed in_triazine_sEH_FP-FFNN_random_seed_0.png')

# SIRT2 FP-FFNN

In [None]:
exp_counts = np.array(df_data[['SIRT2 [strep]_tot']], dtype='int')
bead_counts = np.array(df_data[['beads-linker-only [strep]_tot']], dtype='int')
exp_tot = np.sum(exp_counts, axis=0) # column sums
bead_tot = np.sum(bead_counts, axis=0)

In [None]:
# take random subset of test_slice
test_slice_indices = np.arange(int(len(test_slice)))
np.random.seed(5)
np.random.shuffle(test_slice_indices)
test_slice_subset = [test_slice[i] for i in test_slice_indices[:20000]]

In [None]:
draw_predicted_enrichments_vs_true(model_SIRT2, test_slice_subset, out='Parity scatter plot_full_triazine_SIRT2_FP-FFNN_random_seed_0.png')

In [None]:
draw_predicted_enrichments_vs_true(model_SIRT2, test_slice_subset, x_ub=90, out='Parity scatter plot_zoomed in_triazine_SIRT2_FP-FFNN_random_seed_0.png')