In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os, sys
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch
import h5py
from itertools import product

In [None]:
DELQSAR_ROOT = os.getcwd() + '/../../'
sys.path += [DELQSAR_ROOT + '/../']

from del_qsar import splitters, models
from del_qsar.enrichments import R_ranges

import matplotlib
import matplotlib.pyplot as plt
plt.style.use('seaborn-paper')
import matplotlib.colors as colors

matplotlib.rc('font', family='sans-serif')
matplotlib.rc('font', serif='Arial')
matplotlib.rc('text', usetex='false')

In [None]:
DD1S_FINGERPRINTS_FILENAME = 'x_DD1S_CAIX_2048_bits_all_fps.h5' # should be in the experiments folder

CAIX_RANDOM_SPLIT_MODEL_PATH = os.path.join(DELQSAR_ROOT, 'experiments', 'models', 'DD1S_CAIX', 
                                          'FP-FFNN', 'random_seed_0.torch')

In [None]:
SEED = 0
torch.manual_seed(SEED)

DEVICE = None
if torch.cuda.is_available():
    DEVICE = 'cuda:0'

In [None]:
def get_cpd_indices(num_cyc1_BBs, num_cyc2_BBs, num_cyc3_BBs):
    cpd_indices = []
    
    for j, k in tqdm(product(range(1, num_cyc2_BBs+1), range(1, num_cyc3_BBs+1))):
        cpd_indices.append(np.squeeze(np.where(df_data['cycle2'].isin([j]) & df_data['cycle3'].isin([k]))))
    for i, k in tqdm(product(range(1, num_cyc1_BBs+1), range(1, num_cyc3_BBs+1))):
        cpd_indices.append(np.squeeze(np.where(df_data['cycle1'].isin([i]) & df_data['cycle3'].isin([k]))))
    for i, j in tqdm(product(range(1, num_cyc1_BBs+1), range(1, num_cyc2_BBs+1))):
        cpd_indices.append(np.squeeze(np.where(df_data['cycle1'].isin([i]) & df_data['cycle2'].isin([j]))))
        
    cpd_indices = np.array(cpd_indices)
    
    return cpd_indices

In [None]:
def get_counts_for_disynthons(num_cyc1_BBs, num_cyc2_BBs, num_cyc3_BBs, exp_col_name, beads_col_name):  
    disynthon_exp_counts, disynthon_bead_counts = [], []

    for j, k in tqdm(product(range(1, num_cyc2_BBs+1), range(1, num_cyc3_BBs+1))):
        disynthon_table = df_data[df_data['cycle2'].isin([j]) & df_data['cycle3'].isin([k])]
        exp_counts_sum = sum(disynthon_table[exp_col_name])
        beads_counts_sum = sum(disynthon_table[beads_col_name])
        disynthon_exp_counts.append(exp_counts_sum)
        disynthon_bead_counts.append(beads_counts_sum)

    for i, k in tqdm(product(range(1, num_cyc1_BBs+1), range(1, num_cyc3_BBs+1))):
        disynthon_table = df_data[df_data['cycle1'].isin([i]) & df_data['cycle3'].isin([k])]
        exp_counts_sum = sum(disynthon_table[exp_col_name])
        beads_counts_sum = sum(disynthon_table[beads_col_name])
        disynthon_exp_counts.append(exp_counts_sum)
        disynthon_bead_counts.append(beads_counts_sum)

    for i, j in tqdm(product(range(1, num_cyc1_BBs+1), range(1, num_cyc2_BBs+1))):
        disynthon_table = df_data[df_data['cycle1'].isin([i]) & df_data['cycle2'].isin([j])]
        exp_counts_sum = sum(disynthon_table[exp_col_name])
        beads_counts_sum = sum(disynthon_table[beads_col_name])
        disynthon_exp_counts.append(exp_counts_sum)
        disynthon_bead_counts.append(beads_counts_sum)
        
    disynthon_exp_counts = np.array(disynthon_exp_counts)
    disynthon_bead_counts = np.array(disynthon_bead_counts)
        
    return disynthon_exp_counts, disynthon_bead_counts

In [None]:
def get_avg_preds(model):
    avg_preds = []
    for i in range(len(cpd_indices)):
        test_enrichments = model.predict_on_x(
                x[cpd_indices[i], :], batch_size=BATCH_SIZE, device=DEVICE,
            )
        avg_preds.append(sum(test_enrichments) / len(test_enrichments))
        
    avg_preds = np.array(avg_preds)
        
    return avg_preds

In [None]:
def draw_predicted_enrichments_vs_true(model, out='', x_ub=None, legend_loc=None):
    R, R_lb, R_ub = R_ranges(disynthon_bead_counts, bead_tot[0], 
                             disynthon_exp_counts, exp_tot[0])
    
    fig = plt.figure(figsize=(4, 3), dpi=300) 

    lower_error0 = R[~disynthon_has_sulfonamide] - R_lb[~disynthon_has_sulfonamide]
    upper_error0 = R_ub[~disynthon_has_sulfonamide] - R[~disynthon_has_sulfonamide]
    error0 = [lower_error0, upper_error0]
    container = plt.errorbar(
        x=R[~disynthon_has_sulfonamide], 
        y=avg_preds[~disynthon_has_sulfonamide],
        xerr=error0,
        label='non-benzenesulfonamides',
        color='#1f77b4', # blue
        marker='o',
        markersize=3, 
        elinewidth=0.75,
        ls='none',
        ecolor='k',
        capsize=1,
        capthick=0.75, 
        zorder=2,
    )
    
    lower_error1 = R[disynthon_has_sulfonamide] - R_lb[disynthon_has_sulfonamide]
    upper_error1 = R_ub[disynthon_has_sulfonamide] - R[disynthon_has_sulfonamide]
    error1 = [lower_error1, upper_error1]
    container = plt.errorbar(
        x=R[disynthon_has_sulfonamide], 
        y=avg_preds[disynthon_has_sulfonamide], 
        xerr=error1,
        label='benzenesulfonamides',
        color='#ff7f0e', # orange
        marker='o',
        markersize=3,
        elinewidth=0.75, 
        ls='none',
        ecolor='k',
        capsize=1,
        capthick=0.75, 
        zorder=3,
    )
    
    lines = plt.plot(
        np.linspace(min(avg_preds), max(avg_preds), 100),
        np.linspace(min(avg_preds), max(avg_preds), 100),
        color='#2ca02c', # green
        label='parity',
        linewidth=0.75, 
        zorder=4,
    )
    
    if legend_loc:
        plt.legend(fontsize=7, loc=legend_loc)
    else:
        plt.legend(fontsize=7)
        
    fig.canvas.draw() # required to get tick labels
    ax = plt.gca() 
    ax.tick_params(labelsize=8)
    if x_ub:
        ax.set_xlim([0, x_ub])
        
    ax.grid(zorder=1)
    ax.set_xlabel('calculated enrichment', fontsize=8)
    ax.set_ylabel('predicted enrichment', fontsize=8)
    plt.tight_layout()
    plt.savefig(pathify(str(out)))
    
    plt.show()

In [None]:
def make_1D_histograms():
    ## First plot: calculated enrichments
    fig = plt.figure(figsize=(3, 1.5), dpi=300)
    R, R_lb, R_ub = R_ranges(disynthon_bead_counts, bead_tot[0], 
                             disynthon_exp_counts, exp_tot[0])

    y0 = R[~disynthon_has_sulfonamide]
    y1 = R[disynthon_has_sulfonamide]
    bins = np.arange(0, 10, 0.2)
    _, bins, patches = plt.hist(
        np.clip(y0, 0, bins[-1]), 
        bins=bins, 
        label='non-benzenesulfonamides',
        density=True,
        zorder=2,
        alpha=0.7,
    )
    
    _, bins, patches = plt.hist(
        np.clip(y1, 0, bins[-1]), 
        bins=bins, 
        label='benzenesulfonamides',
        density=True,
        zorder=3,
        alpha=0.6,
    ) 
    plt.legend(fontsize=7)
    fig.canvas.draw()
    ax = plt.gca()
    ax.tick_params(labelsize=7)
    ax.grid(zorder=1)
    ax.set_xlabel('calculated enrichment (maximum likelihood)', fontsize=8)
    ax.set_ylabel('probability density', fontsize=8)
    plt.tight_layout()
    plt.savefig(pathify(f'disynthon_1D_histogram_DD1S_CAIX_FP-FFNN_random_seed_0_calculated_enrichments.png'))
    plt.show()

    
    ## Second plot: predicted enrichments
    fig = plt.figure(figsize=(3, 1.5), dpi=300)
    
    y0 = avg_preds[~disynthon_has_sulfonamide]
    y1 = avg_preds[disynthon_has_sulfonamide]
    bins = np.arange(0, 10, 0.075)
    _, bins, patches = plt.hist(
        np.clip(y0, 0, bins[-1]), 
        bins=bins, 
        label='non-benzenesulfonamides',
        density=True,
        zorder=2,
        alpha=0.7,
    )
    _, bins, patches = plt.hist(
        np.clip(y1, 0, bins[-1]), 
        bins=bins, 
        label='benzenesulfonamides',
        density=True,
        zorder=3,
        alpha=0.6,
    )
    plt.legend(fontsize=7)
    fig.canvas.draw()
    ax = plt.gca()
    ax.tick_params(labelsize=7)
    ax.set_xlim([0.5, 4.5])
    ax.grid(zorder=1)
    ax.set_xlabel('predicted enrichment', fontsize=8)
    ax.set_ylabel('probability density', fontsize=8)
    plt.tight_layout()
    plt.savefig(pathify(f'disynthon_1D_histogram_DD1S_CAIX_FP-FFNN_random_seed_0_predicted_enrichments.png'))
    plt.show()

In [None]:
df_data = pd.read_csv(os.path.join(DELQSAR_ROOT, 'experiments', 'datasets', 'DD1S_CAIX_QSAR.csv'))

In [None]:
os.environ["HDF5_USE_FILE_LOCKING"] = 'FALSE'
hf = h5py.File(os.path.join(DELQSAR_ROOT, 'experiments', DD1S_FINGERPRINTS_FILENAME), 'r')
x = np.array(hf['all_fps'])
INPUT_SIZE = x.shape[1]
hf.close()

In [None]:
if not os.path.isdir('CAIX_disynthon_plots'):
    os.mkdir('CAIX_disynthon_plots')
def pathify(fname):
    return os.path.join('CAIX_disynthon_plots', fname)

In [None]:
exp_counts = np.array(df_data[['exp_tot']], dtype='int')
bead_counts = np.array(df_data[['beads_tot']], dtype='int')
exp_tot = np.sum(exp_counts, axis=0) # column sums
bead_tot = np.sum(bead_counts, axis=0)

In [None]:
# CAIX random split model
BATCH_SIZE = 1024
LAYER_SIZES = [64, 64, 64]
DROPOUT = 0.1
model = models.MLP(INPUT_SIZE, [int(size) for size in LAYER_SIZES],
                    dropout=DROPOUT, torch_seed=SEED)
model.load_state_dict(torch.load(CAIX_RANDOM_SPLIT_MODEL_PATH))
print(str(model))
if DEVICE:
    model = model.to(DEVICE)

In [None]:
cpd_indices = get_cpd_indices(8, 114, 119)
np.save('DD1S_disynthon_cpd_indices.npy', cpd_indices)

In [None]:
# cpd_indices = np.load('DD1S_disynthon_cpd_indices.npy', allow_pickle=True)

In [None]:
disynthon_has_sulfonamide = [False for i in range(15430)]
for disynthon_idx in range(15430):
    cpds = cpd_indices[disynthon_idx]
    for c in cpds:
        if df_data.iloc[c]['cycle3'] in [14, 74]:
            disynthon_has_sulfonamide[disynthon_idx] = True
            
print(sum(disynthon_has_sulfonamide))
disynthon_has_sulfonamide = np.array(disynthon_has_sulfonamide)
np.save('DD1S_disynthon_has_sulfonamide.npy', disynthon_has_sulfonamide)

In [None]:
# disynthon_has_sulfonamide = np.load('DD1S_disynthon_has_sulfonamide.npy')

In [None]:
disynthon_exp_counts, disynthon_bead_counts = get_counts_for_disynthons(8, 114, 119, 'exp_tot', 'beads_tot')
np.save('DD1S_CAIX_disynthon_exp_counts.npy', disynthon_exp_counts)
np.save('DD1S_CAIX_disynthon_bead_counts.npy', disynthon_bead_counts)

In [None]:
# disynthon_exp_counts = np.load('DD1S_CAIX_disynthon_exp_counts.npy')
# disynthon_bead_counts = np.load('DD1S_CAIX_disynthon_bead_counts.npy')

In [None]:
avg_preds = get_avg_preds(model)
np.save('DD1S_CAIX_disynthon_avg_preds.npy', avg_preds)

In [None]:
# avg_preds = np.load('DD1S_CAIX_disynthon_avg_preds.npy')

In [None]:
draw_predicted_enrichments_vs_true(model, out='DD1S_CAIX_disynthon_parity_plot_with_sulf_coloring', legend_loc='center right')

In [None]:
make_1D_histograms()