In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os, sys
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch
import h5py
import copy

In [None]:
DELQSAR_ROOT = os.getcwd() + '/../../'
sys.path += [DELQSAR_ROOT + '/../']

In [None]:
FINGERPRINTS_FILENAME = 'x_DD1S_CAIX_2048_bits_all_fps.h5' # should be in the experiments folder

RANDOM_SPLIT_MODEL_PATH = os.path.join(DELQSAR_ROOT, 'experiments', 'models', 
                                       'DD1S_CAIX', 'FP-FFNN','random_seed_0.torch')

CYCLE12_SPLIT_MODEL_PATH = os.path.join(DELQSAR_ROOT, 'experiments', 'models', 
                                       'DD1S_CAIX', 'FP-FFNN','cycle12_seed_0.torch')

In [None]:
from del_qsar import models, featurizers, splitters
from del_qsar.enrichments import R_from_z, R_ranges

if not os.path.isdir('DD1S_CAIX_histograms_parity_plots'):
    os.mkdir('DD1S_CAIX_histograms_parity_plots')
    
def pathify(fname):
    return os.path.join('DD1S_CAIX_histograms_parity_plots', fname)

LOG_FILE = os.path.join(DELQSAR_ROOT, 'experiments', 'visualizations', 
                        'DD1S_CAIX_histograms_parity_plots', 'DD1S_CAIX_histograms_parity_plots.log')

In [None]:
import matplotlib
import matplotlib.pyplot as plt
plt.style.use('seaborn-paper')
import matplotlib.colors as colors
import matplotlib.patches as patches

matplotlib.rc('font', family='sans-serif') 
matplotlib.rc('font', serif='Arial') 
matplotlib.rc('text', usetex='false')

In [None]:
df_data = pd.read_csv(os.path.join(DELQSAR_ROOT, 'experiments', 'datasets', 'DD1S_CAIX_QSAR.csv'))

In [None]:
exp_counts = np.array(df_data[['exp_tot']], dtype='int')
bead_counts = np.array(df_data[['beads_tot']], dtype='int')
exp_tot = np.sum(exp_counts, axis=0) # column sums
bead_tot = np.sum(bead_counts, axis=0)

In [None]:
os.environ["HDF5_USE_FILE_LOCKING"] = 'FALSE'
hf = h5py.File(os.path.join(DELQSAR_ROOT, 'experiments', FINGERPRINTS_FILENAME), 'r')
x = np.array(hf['all_fps'])
INPUT_SIZE = x.shape[1]
hf.close()

In [None]:
SEED = 0
torch.manual_seed(SEED)

In [None]:
splitter_rand = splitters.RandomSplitter()
train_slice_rand, valid_slice_rand, test_slice_rand  = splitter_rand(x, df_data, seed=SEED)

In [None]:
# random split model
BATCH_SIZE = 1024
LAYER_SIZES = [64, 64, 64]
DROPOUT = 0.1
model_rand = models.MLP(INPUT_SIZE, [int(size) for size in LAYER_SIZES],
                        dropout=DROPOUT, torch_seed=SEED)
model_rand.load_state_dict(torch.load(RANDOM_SPLIT_MODEL_PATH))
print(str(model_rand))

In [None]:
splitter_c12 = splitters.TwoCycleSplitter(['cycle1','cycle2'], LOG_FILE)
train_slice_c12, valid_slice_c12, test_slice_c12 = splitter_c12(x, df_data, seed=SEED)

In [None]:
# cycle 1+2 split model
BATCH_SIZE = 1024
LAYER_SIZES = [16, 4]
DROPOUT = 0.45
model_c12 = models.MLP(INPUT_SIZE, [int(size) for size in LAYER_SIZES],
                        dropout=DROPOUT, torch_seed=SEED)
model_c12.load_state_dict(torch.load(CYCLE12_SPLIT_MODEL_PATH))
print(str(model_c12))

In [None]:
DEVICE = None
if torch.cuda.is_available():
    DEVICE = 'cuda:0'
    model_rand = model_rand.to('cuda:0')
    model_c12 = model_c12.to('cuda:0')

# 2D histograms

In [None]:
def make_2D_histograms_CAIX_enrichments(eval_slice, model, split, pruneLowRawCounts=False):
    _R, _R_lb, _R_ub = R_ranges(bead_counts[eval_slice, 0], bead_tot[0], exp_counts[eval_slice, 0], exp_tot[0])
    _test_enrichments = model.predict_on_x(
        x[eval_slice, :], batch_size=BATCH_SIZE, device=DEVICE,
    )
    R, R_lb, R_ub, test_enrichments = [], [], [], []
    if pruneLowRawCounts:
        for i in range(len(eval_slice)):
            if df_data.iloc[eval_slice[i]]['exp_tot'] + df_data.iloc[eval_slice[i]]['beads_tot'] >= 3:
                R.append(_R[i])
                R_lb.append(_R_lb[i])
                R_ub.append(_R_ub[i])
                test_enrichments.append(_test_enrichments[i])
    y1 = test_enrichments
    
    my_cmap = copy.copy(matplotlib.cm.get_cmap('viridis'))
    my_cmap.set_bad("#CFCFCF") # color zero frequency values as gray
    
    # maximum likelihood
    fig = plt.figure(figsize=(3.33, 2.82), dpi=300)
    y0 = R
    bins = [np.arange(0, 6.001, 0.09),np.arange(0.5, 4.501, 0.06)] 

    plt.hist2d(
        np.clip(y0, 0, bins[0][-1]), 
        np.clip(y1, 0.5, bins[1][-1]), 
        bins=bins, 
        density=False,
        norm=colors.LogNorm(),
        cmap=my_cmap,
    )
    cb = plt.colorbar()
    cb.ax.tick_params(labelsize=7, length=3, pad=0.5)
    cb.ax.set_ylabel('frequency', rotation=270, fontsize=8, labelpad=8)
    fig.canvas.draw()
    ax = plt.gca()
    ax.tick_params(labelsize=9)
    ax.set_xlabel('calculated enrichment\n(maximum likelihood)', fontsize=9)
    ax.set_ylabel('predicted enrichment', fontsize=9)
    ax.set_yticks([1, 2, 3, 4])
    plt.tight_layout()
    if not pruneLowRawCounts:
        plt.savefig(pathify(f'2D_histogram_maximum_likelihood_DD1S_CAIX_FP-FFNN_{split}_seed_0.png'))
    else:
        plt.savefig(pathify(f'2D_histogram_maximum_likelihood_DD1S_CAIX_FP-FFNN_{split}_seed_0_low_counts_pruned.png'))
    plt.show()
    
    
    # lower bound
    fig = plt.figure(figsize=(3.33, 2.82), dpi=300)
    y0 = R_lb
    bins = [np.arange(0, 3.501, 0.0525),np.arange(0.5, 4.501, 0.06)]

    plt.hist2d(
        np.clip(y0, 0, bins[0][-1]), 
        np.clip(y1, 0.5, bins[1][-1]), 
        bins=bins, 
        density=False,
        norm=colors.LogNorm(),
        cmap=my_cmap,
    )
    cb = plt.colorbar()
    cb.ax.tick_params(labelsize=7, length=3, pad=0.5)
    cb.ax.set_ylabel('frequency', rotation=270, fontsize=8, labelpad=8)
    fig.canvas.draw()
    ax = plt.gca()
    ax.tick_params(labelsize=9)
    ax.set_xlabel('calculated enrichment\n(lower bound)', fontsize=9)
    ax.set_ylabel('predicted enrichment', fontsize=9)
    ax.set_yticks([1, 2, 3, 4])
    plt.tight_layout()
    if not pruneLowRawCounts:
        plt.savefig(pathify(f'2D_histogram_LB_DD1S_CAIX_FP-FFNN_{split}_seed_0.png'))
    else: 
        plt.savefig(pathify(f'2D_histogram_LB_DD1S_CAIX_FP-FFNN_{split}_seed_0_low_counts_pruned.png'))
    plt.show()
    
    
    # upper bound
    fig = plt.figure(figsize=(3.33, 2.82), dpi=300)
    y0 = R_ub
    bins = [np.arange(0, 10.001, 0.15),np.arange(0.5, 4.501, 0.06)]

    plt.hist2d(
        np.clip(y0, 0, bins[0][-1]), 
        np.clip(y1, 0.5, bins[1][-1]), 
        bins=bins, 
        density=False,
        norm=colors.LogNorm(),
        cmap=my_cmap,
    )
    cb = plt.colorbar()
    cb.ax.tick_params(labelsize=7, length=3, pad=0.5)
    cb.ax.set_ylabel('frequency', rotation=270, fontsize=8, labelpad=8)
    fig.canvas.draw()
    ax = plt.gca()
    ax.tick_params(labelsize=9)
    ax.set_xlabel('calculated enrichment\n(upper bound)', fontsize=9)
    ax.set_ylabel('predicted enrichment', fontsize=9)
    ax.set_yticks([1, 2, 3, 4])
    plt.tight_layout()
    if not pruneLowRawCounts:
        plt.savefig(pathify(f'2D_histogram_UB_DD1S_CAIX_FP-FFNN_{split}_seed_0.png'))
    else: 
        plt.savefig(pathify(f'2D_histogram_UB_DD1S_CAIX_FP-FFNN_{split}_seed_0_low_counts_pruned.png'))
    plt.show()

In [None]:
make_2D_histograms_CAIX_enrichments(test_slice_rand, model_rand, 'random', pruneLowRawCounts=True)

In [None]:
make_2D_histograms_CAIX_enrichments(test_slice_c12, model_c12, 'cycle 1+2', pruneLowRawCounts=True)

# 1D histograms

In [None]:
def make_1D_histograms_CAIX_enrichments(eval_slice):
    ## First plot: calculated enrichments
    fig = plt.figure(figsize=(3, 1.5), dpi=300)
    R, R_lb, R_ub = R_ranges(bead_counts[eval_slice, 0], bead_tot[0], exp_counts[eval_slice, 0], exp_tot[0])
    has_sulfonamide = np.array([df_data.iloc[idx]['cycle3'] in [14, 74] for idx in eval_slice])
    y0 = R[~has_sulfonamide]
    y1 = R[has_sulfonamide]
    bins = np.arange(0, 10, 0.3)
    _, bins, patches = plt.hist(
        np.clip(y0, 0, bins[-1]), 
        bins=bins, 
        label='compounds without\nbenzenesulfonamide',
        density=True,
        zorder=2,
        alpha=0.7,
    )
    
    _, bins, patches = plt.hist(
        np.clip(y1, 0, bins[-1]), 
        bins=bins, 
        label='compounds with\nbenzenesulfonamide',
        density=True,
        zorder=3,
        alpha=0.6,
    ) 
    plt.legend(fontsize=7)
    fig.canvas.draw()
    ax = plt.gca()
    ax.tick_params(labelsize=7)
    ax.grid(zorder=1)
    ax.set_xlabel('calculated enrichment (maximum likelihood)', fontsize=8)
    ax.set_ylabel('probability density', fontsize=8)
    plt.tight_layout()
    plt.savefig(pathify(f'1D_histogram_DD1S_CAIX_FP-FFNN_random_seed_0_calculated_enrichments.png'))
    plt.show()

    
    ## Second plot: predicted enrichments
    fig = plt.figure(figsize=(3, 1.5), dpi=300)
    test_enrichments = model_rand.predict_on_x(
        x[eval_slice, :], batch_size=BATCH_SIZE, device=DEVICE,
    )
    
    y0 = test_enrichments[~has_sulfonamide]
    y1 = test_enrichments[has_sulfonamide]
    bins = np.arange(0, 10, 0.1) 
    _, bins, patches = plt.hist(
        np.clip(y0, 0, bins[-1]), 
        bins=bins, 
        label='compounds without\nbenzenesulfonamide',
        density=True,
        zorder=2,
        alpha=0.7,
    )
    _, bins, patches = plt.hist(
        np.clip(y1, 0, bins[-1]), 
        bins=bins, 
        label='compounds with\nbenzenesulfonamide',
        density=True,
        zorder=3,
        alpha=0.6,
    )
    plt.legend(fontsize=7)
    fig.canvas.draw()
    ax = plt.gca()
    ax.tick_params(labelsize=7)
    ax.set_xlim([0.5, 4.5])
    ax.grid(zorder=1)
    ax.set_xlabel('predicted enrichment', fontsize=8)
    ax.set_ylabel('probability density', fontsize=8)
    plt.tight_layout()
    plt.savefig(pathify(f'1D_histogram_DD1S_CAIX_FP-FFNN_random_seed_0_predicted_enrichments.png'))
    plt.show()

In [None]:
make_1D_histograms_CAIX_enrichments(test_slice_rand)

# Parity scatter plot

In [None]:
def draw_predicted_enrichments_vs_true(eval_slice, zoomIn, rectangles=True):
    R, R_lb, R_ub = R_ranges(bead_counts[eval_slice, 0], bead_tot[0], exp_counts[eval_slice, 0], exp_tot[0])
    has_sulfonamide = np.array([df_data.iloc[idx]['cycle3'] in [14, 74] for idx in eval_slice])
    test_enrichments = model_rand.predict_on_x(
        x[eval_slice, :], batch_size=BATCH_SIZE, device=DEVICE,
    )
    
    if not rectangles:
        fig = plt.figure(figsize=(4, 3), dpi=300) 
        
        lower_error0 = R[~has_sulfonamide] - R_lb[~has_sulfonamide]
        upper_error0 = R_ub[~has_sulfonamide] - R[~has_sulfonamide]
        error0 = [lower_error0, upper_error0]
        container = plt.errorbar(
            x=R[~has_sulfonamide], 
            y=test_enrichments[~has_sulfonamide],
            xerr=error0,
            label='compounds without\nbenzenesulfonamide',
            color='#1f77b4', # blue
            marker='o',
            markersize=3, 
            elinewidth=0.75,
            ls='none',
            ecolor='k',
            capsize=1,
            capthick=0.75, 
            zorder=2,
        ) 

        lower_error1 = R[has_sulfonamide] - R_lb[has_sulfonamide]
        upper_error1 = R_ub[has_sulfonamide] - R[has_sulfonamide]
        error1 = [lower_error1, upper_error1]
        container = plt.errorbar(
            x=R[has_sulfonamide], 
            y=test_enrichments[has_sulfonamide], 
            xerr=error1,
            label='compounds with\nbenzenesulfonamide',
            color='#ff7f0e', # orange
            marker='o',
            markersize=3,
            elinewidth=0.75, 
            ls='none',
            ecolor='k',
            capsize=1,
            capthick=0.75, 
            zorder=3,
        )


        lines = plt.plot(
            np.linspace(min(test_enrichments), max(test_enrichments), 100),
            np.linspace(min(test_enrichments), max(test_enrichments), 100),
            color='#2ca02c', # green
            label='parity',
            linewidth=0.75, 
            zorder=5,
        )

        plt.legend(fontsize=7)
        fig.canvas.draw()
        ax = plt.gca() 
        if zoomIn:
            ax.set_xlim([0, 55])
        ax.tick_params(labelsize=8)
        ax.grid(zorder=1)
        ax.set_xlabel('calculated enrichment', fontsize=8)
        ax.set_ylabel('predicted enrichment', fontsize=8)
        plt.tight_layout()
        
        if not zoomIn:
            plt.savefig(pathify(f'Parity_scatter_plot_full_DD1S_CAIX_FP-FFNN_random_seed_0.png'))
        else:
            plt.savefig(pathify(f'Parity_scatter_plot_zoomed_in_DD1S_CAIX_FP-FFNN_random_seed_0.png'))
            
    else:
        fig, (a, a2) = plt.subplots(1, 2, sharey=True, figsize=(4, 3), dpi=300, 
                                    gridspec_kw={'width_ratios': [30.5, 5.5]})

        lower_error0 = R[~has_sulfonamide] - R_lb[~has_sulfonamide]
        upper_error0 = R_ub[~has_sulfonamide] - R[~has_sulfonamide]
        error0 = [lower_error0, upper_error0]
        container = a.errorbar(
            x=R[~has_sulfonamide], 
            y=test_enrichments[~has_sulfonamide],
            xerr=error0,
            label='compounds without\nbenzenesulfonamide',
            color='#1f77b4', # blue
            marker='o',
            markersize=3, 
            elinewidth=0.75,
            ls='none',
            ecolor='k',
            capsize=1,
            capthick=0.75, 
            zorder=2,
        ) 

        container = a2.errorbar(
            x=R[~has_sulfonamide], 
            y=test_enrichments[~has_sulfonamide],
            xerr=error0,
            label='compounds without\nbenzenesulfonamide',
            color='#1f77b4', # blue
            marker='o',
            markersize=3, 
            elinewidth=0.75,
            ls='none',
            ecolor='k',
            capsize=1,
            capthick=0.75, 
            zorder=2,
        ) 

        lower_error1 = R[has_sulfonamide] - R_lb[has_sulfonamide]
        upper_error1 = R_ub[has_sulfonamide] - R[has_sulfonamide]
        error1 = [lower_error1, upper_error1]
        container = a.errorbar(
            x=R[has_sulfonamide], 
            y=test_enrichments[has_sulfonamide], 
            xerr=error1,
            label='compounds with\nbenzenesulfonamide',
            color='#ff7f0e', # orange
            marker='o',
            markersize=3,
            elinewidth=0.75, 
            ls='none',
            ecolor='k',
            capsize=1,
            capthick=0.75, 
            zorder=3,
        )

        container = a2.errorbar(
            x=R[has_sulfonamide], 
            y=test_enrichments[has_sulfonamide], 
            xerr=error1,
            label='compounds with\nbenzenesulfonamide',
            color='#ff7f0e', # orange
            marker='o',
            markersize=3,
            elinewidth=0.75, 
            ls='none',
            ecolor='k',
            capsize=1,
            capthick=0.75, 
            zorder=3,
        )


        lines = a.plot(
            np.linspace(min(test_enrichments), max(test_enrichments), 100),
            np.linspace(min(test_enrichments), max(test_enrichments), 100),
            color='#2ca02c', # green
            label='parity',
            linewidth=0.75, 
            zorder=5,
        )

        lines = a2.plot(
            np.linspace(min(test_enrichments), max(test_enrichments), 100),
            np.linspace(min(test_enrichments), max(test_enrichments), 100),
            color='#2ca02c', # green
            label='parity',
            linewidth=0.75, 
            zorder=5,
        )

        a.legend(fontsize=7, loc='center right')
        fig.canvas.draw()
        ax = plt.gca() 
        ax.tick_params(labelsize=8)

        # add rectangles
        a.add_patch(patches.Rectangle((0, 1.6607487711287636), 5.0, (2.1909911122940877-1.6607487711287636), 
                             linewidth=2, edgecolor='#d62728', facecolor='none', fill=False, 
                             zorder=6, alpha=0.6))
        a.add_patch(patches.Rectangle((0, 3.397599537463589), 6.0, (4.033890346861978-3.397599537463589), 
                                 linewidth=2, edgecolor='#9467bd', facecolor='none', fill=False, 
                                 zorder=7, alpha=0.6))
        a.add_patch(patches.Rectangle((0, 0.33691880404949187), 55.0, (6.169584556868059-0.33691880404949187), 
                                 linewidth=3.9, edgecolor='#bcbd22', facecolor='none', fill=False, 
                                 zorder=8, alpha=0.6))
        a2.add_patch(patches.Rectangle((0, 0.33691880404949187), 55.0, (6.169584556868059-0.33691880404949187), 
                                 linewidth=3.9, edgecolor='#bcbd22', facecolor='none', fill=False, 
                                 zorder=8, alpha=0.6))

        a.set_xlim(0, 30.5)
        a.set_ylim([0.33691871762275694, 4.4302008092403415])
        a2.set_ylim([0.33691871762275694, 4.4302008092403415])
        a2.set_xlim(49.5, 55)
        a.spines['right'].set_visible(False)
        a2.spines['left'].set_visible(False)
        a2.tick_params(left=False, labelleft=False)

        d = .015
        kwargs = dict(transform=a.transAxes, color='k', clip_on=False, linewidth=0.75)
        a.plot((1-d/2.7, 1+d/2.7), (-d, +d), **kwargs)       
        a.plot((1-d/2.7, 1+d/2.7), (1-d, 1+d), **kwargs)  
        kwargs.update(transform=a2.transAxes)
        a2.plot((-2*d, +2*d), (1-d, 1+d), **kwargs) 
        a2.plot((-2*d, +2*d), (-d, +d), **kwargs)

        a.grid(zorder=1)
        a2.grid(zorder=1)
        a.set_xlabel('calculated enrichment', fontsize=8)
        a.set_ylabel('predicted enrichment', fontsize=8)
        plt.tight_layout()

        plt.savefig(pathify(f'Parity_scatter_plot_zoomed_in_with_x_axis_break_and_rectangles_DD1S_CAIX_FP-FFNN_random_seed_0.png'))
    
    print(plt.axis())
    plt.show()

In [None]:
draw_predicted_enrichments_vs_true(test_slice_rand, zoomIn=False, rectangles=False)

In [None]:
draw_predicted_enrichments_vs_true(test_slice_rand, zoomIn=True, rectangles=False)

In [None]:
draw_predicted_enrichments_vs_true(test_slice_rand, zoomIn=True, rectangles=True)

# Parity scatter plot close-ups

## High confidence point

In [None]:
def draw_close_up_high_conf_pt(eval_slice, slice_label = ''):
    R, R_lb, R_ub = R_ranges(bead_counts[eval_slice, 0], bead_tot[0], exp_counts[eval_slice, 0], exp_tot[0])
    has_sulfonamide = np.array([df_data.iloc[idx]['cycle3'] in [14, 74] for idx in eval_slice])
    test_enrichments = model_rand.predict_on_x(
        x[eval_slice, :], batch_size=BATCH_SIZE, device=DEVICE,
    )
    
    fig = plt.figure(figsize=(2.33, 2.33), dpi=300)

    lower_error0 = R[~has_sulfonamide] - R_lb[~has_sulfonamide]
    upper_error0 = R_ub[~has_sulfonamide] - R[~has_sulfonamide]
    error0 = [lower_error0, upper_error0]
    container = plt.errorbar(
        x=R[~has_sulfonamide], 
        y=test_enrichments[~has_sulfonamide],
        xerr=error0, 
        color='#1f77b4', # blue
        marker='o',
        markersize=3,
        elinewidth=0.5,
        ls='none',
        ecolor='k',
        capsize=1,
        capthick=0.5, 
        zorder=2,
    ) 
    
    lower_error1 = R[has_sulfonamide] - R_lb[has_sulfonamide]
    upper_error1 = R_ub[has_sulfonamide] - R[has_sulfonamide]
    error1 = [lower_error1, upper_error1]
    container = plt.errorbar(
        x=R[has_sulfonamide], 
        y=test_enrichments[has_sulfonamide], 
        xerr=error1,
        color='#ff7f0e', # orange
        marker='o',
        markersize=3,
        elinewidth=0.5, 
        ls='none',
        ecolor='k',
        capsize=1,
        capthick=0.5, 
        zorder=3,
    )
    
    
    container = plt.errorbar(
        x=R[has_sulfonamide][358],
        y=test_enrichments[has_sulfonamide][358],
        xerr=[[lower_error1[358]], [upper_error1[358]]],
        label='cpd_id 11676',
        color='#ff7f0e',
        marker='o',
        markersize=4,
        markeredgecolor='#d62728',
        markeredgewidth=1.5,
        elinewidth=1.5,
        ls='none',
        ecolor='#d62728',
        capsize=2,
        capthick=1.5,
        zorder=4,
    )

    lines = plt.plot(
        np.linspace(min(test_enrichments), max(test_enrichments), 100),
        np.linspace(min(test_enrichments), max(test_enrichments), 100),
        color='#2ca02c', # green
        linewidth=0.5, 
        zorder=5,
    )
    
    leg = plt.legend(loc='upper center', bbox_to_anchor = (0.5,1.22), numpoints=1, fontsize=8)
    fig.canvas.draw()
    ax = plt.gca() 
    rect1 = patches.Rectangle((0, 1.6607487711287636), 5.0, (2.1909911122940877-1.6607487711287636), 
                             linewidth=3, edgecolor='#d62728', facecolor='none', fill=False, zorder=6, alpha=0.6)
    ax.add_patch(rect1)
    ax.tick_params(labelsize=8)
    ax.set_xlim([0, 5])
    ax.set_ylim([1.660748771128763686560259079343048808697663659495720564422, 2.190991112294087913439740920656951191302336340504279435577])
    ax.grid(zorder=1)
    ax.set_xlabel('calculated enrichment', fontsize=8)
    ax.set_ylabel('predicted enrichment', fontsize=8)
    plt.tight_layout()
    print(plt.axis())
    plt.savefig(pathify('high_confidence_point_DD1S_CAIX_FP-FFNN_random_seed_0.png'),
               bbox_extra_artists=(leg,), bbox_inches='tight')
    plt.show()

In [None]:
draw_close_up_high_conf_pt(test_slice_rand)

## Low confidence / high prediction point

In [None]:
def draw_close_up_low_conf_high_pred_pt(eval_slice, slice_label = ''):
    R, R_lb, R_ub = R_ranges(bead_counts[eval_slice, 0], bead_tot[0], exp_counts[eval_slice, 0], exp_tot[0])
    has_sulfonamide = np.array([df_data.iloc[idx]['cycle3'] in [14, 74] for idx in eval_slice])
    test_enrichments = model_rand.predict_on_x(
        x[eval_slice, :], batch_size=BATCH_SIZE, device=DEVICE,
    )
    
    fig = plt.figure(figsize=(2.33, 2.33), dpi=300)

    lower_error0 = R[~has_sulfonamide] - R_lb[~has_sulfonamide]
    upper_error0 = R_ub[~has_sulfonamide] - R[~has_sulfonamide]
    error0 = [lower_error0, upper_error0]
    container = plt.errorbar(
        x=R[~has_sulfonamide], 
        y=test_enrichments[~has_sulfonamide],
        xerr=error0, 
        color='#1f77b4', # blue
        marker='o',
        markersize=3,
        elinewidth=0.5,
        ls='none',
        ecolor='k',
        capsize=1,
        capthick=0.5, 
        zorder=2,
    ) 
    
    lower_error1 = R[has_sulfonamide] - R_lb[has_sulfonamide]
    upper_error1 = R_ub[has_sulfonamide] - R[has_sulfonamide]
    error1 = [lower_error1, upper_error1]
    container = plt.errorbar(
        x=R[has_sulfonamide], 
        y=test_enrichments[has_sulfonamide], 
        xerr=error1,
        color='#ff7f0e', # orange
        marker='o',
        markersize=3,
        elinewidth=0.5, 
        ls='none',
        ecolor='k',
        capsize=1,
        capthick=0.5, 
        zorder=3,
    )
    
    
    container = plt.errorbar(
        x=R[has_sulfonamide][45],
        y=test_enrichments[has_sulfonamide][45],
        xerr=[[lower_error1[45]], [upper_error1[45]]],
        label='cpd_id 23814',
        color='#ff7f0e',
        marker='o',
        markersize=4,
        markeredgecolor='#9467bd',
        markeredgewidth=1.5,
        elinewidth=1.5,
        ls='none',
        ecolor='#9467bd',
        capsize=2,
        capthick=1.5,
        zorder=4,
    )

    lines = plt.plot(
        np.linspace(min(test_enrichments), max(test_enrichments), 100),
        np.linspace(min(test_enrichments), max(test_enrichments), 100),
        color='#2ca02c', # green
        linewidth=0.5, 
        zorder=5,
    )
    
    leg = plt.legend(loc='upper center', bbox_to_anchor = (0.5,1.22), numpoints=1, fontsize=8)
    fig.canvas.draw()
    ax = plt.gca()
    rect2 = patches.Rectangle((0, 3.397599537463589), 6.0, (4.033890346861978-3.397599537463589), 
                             linewidth=3, edgecolor='#9467bd', facecolor='none', fill=False, zorder=7, alpha=0.6)
    ax.add_patch(rect2)
    ax.tick_params(labelsize=8)
    ax.set_xlim([0, 6])
    ax.set_ylim([3.397599537463588742195697432338653712699514226231783483691, 4.033890346861977814451075641915336571825121443442054129077])
    ax.grid(zorder=1)
    ax.set_xlabel('calculated enrichment', fontsize=8)
    ax.set_ylabel('predicted enrichment', fontsize=8)
    plt.tight_layout()
    print(plt.axis())
    plt.savefig(pathify('low_confidence_high_prediction_point_DD1S_CAIX_FP-FFNN_random_seed_0.png'),
               bbox_extra_artists=(leg,), bbox_inches='tight')
    plt.show()

In [None]:
draw_close_up_low_conf_high_pred_pt(test_slice_rand)

## Low confidence / low prediction point

In [None]:
def draw_close_up_low_conf_low_pred_pt(eval_slice, slice_label = ''):
    R, R_lb, R_ub = R_ranges(bead_counts[eval_slice, 0], bead_tot[0], exp_counts[eval_slice, 0], exp_tot[0])
    has_sulfonamide = np.array([df_data.iloc[idx]['cycle3'] in [14, 74] for idx in eval_slice])
    test_enrichments = model_rand.predict_on_x(
        x[eval_slice, :], batch_size=BATCH_SIZE, device=DEVICE,
    )
    
    fig = plt.figure(figsize=(2.33, 2.33), dpi=300)

    lower_error0 = R[~has_sulfonamide] - R_lb[~has_sulfonamide]
    upper_error0 = R_ub[~has_sulfonamide] - R[~has_sulfonamide]
    error0 = [lower_error0, upper_error0]
    container = plt.errorbar(
        x=R[~has_sulfonamide], 
        y=test_enrichments[~has_sulfonamide],
        xerr=error0, 
        color='#1f77b4', # blue
        marker='o',
        markersize=3,
        elinewidth=0.5,
        ls='none',
        ecolor='k',
        capsize=1,
        capthick=0.5, 
        zorder=2,
    ) 
    
    lower_error1 = R[has_sulfonamide] - R_lb[has_sulfonamide]
    upper_error1 = R_ub[has_sulfonamide] - R[has_sulfonamide]
    error1 = [lower_error1, upper_error1]
    container = plt.errorbar(
        x=R[has_sulfonamide], 
        y=test_enrichments[has_sulfonamide], 
        xerr=error1,
        color='#ff7f0e', # orange
        marker='o',
        markersize=3,
        elinewidth=0.5, 
        ls='none',
        ecolor='k',
        capsize=1,
        capthick=0.5, 
        zorder=3,
    )
    
    
    container = plt.errorbar(
        x=R[~has_sulfonamide][605],
        y=test_enrichments[~has_sulfonamide][605],
        xerr=[[lower_error0[605]], [upper_error0[605]]],
        label='cpd_id 81804',
        color='#1f77b4',
        marker='o',
        markersize=4,
        markeredgecolor='#bcbd22',
        markeredgewidth=1.5,
        elinewidth=1.5, 
        ls='none',
        ecolor='#bcbd22',
        capsize=2,
        capthick=1.5,
        zorder=4,
    )

    lines = plt.plot(
        np.linspace(min(test_enrichments), max(test_enrichments), 100),
        np.linspace(min(test_enrichments), max(test_enrichments), 100),
        color='#2ca02c', # green
        linewidth=0.5, 
        zorder=5,
    )
    
    leg = plt.legend(loc='upper center', bbox_to_anchor = (0.5,1.22), numpoints=1, fontsize=8)
    fig.canvas.draw()
    ax = plt.gca() 
    rect3 = patches.Rectangle((0, 0.33691880404949187), 55.0, (6.169584556868059-0.33691880404949187), 
                             linewidth=3, edgecolor='#bcbd22', facecolor='none', fill=False, zorder=8, alpha=0.6)
    ax.add_patch(rect3)
    ax.tick_params(labelsize=8)
    ax.set_xlim([0, 55])
    ax.set_ylim([0.33691880404949187, 6.16958455686805837])
    ax.grid(zorder=1)
    ax.set_xlabel('calculated enrichment', fontsize=8)
    ax.set_ylabel('predicted enrichment', fontsize=8)
    plt.tight_layout()
    print(plt.axis())
    plt.savefig(pathify('low_confidence_low_prediction_point_DD1S_CAIX_FP-FFNN_random_seed_0.png'),
               bbox_extra_artists=(leg,), bbox_inches='tight')
    plt.show()

In [None]:
draw_close_up_low_conf_low_pred_pt(test_slice_rand)