In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os, sys
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch
import h5py
from itertools import product

In [None]:
DELQSAR_ROOT = os.getcwd() + '/../../'
sys.path += [DELQSAR_ROOT + '/../']

from del_qsar import splitters, models
from del_qsar.enrichments import R_ranges

import matplotlib
import matplotlib.pyplot as plt
plt.style.use('seaborn-paper')

matplotlib.rc('font', family='sans-serif')
matplotlib.rc('font', serif='Arial')
matplotlib.rc('text', usetex='false')

In [None]:
# DD1S_FINGERPRINTS_FILENAME = 'x_DD1S_CAIX_2048_bits_all_fps.h5' # should be in the experiments folder
triazine_FINGERPRINTS_FILENAME = 'x_triazine_2048_bits_all_fps.h5' # should be in the experiments folder

# CAIX_RANDOM_SPLIT_MODEL_PATH = os.path.join(DELQSAR_ROOT, 'experiments', 'models', 'DD1S_CAIX', 
#                                           'FP-FFNN', 'random_seed_0.torch')
sEH_RANDOM_SPLIT_MODEL_PATH = os.path.join(DELQSAR_ROOT, 'experiments', 'models', 'triazine_sEH', 
                                          'FP-FFNN', 'random_seed_0.torch')
SIRT2_RANDOM_SPLIT_MODEL_PATH = os.path.join(DELQSAR_ROOT, 'experiments', 'models', 'triazine_SIRT2', 
                                            'FP-FFNN', 'random_seed_0.torch')

In [None]:
SEED = 0
torch.manual_seed(SEED)

DEVICE = None
if torch.cuda.is_available():
    DEVICE = 'cuda:0'

In [None]:
def get_cpd_indices(num_cyc1_BBs, num_cyc2_BBs, num_cyc3_BBs):
    cpd_indices = []
    
    for j, k in tqdm(product(range(1, num_cyc2_BBs+1), range(1, num_cyc3_BBs+1))):
        cpd_indices.append(np.squeeze(np.where(df_data['cycle2'].isin([j]) & df_data['cycle3'].isin([k]))))
    for i, k in tqdm(product(range(1, num_cyc1_BBs+1), range(1, num_cyc3_BBs+1))):
        cpd_indices.append(np.squeeze(np.where(df_data['cycle1'].isin([i]) & df_data['cycle3'].isin([k]))))
    for i, j in tqdm(product(range(1, num_cyc1_BBs+1), range(1, num_cyc2_BBs+1))):
        cpd_indices.append(np.squeeze(np.where(df_data['cycle1'].isin([i]) & df_data['cycle2'].isin([j]))))
        
    cpd_indices = np.array(cpd_indices)
    
    return cpd_indices

In [None]:
def get_counts_for_disynthons(num_cyc1_BBs, num_cyc2_BBs, num_cyc3_BBs, exp_col_name, beads_col_name):  
    disynthon_exp_counts, disynthon_bead_counts = [], []

    for j, k in tqdm(product(range(1, num_cyc2_BBs+1), range(1, num_cyc3_BBs+1))):
        disynthon_table = df_data[df_data['cycle2'].isin([j]) & df_data['cycle3'].isin([k])]
        exp_counts_sum = sum(disynthon_table[exp_col_name])
        beads_counts_sum = sum(disynthon_table[beads_col_name])
        disynthon_exp_counts.append(exp_counts_sum)
        disynthon_bead_counts.append(beads_counts_sum)

    for i, k in tqdm(product(range(1, num_cyc1_BBs+1), range(1, num_cyc3_BBs+1))):
        disynthon_table = df_data[df_data['cycle1'].isin([i]) & df_data['cycle3'].isin([k])]
        exp_counts_sum = sum(disynthon_table[exp_col_name])
        beads_counts_sum = sum(disynthon_table[beads_col_name])
        disynthon_exp_counts.append(exp_counts_sum)
        disynthon_bead_counts.append(beads_counts_sum)

    for i, j in tqdm(product(range(1, num_cyc1_BBs+1), range(1, num_cyc2_BBs+1))):
        disynthon_table = df_data[df_data['cycle1'].isin([i]) & df_data['cycle2'].isin([j])]
        exp_counts_sum = sum(disynthon_table[exp_col_name])
        beads_counts_sum = sum(disynthon_table[beads_col_name])
        disynthon_exp_counts.append(exp_counts_sum)
        disynthon_bead_counts.append(beads_counts_sum)
        
    disynthon_exp_counts = np.array(disynthon_exp_counts)
    disynthon_bead_counts = np.array(disynthon_bead_counts)
        
    return disynthon_exp_counts, disynthon_bead_counts

In [None]:
def get_avg_preds(model):
    avg_preds = []
    for i in range(len(cpd_indices)):
        test_enrichments = model.predict_on_x(
                x[cpd_indices[i], :], batch_size=BATCH_SIZE, device=DEVICE,
            )
        avg_preds.append(sum(test_enrichments) / len(test_enrichments))
        
    avg_preds = np.array(avg_preds)
        
    return avg_preds

In [None]:
def draw_predicted_enrichments_vs_true(model, out='', x_ub=None, legend_loc=None):
    R, R_lb, R_ub = R_ranges(disynthon_bead_counts, bead_tot[0], 
                             disynthon_exp_counts, exp_tot[0])
    
    fig = plt.figure(figsize=(3.33, 2), dpi=300) 

    lower_error = R - R_lb
    upper_error = R_ub - R
    error = [lower_error, upper_error]
    container = plt.errorbar(
        x=R, 
        y=avg_preds,
        xerr=error,
        color='#1f77b4', # blue
        label='disynthon',
        marker='o',
        markersize=3, 
        elinewidth=0.75,
        ls='none',
        ecolor='k',
        capsize=1,
        capthick=0.75, 
        zorder=2,
    ) 
    
    lines = plt.plot(
        np.linspace(min(avg_preds), max(avg_preds), 100),
        np.linspace(min(avg_preds), max(avg_preds), 100),
        color='#2ca02c', # green
        label='parity',
        linewidth=0.75, 
        zorder=3,
    )
    
    if legend_loc:
        plt.legend(fontsize=7, loc=legend_loc)
    else:
        plt.legend(fontsize=7)
        
    fig.canvas.draw()
    ax = plt.gca() 
    ax.tick_params(labelsize=8)
    if x_ub:
        ax.set_xlim([0, x_ub])
        
    ax.grid(zorder=1)
    ax.set_xlabel('calculated enrichment', fontsize=8)
    ax.set_ylabel('predicted enrichment', fontsize=8)
    plt.tight_layout()
    plt.savefig(pathify(str(out)))
    
    plt.show()

# (DD1S CAIX)

In [None]:
# df_data = pd.read_csv(os.path.join(DELQSAR_ROOT, 'experiments', 'datasets', 'DD1S_CAIX_QSAR.csv'))

In [None]:
# os.environ["HDF5_USE_FILE_LOCKING"] = 'FALSE'
# hf = h5py.File(os.path.join(DELQSAR_ROOT, 'experiments', DD1S_FINGERPRINTS_FILENAME), 'r')
# x = np.array(hf['all_fps'])
# INPUT_SIZE = x.shape[1]
# hf.close()

In [None]:
# if not os.path.isdir('CAIX_disynthon_plots'):
#     os.mkdir('CAIX_disynthon_plots')
# def pathify(fname):
#     return os.path.join('CAIX_disynthon_plots', fname)

In [None]:
# exp_counts = np.array(df_data[['exp_tot']], dtype='int')
# bead_counts = np.array(df_data[['beads_tot']], dtype='int')
# exp_tot = np.sum(exp_counts, axis=0) # column sums
# bead_tot = np.sum(bead_counts, axis=0)

In [None]:
# CAIX random split model
# BATCH_SIZE = 1024
# LAYER_SIZES = [64, 64, 64]
# DROPOUT = 0.1
# model_CAIX = models.MLP(INPUT_SIZE, [int(size) for size in LAYER_SIZES],
#                     dropout=DROPOUT, torch_seed=SEED)
# model_CAIX.load_state_dict(torch.load(CAIX_RANDOM_SPLIT_MODEL_PATH))
# print(str(model_CAIX))
# if DEVICE:
#     model_CAIX = model_CAIX.to(DEVICE)

In [None]:
# cpd_indices = get_cpd_indices(8, 114, 119)
# np.save('DD1S_disynthon_cpd_indices.npy', cpd_indices)

In [None]:
## cpd_indices = np.load('DD1S_disynthon_cpd_indices.npy', allow_pickle=True)

In [None]:
# disynthon_exp_counts, disynthon_bead_counts = get_counts_for_disynthons(8, 114, 119, 'exp_tot', 'beads_tot')
# np.save('DD1S_CAIX_disynthon_exp_counts.npy', disynthon_exp_counts)
# np.save('DD1S_CAIX_disynthon_bead_counts.npy', disynthon_bead_counts)

In [None]:
## disynthon_exp_counts = np.load('DD1S_CAIX_disynthon_exp_counts.npy')
## disynthon_bead_counts = np.load('DD1S_CAIX_disynthon_bead_counts.npy')

In [None]:
# avg_preds = get_avg_preds(model_CAIX)
# np.save('DD1S_CAIX_disynthon_avg_preds.npy', avg_preds)

In [None]:
## avg_preds = np.load('DD1S_CAIX_disynthon_avg_preds.npy')

In [None]:
# draw_predicted_enrichments_vs_true(model_CAIX, out='DD1S_CAIX_disynthon_parity_plot', legend_loc='lower right')

# Triazine sEH

In [None]:
df_data = pd.read_csv(os.path.join(DELQSAR_ROOT, 'experiments', 'datasets', 'triazine_lib_sEH_SIRT2_QSAR.csv'))

In [None]:
os.environ["HDF5_USE_FILE_LOCKING"] = 'FALSE'
hf = h5py.File(os.path.join(DELQSAR_ROOT, 'experiments', triazine_FINGERPRINTS_FILENAME), 'r')
x = np.array(hf['all_fps'])
INPUT_SIZE = x.shape[1]
hf.close()

In [None]:
if not os.path.isdir('sEH_disynthon_plots'):
    os.mkdir('sEH_disynthon_plots')
def pathify(fname):
    return os.path.join('sEH_disynthon_plots', fname)

In [None]:
exp_counts = np.array(df_data[['sEH [strep]_tot']], dtype='int')
bead_counts = np.array(df_data[['beads-linker-only [strep]_tot']], dtype='int')
exp_tot = np.sum(exp_counts, axis=0) # column sums
bead_tot = np.sum(bead_counts, axis=0)

In [None]:
# sEH random split model
BATCH_SIZE = 1024
LAYER_SIZES = [256, 128, 64]
DROPOUT = 0.4
model_sEH = models.MLP(INPUT_SIZE, [int(size) for size in LAYER_SIZES],
                    dropout=DROPOUT, torch_seed=SEED)
model_sEH.load_state_dict(torch.load(sEH_RANDOM_SPLIT_MODEL_PATH))
print(str(model_sEH))
if DEVICE:
    model_sEH = model_sEH.to(DEVICE)

In [None]:
cpd_indices = get_cpd_indices(78, 290, 250)
np.save('triazine_disynthon_cpd_indices.npy', cpd_indices)

In [None]:
# cpd_indices = np.load('triazine_disynthon_cpd_indices.npy', allow_pickle=True)

In [None]:
disynthon_exp_counts, disynthon_bead_counts = get_counts_for_disynthons(78, 290, 250, 
                                                                        'sEH [strep]_tot', 
                                                                        'beads-linker-only [strep]_tot')
np.save('sEH_disynthon_exp_counts.npy', disynthon_exp_counts)
np.save('sEH_disynthon_bead_counts.npy', disynthon_bead_counts)

In [None]:
# disynthon_exp_counts = np.load('sEH_disynthon_exp_counts.npy')
# disynthon_bead_counts = np.load('sEH_disynthon_bead_counts.npy')

In [None]:
avg_preds = get_avg_preds(model_sEH)
np.save('sEH_disynthon_avg_preds.npy', avg_preds)

In [None]:
# avg_preds = np.load('sEH_disynthon_avg_preds.npy')

In [None]:
draw_predicted_enrichments_vs_true(model_sEH, out='sEH_disynthon_parity_plot')

In [None]:
draw_predicted_enrichments_vs_true(model_sEH, out='sEH_disynthon_parity_plot_zoom_in_800', x_ub=800)

# Triazine SIRT2

In [None]:
df_data = pd.read_csv(os.path.join(DELQSAR_ROOT, 'experiments', 'datasets', 'triazine_lib_sEH_SIRT2_QSAR.csv'))

In [None]:
os.environ["HDF5_USE_FILE_LOCKING"] = 'FALSE'
hf = h5py.File(os.path.join(DELQSAR_ROOT, 'experiments', triazine_FINGERPRINTS_FILENAME), 'r')
x = np.array(hf['all_fps'])
INPUT_SIZE = x.shape[1]
hf.close()

In [None]:
if not os.path.isdir('SIRT2_disynthon_plots'):
    os.mkdir('SIRT2_disynthon_plots')
def pathify(fname):
    return os.path.join('SIRT2_disynthon_plots', fname)

In [None]:
exp_counts = np.array(df_data[['SIRT2 [strep]_tot']], dtype='int')
bead_counts = np.array(df_data[['beads-linker-only [strep]_tot']], dtype='int')
exp_tot = np.sum(exp_counts, axis=0) # column sums
bead_tot = np.sum(bead_counts, axis=0)

In [None]:
# SIRT2 random split model
BATCH_SIZE = 1024
LAYER_SIZES = [256, 128, 64]
DROPOUT = 0.1
model_SIRT2 = models.MLP(INPUT_SIZE, [int(size) for size in LAYER_SIZES],
                    dropout=DROPOUT, torch_seed=SEED)
model_SIRT2.load_state_dict(torch.load(SIRT2_RANDOM_SPLIT_MODEL_PATH))
print(str(model_SIRT2))
if DEVICE:
    model_SIRT2 = model_SIRT2.to(DEVICE)

In [None]:
cpd_indices = get_cpd_indices(78, 290, 250) 
np.save('triazine_disynthon_cpd_indices.npy', cpd_indices)

In [None]:
# cpd_indices = np.load('triazine_disynthon_cpd_indices.npy', allow_pickle=True)

In [None]:
avg_preds = get_avg_preds(model_SIRT2)
np.save('SIRT2_disynthon_avg_preds.npy', avg_preds)

In [None]:
# avg_preds = np.load('SIRT2_disynthon_avg_preds.npy')

In [None]:
disynthon_exp_counts, disynthon_bead_counts = get_counts_for_disynthons(78, 290, 250, 
                                                                        'SIRT2 [strep]_tot', 
                                                                        'beads-linker-only [strep]_tot')
np.save('SIRT2_disynthon_exp_counts.npy', disynthon_exp_counts)
np.save('SIRT2_disynthon_bead_counts.npy', disynthon_bead_counts)

In [None]:
# disynthon_exp_counts = np.load('SIRT2_disynthon_exp_counts.npy')
# disynthon_bead_counts = np.load('SIRT2_disynthon_bead_counts.npy')

In [None]:
draw_predicted_enrichments_vs_true(model_SIRT2, out='SIRT2_disynthon_parity_plot')

In [None]:
draw_predicted_enrichments_vs_true(model_SIRT2, out='SIRT2_disynthon_parity_plot_zoom_in_100', x_ub=100)

In [None]:
draw_predicted_enrichments_vs_true(model_SIRT2, out='SIRT2_disynthon_parity_plot_zoom_in_50', x_ub=50)