In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os, sys
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch
import h5py
import copy

In [None]:
DELQSAR_ROOT = os.getcwd() + '/../../'
sys.path += [DELQSAR_ROOT + '/../']

In [None]:
FINGERPRINTS_FILENAME = 'x_triazine_2048_bits_all_fps.h5' # should be in the experiments folder

sEH_RANDOM_SPLIT_MODEL_PATH = os.path.join(DELQSAR_ROOT, 'experiments', 'models', 'triazine_sEH', 
                                           'FP-FFNN', 'random_seed_0.torch')

sEH_CYCLE123_SPLIT_MODEL_PATH = os.path.join(DELQSAR_ROOT, 'experiments', 'models', 'triazine_sEH', 
                                           'FP-FFNN', 'cycle123_seed_0.torch')

SIRT2_RANDOM_SPLIT_MODEL_PATH = os.path.join(DELQSAR_ROOT, 'experiments', 'models', 'triazine_SIRT2', 
                                           'FP-FFNN', 'random_seed_0.torch')

SIRT2_CYCLE123_SPLIT_MODEL_PATH = os.path.join(DELQSAR_ROOT, 'experiments', 'models', 'triazine_SIRT2', 
                                           'FP-FFNN', 'cycle123_seed_0.torch')

In [None]:
from del_qsar import models, featurizers, splitters
from del_qsar.enrichments import R_from_z, R_ranges

if not os.path.isdir('triazine_2D_histograms'):
    os.mkdir('triazine_2D_histograms')
    
def pathify(fname):
    return os.path.join('triazine_2D_histograms', fname)

LOG_FILE = os.path.join(DELQSAR_ROOT, 'experiments', 'visualizations',
                        'triazine_2D_histograms', 'triazine_2D_histograms.log')

In [None]:
import matplotlib
import matplotlib.pyplot as plt
plt.style.use('seaborn-paper')
import matplotlib.colors as colors

matplotlib.rc('font', family='sans-serif') 
matplotlib.rc('font', serif='Arial') 
matplotlib.rc('text', usetex='false')

In [None]:
if not os.path.isdir(os.path.join(DELQSAR_ROOT, 'experiments', 'visualizations',
                                  'triazine_2D_histograms')):
    os.mkdir(os.path.join(DELQSAR_ROOT, 'experiments', 'visualizations',
                          'triazine_2D_histograms'))
def pathify(fname):
    return os.path.join(os.path.join(DELQSAR_ROOT, 'experiments', 'visualizations',
                        'triazine_2D_histograms'), fname)

In [None]:
df_data = pd.read_csv(os.path.join(DELQSAR_ROOT, 'experiments', 'datasets', 
                                   'triazine_lib_sEH_SIRT2_QSAR.csv'))

os.environ["HDF5_USE_FILE_LOCKING"] = 'FALSE'
hf = h5py.File(os.path.join(DELQSAR_ROOT, 'experiments', FINGERPRINTS_FILENAME), 'r')
x = np.array(hf['all_fps'])
INPUT_SIZE = x.shape[1]
hf.close()

In [None]:
BATCH_SIZE = 1024
SEED = 0
torch.manual_seed(SEED)

In [None]:
splitter_rand = splitters.RandomSplitter()
train_slice_rand, valid_slice_rand, test_slice_rand  = splitter_rand(x, df_data, seed=SEED)

In [None]:
LAYER_SIZES = [256, 128, 64]
DROPOUT = 0.4
model_sEH_rand = models.MLP(INPUT_SIZE, [int(size) for size in LAYER_SIZES],
                    dropout=DROPOUT, torch_seed=SEED)
model_sEH_rand.load_state_dict(torch.load(sEH_RANDOM_SPLIT_MODEL_PATH))
print(str(model_sEH_rand))

In [None]:
LAYER_SIZES = [256, 128, 64]
DROPOUT = 0.1
model_SIRT2_rand = models.MLP(INPUT_SIZE, [int(size) for size in LAYER_SIZES],
                    dropout=DROPOUT, torch_seed=SEED)
model_SIRT2_rand.load_state_dict(torch.load(SIRT2_RANDOM_SPLIT_MODEL_PATH))
print(str(model_SIRT2_rand))

In [None]:
splitter_c123 = splitters.ThreeCycleSplitter(['cycle1','cycle2','cycle3'], LOG_FILE)
train_slice_c123, valid_slice_c123, test_slice_c123  = splitter_c123(x, df_data, seed=SEED)

In [None]:
LAYER_SIZES = [1024, 256, 64]
DROPOUT = 0.45
model_sEH_c123 = models.MLP(INPUT_SIZE, [int(size) for size in LAYER_SIZES],
                    dropout=DROPOUT, torch_seed=SEED)
model_sEH_c123.load_state_dict(torch.load(sEH_CYCLE123_SPLIT_MODEL_PATH))
print(str(model_sEH_c123))

In [None]:
LAYER_SIZES = [1024, 256, 64]
DROPOUT = 0.1
model_SIRT2_c123 = models.MLP(INPUT_SIZE, [int(size) for size in LAYER_SIZES],
                    dropout=DROPOUT, torch_seed=SEED)
model_SIRT2_c123.load_state_dict(torch.load(SIRT2_CYCLE123_SPLIT_MODEL_PATH))
print(str(model_SIRT2_c123))

In [None]:
DEVICE = None
if torch.cuda.is_available():
    DEVICE = 'cuda:0'
    model_sEH_rand = model_sEH_rand.to('cuda:0')
    model_sEH_c123 = model_sEH_c123.to('cuda:0')
    model_SIRT2_rand = model_SIRT2_rand.to('cuda:0')
    model_SIRT2_c123 = model_SIRT2_c123.to('cuda:0')

# sEH FP-FFNN 

In [None]:
exp_counts = np.array(df_data[['sEH [strep]_tot']], dtype='int')
bead_counts = np.array(df_data[['beads-linker-only [strep]_tot']], dtype='int')
exp_tot = np.sum(exp_counts, axis=0) # column sums
bead_tot = np.sum(bead_counts, axis=0)

In [None]:
def make_2D_histograms_sEH(eval_slice, model, split, zoomIn, pruneLowRawCounts=False):
    _R, _R_lb, _R_ub = R_ranges(bead_counts[eval_slice, 0], bead_tot[0], exp_counts[eval_slice, 0], exp_tot[0])
    _test_enrichments = model.predict_on_x(
        x[eval_slice, :], batch_size=BATCH_SIZE, device=DEVICE,
    )
    R, R_lb, R_ub, test_enrichments = [], [], [], []
    if pruneLowRawCounts:
        for i in range(len(eval_slice)):
            if (df_data.iloc[eval_slice[i]]['sEH [strep]_tot'] + df_data.iloc[eval_slice[i]]['beads-linker-only [strep]_tot']) >= 3:
                R.append(_R[i])
                R_lb.append(_R_lb[i])
                R_ub.append(_R_ub[i])
                test_enrichments.append(_test_enrichments[i])
        print(min(test_enrichments), max(test_enrichments))
    else:
        R = _R
        R_lb = _R_lb
        R_ub = _R_ub
        test_enrichments = _test_enrichments
    y1 = test_enrichments
    my_cmap = copy.copy(matplotlib.cm.get_cmap('viridis'))
    my_cmap.set_bad("#CFCFCF") # color zero frequency values as gray

    # maximum likelihood
    fig = plt.figure(figsize=(3.33, 2.82), dpi=300)
    y0 = R
    if zoomIn:
        bins = [np.arange(0, 10.001, 0.15),np.arange(0, 14.001, 0.21)]
    else:
        bins = [np.arange(0, 1543, 23.145),np.arange(0, 523, 7.845)]
    plt.hist2d(
        np.clip(y0, 0, bins[0][-1]), 
        np.clip(y1, 0, bins[1][-1]), 
        bins=bins, 
        density=False,
        norm=colors.LogNorm(),
        cmap=my_cmap,
    )
    cb = plt.colorbar()
    cb.ax.tick_params(labelsize=7, length=3, pad=0.5)
    cb.ax.set_ylabel('frequency', rotation=270, fontsize=8, labelpad=8)
    fig.canvas.draw()
    ax = plt.gca()
    ax.tick_params(labelsize=9)
    ax.set_xlabel('calculated enrichment\n(maximum likelihood)', fontsize=9)
    ax.set_ylabel('predicted enrichment', fontsize=9)
    plt.tight_layout()
    if zoomIn and not pruneLowRawCounts:
        plt.savefig(pathify(f'2D_histogram_maximum_likelihood_zoomed_in_sEH_FP-FFNN_{split}_seed_0.png'))
    elif zoomIn and pruneLowRawCounts:
        plt.savefig(pathify(f'2D_histogram_maximum_likelihood_zoomed_in_sEH_FP-FFNN_{split}_seed_0_low_counts_pruned.png'))
    elif not zoomIn and not pruneLowRawCounts: 
        plt.savefig(pathify(f'2D_histogram_maximum_likelihood_full_sEH_FP-FFNN_{split}_seed_0.png'))
    plt.show()


    # lower bound
    fig = plt.figure(figsize=(3.33, 2.82), dpi=300)
    y0 = R_lb
    if zoomIn:
        bins = [np.arange(0, 4.001, 0.06),np.arange(0, 14.001, 0.21)]
    else:
        bins = [np.arange(0, 222, 1.665),np.arange(0, 523, 7.845)]
    plt.hist2d(
        np.clip(y0, 0, bins[0][-1]), 
        np.clip(y1, 0, bins[1][-1]), 
        bins=bins, 
        density=False,
        norm=colors.LogNorm(),
        cmap=my_cmap,
    )
    cb = plt.colorbar()
    cb.ax.tick_params(labelsize=7, length=3, pad=0.5)
    cb.ax.set_ylabel('frequency', rotation=270, fontsize=8, labelpad=8)
    fig.canvas.draw()
    ax = plt.gca()
    ax.tick_params(labelsize=9)
    ax.set_xlabel('calculated enrichment\n(lower bound)', fontsize=9)
    ax.set_ylabel('predicted enrichment', fontsize=9)
    plt.tight_layout()
    if zoomIn and not pruneLowRawCounts:
        plt.savefig(pathify(f'2D_histogram_LB_zoomed_in_sEH_FP-FFNN_{split}_seed_0.png'))
    elif zoomIn and pruneLowRawCounts:
        plt.savefig(pathify(f'2D_histogram_LB_zoomed_in_sEH_FP-FFNN_{split}_seed_0_low_counts_pruned.png'))
    elif not zoomIn and not pruneLowRawCounts:
        plt.savefig(pathify(f'2D_histogram_LB_full_sEH_FP-FFNN_{split}_seed_0.png'))
    plt.show()


    # upper bound
    fig = plt.figure(figsize=(3.33, 2.82), dpi=300)
    y0 = R_ub
    if zoomIn:
        bins = [np.arange(0, 10.001, 0.15),np.arange(0, 14.001, 0.21)]
    else:
        bins = [np.arange(0, 33955, 509.325),np.arange(0, 523, 7.845)]
    plt.hist2d(
        np.clip(y0, 0, bins[0][-1]), 
        np.clip(y1, 0, bins[1][-1]), 
        bins=bins, 
        density=False,
        norm=colors.LogNorm(),
        cmap=my_cmap,
    )
    cb = plt.colorbar()
    cb.ax.tick_params(labelsize=7, length=3, pad=0.5)
    cb.ax.set_ylabel('frequency', rotation=270, fontsize=8, labelpad=8)
    fig.canvas.draw()
    ax = plt.gca()
    ax.tick_params(labelsize=9)
    ax.set_xlabel('calculated enrichment\n(upper bound)', fontsize=9)
    ax.set_ylabel('predicted enrichment', fontsize=9)
    plt.tight_layout()
    if zoomIn and not pruneLowRawCounts:
        plt.savefig(pathify(f'2D_histogram_UB_zoomed_in_sEH_FP-FFNN_{split}_seed_0.png'))
    elif zoomIn and pruneLowRawCounts:
        plt.savefig(pathify(f'2D_histogram_UB_zoomed_in_sEH_FP-FFNN_{split}_seed_0_low_counts_pruned.png'))
    elif not zoomIn and not pruneLowRawCounts:
        plt.savefig(pathify(f'2D_histogram_UB_full_sEH_FP-FFNN_{split}_seed_0.png'))
    plt.show()

In [None]:
make_2D_histograms_sEH(test_slice_rand, model_sEH_rand, 'random', zoomIn=True, pruneLowRawCounts=True)

In [None]:
make_2D_histograms_sEH(test_slice_c123, model_sEH_c123, 'cycle 1+2+3', zoomIn=True, pruneLowRawCounts=True)

# SIRT2 FP-FFNN

In [None]:
exp_counts = np.array(df_data[['SIRT2 [strep]_tot']], dtype='int')
bead_counts = np.array(df_data[['beads-linker-only [strep]_tot']], dtype='int')
exp_tot = np.sum(exp_counts, axis=0) # column sums
bead_tot = np.sum(bead_counts, axis=0)

In [None]:
SIRT2_random_indices = []
SIRT2_c123_indices = []
def make_2D_histograms_SIRT2(eval_slice, model, split, zoomIn, pruneLowRawCounts=False):
    _R, _R_lb, _R_ub = R_ranges(bead_counts[eval_slice, 0], bead_tot[0], exp_counts[eval_slice, 0], exp_tot[0])
    _test_enrichments = model.predict_on_x(
        x[eval_slice, :], batch_size=BATCH_SIZE, device=DEVICE,
    )
    R, R_lb, R_ub, test_enrichments = [], [], [], []
    if pruneLowRawCounts:
        for i in range(len(eval_slice)):
            if df_data.iloc[eval_slice[i]]['SIRT2 [strep]_tot'] + df_data.iloc[eval_slice[i]]['beads-linker-only [strep]_tot'] >= 3:
                R.append(_R[i])
                R_lb.append(_R_lb[i])
                R_ub.append(_R_ub[i])
                test_enrichments.append(_test_enrichments[i])
        print(min(test_enrichments), max(test_enrichments))
    else:
        R = _R
        R_lb = _R_lb
        R_ub = _R_ub
        test_enrichments = _test_enrichments
    y1 = test_enrichments
    my_cmap = copy.copy(matplotlib.cm.get_cmap('viridis'))
    my_cmap.set_bad("#CFCFCF") # color zero frequency values as gray

    # maximum likelihood
    fig = plt.figure(figsize=(3.33, 2.82), dpi=300)
    y0 = R
    if zoomIn:
        bins = [np.arange(0, 10.001, 0.15),np.arange(0, 20, 0.3)]
    else:
        bins = [np.arange(0, 146, 4.38),np.arange(0, 20, 0.6)]
    plt.hist2d(
        np.clip(y0, 0, bins[0][-1]), 
        np.clip(y1, 0, bins[1][-1]), 
        bins=bins, 
        density=False,
        norm=colors.LogNorm(),
        cmap=my_cmap,
    )
    cb = plt.colorbar()
    cb.ax.tick_params(labelsize=7, length=3, pad=0.5)
    cb.ax.set_ylabel('frequency', rotation=270, fontsize=8, labelpad=8)
    fig.canvas.draw()
    ax = plt.gca()
    ax.tick_params(labelsize=9)
    ax.set_xlabel('calculated enrichment\n(maximum likelihood)', fontsize=9)
    ax.set_ylabel('predicted enrichment', fontsize=9)
    ax.set_yticks([0, 5, 10, 15])
    plt.tight_layout()
    if zoomIn and not pruneLowRawCounts:
        plt.savefig(pathify(f'2D_histogram_maximum_likelihood_zoomed_in_SIRT2_FP-FFNN_{split}_seed_0.png'))
    elif zoomIn and pruneLowRawCounts:
        plt.savefig(pathify(f'2D_histogram_maximum_likelihood_zoomed_in_SIRT2_FP-FFNN_{split}_seed_0_low_counts_pruned.png'))
    elif not zoomIn and not pruneLowRawCounts:
        plt.savefig(pathify(f'2D_histogram_maximum_likelihood_full_SIRT2_FP-FFNN_{split}_seed_0.png'))
    plt.show()


    # lower bound
    fig = plt.figure(figsize=(3.33, 2.82), dpi=300)
    y0 = R_lb
    if zoomIn:
        bins = [np.arange(0, 4.001, 0.06),np.arange(0, 20, 0.3)]
    else:
        bins = [np.arange(0, 20.04, 0.6012),np.arange(0, 20, 0.6)]
    plt.hist2d(
        np.clip(y0, 0, bins[0][-1]), 
        np.clip(y1, 0, bins[1][-1]), 
        bins=bins, 
        density=False,
        norm=colors.LogNorm(),
        cmap=my_cmap,
    )
    cb = plt.colorbar()
    cb.ax.tick_params(labelsize=7, length=3, pad=0.5)
    cb.ax.set_ylabel('frequency', rotation=270, fontsize=8, labelpad=8)
    fig.canvas.draw()
    ax = plt.gca()
    ax.tick_params(labelsize=9)
    ax.set_xlabel('calculated enrichment\n(lower bound)', fontsize=9)
    ax.set_ylabel('predicted enrichment', fontsize=9)
    ax.set_yticks([0, 5, 10, 15])
    plt.tight_layout()
    if zoomIn and not pruneLowRawCounts:
        plt.savefig(pathify(f'2D_histogram_LB_zoomed_in_SIRT2_FP-FFNN_{split}_seed_0.png'))
    elif zoomIn and pruneLowRawCounts:
        plt.savefig(pathify(f'2D_histogram_LB_zoomed_in_SIRT2_FP-FFNN_{split}_seed_0_low_counts_pruned.png'))
    elif not zoomIn and not pruneLowRawCounts:
        plt.savefig(pathify(f'2D_histogram_LB_full_SIRT2_FP-FFNN_{split}_seed_0.png'))
    plt.show()


    # upper bound
    fig = plt.figure(figsize=(3.33, 2.82), dpi=300)
    y0 = R_ub
    if zoomIn:
        bins = [np.arange(0, 10.001, 0.15),np.arange(0, 20, 0.3)]
    else:
        bins = [np.arange(0, 1841, 55.23),np.arange(0, 20, 0.6)]
    plt.hist2d(
        np.clip(y0, 0, bins[0][-1]), 
        np.clip(y1, 0, bins[1][-1]), 
        bins=bins, 
        density=False,
        norm=colors.LogNorm(),
        cmap=my_cmap,
    )
    cb = plt.colorbar()
    cb.ax.tick_params(labelsize=7, length=3, pad=0.5)
    cb.ax.set_ylabel('frequency', rotation=270, fontsize=8, labelpad=8)
    fig.canvas.draw()
    ax = plt.gca()
    ax.tick_params(labelsize=9)
    ax.set_xlabel('calculated enrichment\n(upper bound)', fontsize=9)
    ax.set_ylabel('predicted enrichment', fontsize=9)
    ax.set_yticks([0, 5, 10, 15])
    plt.tight_layout()
    if zoomIn and not pruneLowRawCounts:
        plt.savefig(pathify(f'2D_histogram_UB_zoomed_in_SIRT2_FP-FFNN_{split}_seed_0.png'))
    elif zoomIn and pruneLowRawCounts:
        plt.savefig(pathify(f'2D_histogram_UB_zoomed_in_SIRT2_FP-FFNN_{split}_seed_0_low_counts_pruned.png'))
    elif not zoomIn and not pruneLowRawCounts:
        plt.savefig(pathify(f'2D_histogram_UB_full_SIRT2_FP-FFNN_{split}_seed_0.png'))
    plt.show()

In [None]:
make_2D_histograms_SIRT2(test_slice_rand, model_SIRT2_rand, 'random', zoomIn=True, pruneLowRawCounts=True)

In [None]:
make_2D_histograms_SIRT2(test_slice_c123, model_SIRT2_c123, 'cycle 1+2+3', zoomIn=True, pruneLowRawCounts=True)