In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os, sys
from tqdm import tqdm
import numpy as np
import pandas as pd

In [None]:
DELQSAR_ROOT = os.getcwd() + '/../../'
sys.path += [DELQSAR_ROOT + '/../']

if not os.path.isdir('DD1S_CAIX_cycle2_distrib_shift'):
    os.mkdir('DD1S_CAIX_cycle2_distrib_shift')
    
def pathify(fname):
    return os.path.join('DD1S_CAIX_cycle2_distrib_shift', fname)

LOG_FILE = os.path.join(DELQSAR_ROOT, 'experiments', 'visualizations',
                        'DD1S_CAIX_cycle2_distrib_shift', 'DD1S_CAIX_cycle2_distrib_shift.log')

In [None]:
from del_qsar import splitters
from del_qsar.enrichments import R_ranges

In [None]:
import matplotlib
import matplotlib.pyplot as plt
plt.style.use('seaborn-paper')

matplotlib.rc('font', family='sans-serif') 
matplotlib.rc('font', serif='Arial') 
matplotlib.rc('text', usetex='false') 
matplotlib.rcParams.update({'font.size': 9})

In [None]:
df_data = pd.read_csv(os.path.join(DELQSAR_ROOT, 'experiments', 'datasets', 'DD1S_CAIX_QSAR.csv'))

exp_counts = np.array(df_data[['exp_tot']], dtype='int')
bead_counts = np.array(df_data[['beads_tot']], dtype='int')
exp_tot = np.sum(exp_counts, axis=0) # column sums
bead_tot = np.sum(bead_counts, axis=0)

In [None]:
splitter = splitters.OneCycleSplitter(['cycle2'], LOG_FILE)
test_slices = [splitter(_, df_data, seed=i)[2] for i in tqdm(range(5))]

In [None]:
[len(ts) for ts in test_slices]

In [None]:
def make_plot_calc_enrichments(eval_slices, seeds, zoomIn):
    fig = plt.figure(figsize=(3.5, 2), dpi=300)
    for seed in seeds:
        R, R_lb, R_ub = R_ranges(bead_counts[eval_slices[seed], 0], bead_tot[0], 
                                       exp_counts[eval_slices[seed], 0], exp_tot[0])
        bins = np.arange(0, max(R_lb)+0.001, 0.06)
        _, bins, patches = plt.hist(
            np.clip(R_lb, 0, bins[-1]), 
            bins=bins, 
            density=True,
            zorder=2,
            alpha=0.4,
            label=f'Seed {seed}', 
        )
    
    plt.legend(fontsize=7)
    fig.canvas.draw() # required to get tick labels
    ax = plt.gca()
    ax.tick_params(labelsize=9)
    ax.set_xlim([0, 2])
    if zoomIn:
        ax.set_ylim([0, 0.4])
    
    ax.grid(zorder=1)
    ax.set_xlabel('Calculated enrichment (lower bound)', fontsize=9)
    ax.set_ylabel('Probability density', fontsize=9)
    plt.tight_layout()
    if zoomIn:
        plt.savefig(pathify(f'DD1S_CAIX_cycle2_calculated_enrichments_LB_seeds_{str(seeds)}_y-axis_zoomed in.png'))
    else:
        plt.savefig(pathify(f'DD1S_CAIX_cycle2_calculated_enrichments_LB_seeds_{str(seeds)}.png'))
    plt.show()

In [None]:
make_plot_calc_enrichments(test_slices, [0,1,2,3,4], zoomIn=False)

In [None]:
make_plot_calc_enrichments(test_slices, [0,1,2,3,4], zoomIn=True)

In [None]:
make_plot_calc_enrichments(test_slices, [0,4], False)

In [None]:
make_plot_calc_enrichments(test_slices, [0,4], True)

In [None]:
make_plot_calc_enrichments(test_slices, [1,4], False)

In [None]:
make_plot_calc_enrichments(test_slices, [1,4], True)

In [None]:
make_plot_calc_enrichments(test_slices, [2,4], False)

In [None]:
make_plot_calc_enrichments(test_slices, [2,4], True)

In [None]:
make_plot_calc_enrichments(test_slices, [3,4], False)

In [None]:
make_plot_calc_enrichments(test_slices, [3,4], True)