# Imports

In [None]:
import importlib
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import lines
import seaborn as sns

import os
import sys

In [None]:
pythoncodepath = os.path.abspath(os.path.join('..', '..', '_pythoncode'))
sys.path = [pythoncodepath] + sys.path
import importhelper
importhelper.addfolders2path(pythoncodepath)

In [None]:
import data_utils
import plot_utils

In [None]:
importlib.reload(plot_utils)
plot_utils.set_rcParams()

In [None]:
fig_num = os.getcwd().split('/')[-1][3:5]
print(fig_num)

# Load data

In [None]:
cone_optim_folder = os.path.join('..', '..', 'step1a_optimize_cones', 'optim_data')
cbc_optim_folder = os.path.join('..', '..', 'step2a_optimize_cbc', 'optim_data')

In [None]:
os.listdir(cone_optim_folder)

In [None]:
sorted(os.listdir(cbc_optim_folder))

In [None]:
cell2folder = {
    'Cone':  os.path.join(cone_optim_folder, 'optimize_cone_submission2'),
    'OFF':  os.path.join(cbc_optim_folder, 'optimize_OFF_submission2'),
    'ON':   os.path.join(cbc_optim_folder, 'optimize_ON_submission2'),
}

In [None]:
round_losses = {}

for cell, folder in cell2folder.items():
    round_losses[cell] = []
    for file in sorted(os.listdir(os.path.join(folder, 'samples'))):
        round_losses[cell].append(data_utils.load_var(os.path.join(folder, 'samples', file))['loss'])

# Get losses from posterior samples

In [None]:
import optim_funcs
importlib.reload(optim_funcs)
optim = optim_funcs.EmptyOptimizer()

In [None]:
os.listdir(os.path.join(folder, 'post_data'))

In [None]:
post_losses = {}        
for cell, folder in cell2folder.items():
    rec_data_list = data_utils.load_var(os.path.join(folder, 'post_data', 'post_model_output_list.pkl'))
    post_losses[cell] = optim.stack_model_output_list(rec_data_list)['loss']
    
post_losses_opt_cpl = {}
for cell, folder in cell2folder.items():
    if 'Cone' not in cell:
        rec_data_list = data_utils.load_var(os.path.join(folder, 'post_data', 'post_model_output_list_optimize_cpl.pkl'))
        post_losses_opt_cpl[cell] = optim.stack_model_output_list(rec_data_list)['loss']

In [None]:
marg_losses = {}        
for cell, folder in cell2folder.items():
    rec_data_list = data_utils.load_var(os.path.join(folder, 'marginal_post_data', 'rec_data_list_from_marginals.pkl'))
    marg_losses[cell] = optim.stack_model_output_list(rec_data_list)['loss']

# Export data

In [None]:
data_utils.make_dir('source_data')

In [None]:
loss_names = list(post_losses['Cone'].keys())
loss_names

In [None]:
def get_round_losses(round_i_losses):
    n_round = round_i_losses['total'].size 
        
    round_data = np.full((n_round, len(loss_names)), np.nan)

    for l_idx, l_name in enumerate(loss_names):
        round_data[:, l_idx] = round_i_losses[l_name]
        
    return round_data

In [None]:
def check_output(cell):

    test_data = pd.read_csv(
        'source_data/Sample_discrepancies_' + cell + '.csv', index_col=0)

    for r in range(4):
        for l_name, losses_i in round_losses[cell][r].items():
            absdiff = np.abs(losses_i[~np.isnan(losses_i)]
                             - test_data.loc['Round_' + str(r+1)][l_name].values[~np.isnan(losses_i)])
            assert np.all(absdiff < 1e-3), np.max(absdiff)

    for l_name, losses_i in post_losses[cell].items():
        assert np.all(np.abs(losses_i[~np.isnan(losses_i)]
                             - test_data.loc['Posterior'][l_name].values[~np.isnan(losses_i)]) < 1e-4)

    for l_name, losses_i in marg_losses[cell].items():
        assert np.all(np.abs(losses_i[~np.isnan(losses_i)]
                             - test_data.loc['Marginals'][l_name].values[~np.isnan(losses_i)]) < 1e-4)

In [None]:
for cell, round_losses_cell in round_losses.items():
    
    cell_data = np.full((0, len(loss_names)), np.nan)
    
    cell_data_rows = np.full((0), "Col names")
    
    for r_idx, round_i_losses in enumerate(round_losses_cell):
        round_data = get_round_losses(round_i_losses)
        cell_data = np.vstack([cell_data, round_data])
        cell_data_rows = np.concatenate([cell_data_rows, np.full((round_data.shape[0]), "Round_"+str(r_idx+1))])
        
    round_data = get_round_losses(post_losses[cell])
    cell_data = np.vstack([cell_data, round_data])
    cell_data_rows = np.concatenate([cell_data_rows, np.full((round_data.shape[0]), "Posterior")])
    
    round_data = get_round_losses(marg_losses[cell])
    cell_data = np.vstack([cell_data, round_data])
    cell_data_rows = np.concatenate([cell_data_rows, np.full((round_data.shape[0]), "Marginals")])
    
    pd.DataFrame(cell_data, columns=loss_names, index=cell_data_rows).to_csv(
        'source_data/Sample_discrepancies_' + cell + '.csv', float_format="%.6f")
    
    #### TEST ####
    check_output(cell)

# Create text output


Summarize mean, median, std and perncetile as latex commands for paper text.

## Statistics

In [None]:
from datetime import datetime

text_data = []

text_data.append('%' + str(datetime.now()) + '\n')

for cell, post_losses_cell in post_losses.items():
    
    print('\n', cell)
    if not np.all(np.isfinite(post_losses_cell['total'])):
        print(np.sum(~np.isfinite(post_losses_cell['total'])), ' post samples are nans')
    if not np.all(np.isfinite(marg_losses[cell]['total'])):
        print(np.sum(~np.isfinite(marg_losses[cell]['total'])), ' marg samples are nans')
    
    for loss_key, post_losses_key in post_losses_cell.items():
        marg_losses_key = marg_losses[cell][loss_key]
        
        marg_losses_key = np.abs(marg_losses_key.copy())
        post_losses_key = np.abs(post_losses_key.copy())
        
        # Posterior.
        prefix = "\\newcommand\\postLoss" + cell + loss_key.replace('_', '')
        
        text_data.append(prefix + "Mean{"   + "{:.2f}".format(np.nanmean(post_losses_key))   + "}\n")
        text_data.append(prefix + "Median{" + "{:.2f}".format(np.nanmedian(post_losses_key)) + "}\n")
        text_data.append(prefix + "Std{"    + "{:.2f}".format(np.nanstd(post_losses_key))    + "}\n")
        text_data.append(prefix + "qNiFi{"  + "{:.2f}".format(np.nanpercentile(post_losses_key, q=95)) + "}\n")
        
        if np.median(post_losses_key) > 0.0 and loss_key not in ['total', 'iGluSnFR']:
            print(loss_key.ljust(20), 'has non-zero median for cell\t', cell)
        
        # Marginal posterior.
        prefix = "\\newcommand\\margPostLoss" + cell + loss_key.replace('_', '')

        text_data.append(prefix + "Mean{"   + "{:.2f}".format(np.nanmean(marg_losses_key))   + "}\n")
        text_data.append(prefix + "Median{" + "{:.2f}".format(np.nanmedian(marg_losses_key)) + "}\n")
        text_data.append(prefix + "Std{"    + "{:.2f}".format(np.nanstd(marg_losses_key))    + "}\n")
        text_data.append(prefix + "qNiFi{"  + "{:.2f}".format(np.nanpercentile(marg_losses_key, q=95)) + "}\n")
    
OFF_iGlu_frac = 100*np.mean(post_losses['OFF']['iGluSnFR']) / np.mean(post_losses['OFF']['total'])
ON_iGlu_frac  = 100*np.mean(post_losses['ON']['iGluSnFR'])  / np.mean(post_losses['ON']['total'])

text_data.append("\\newcommand\\postLossIGluFracOFF{" + "{:.0f}".format(OFF_iGlu_frac)+ "}\n")
text_data.append("\\newcommand\\postLossIGluFracON{" + "{:.0f}".format(ON_iGlu_frac)  + "}\n")

text_data.append("\\newcommand\\postLossFracBetterPostThanMargOFF{" + "{:.0f}".format(OFF_iGlu_frac)+ "}\n")
text_data.append("\\newcommand\\postLossFracBetterPostThanMargON{" + "{:.0f}".format(ON_iGlu_frac)  + "}\n")

data_utils.make_dir('text_data')
with open('text_data/post_loss_data.tex', 'w') as f:
    f.writelines(text_data)

In [None]:
for cell, round_losses_i in round_losses.items():
    print(cell)
    for round_loss in round_losses_i:
        print(np.sum(np.isnan(round_loss['total'])), end='\t')
        
    print('post:', np.sum(np.isnan(post_losses[cell]['total'])), end='\t')
    print('marg:', np.sum(np.isnan(marg_losses[cell]['total'])), end='\t')
        
    print()

In [None]:
for cell, cell_post_losses in post_losses.items():
    print(cell)
    for loss_key, losses in cell_post_losses.items():
        print(loss_key.ljust(20), f"{np.mean(np.abs(losses)):.4f}")
        
    print()

## Loss of final model outputs

In [None]:
cell2final_model_outputs = {
    cell: data_utils.load_var(os.path.join(folder, 'post_data', 'final_model_output.pkl'))
        for cell, folder in cell2folder.items()
}

In [None]:
for cell, final_model_output in cell2final_model_outputs.items():
    print(cell)
    for loss_key, loss_value in final_model_output['loss'].items():
        if loss_value != 0.0:
            print(loss_key.ljust(20), f"{loss_value:.6f}")
    print()

In [None]:
from datetime import datetime
text_output = []
text_output.append('%' + str(datetime.now()) + '\n')
for cell, final_model_output in cell2final_model_outputs.items():
    total_loss = final_model_output['loss']['total']
    text_output.append("\\newcommand\\optimized" + cell + "TotalLoss{" + "{:.2f}".format(total_loss) + "}\n")

with open('text_data/optimizedCellsLoss.tex', 'w') as f:
    f.writelines(text_output)
    
text_output

In [None]:
for cell, cell_post_losses_opt_cpl in post_losses_opt_cpl.items():
    print(cell)
    cell_post_losses = post_losses[cell]
    
    print(100*np.sum(cell_post_losses['total'] < cell_post_losses_opt_cpl['total']) / cell_post_losses['total'].size)
    
    sort_idx = np.argsort(cell_post_losses_opt_cpl['total'])
    
    print(np.sum(cell_post_losses['total'][sort_idx][:40] < cell_post_losses_opt_cpl['total'][sort_idx][:40]))

# Plot loss over rounds

In [None]:
def plot_round_loss(ax, round_loss, bottom, loss_min, loss_max, weight=1, squeeze=1):
    plot_areas(ax, round_loss, bottom, weight, squeeze)
    plot_hist(ax, round_loss, bottom, loss_min, loss_max, weight=weight)

In [None]:
def plot_areas(ax, round_loss, bottom, weight, squeeze):
    
    round_loss = round_loss.copy()
    
    xmin = bottom
    xmax = bottom+round_loss.size*weight*squeeze
    
    ax.plot(
        [xmin, xmax],
        np.tile(np.nanmedian(round_loss), 2),
        color='r', lw=0.6, alpha=1, zorder=1, clip_on=True
    )

    ax.fill_between(
        [xmin, xmax],
        np.tile(np.nanpercentile(round_loss,  5), 2),
        np.tile(np.nanpercentile(round_loss, 95), 2),
        alpha=0.2, label='_', color='k', lw=0,
    )         

    ax.fill_between(
        [xmin, xmax],
        np.tile(np.nanpercentile(round_loss, 25), 2),
        np.tile(np.nanpercentile(round_loss, 75), 2),
        alpha=0.5, label='_', color='steelblue', lw=0,
    )

In [None]:
n_bins = 31

def plot_hist(ax, round_loss, bottom, loss_min, loss_max, weight=1):
    
    round_loss = round_loss.copy()
    
    assert isinstance(loss_max, (float,int))
    assert isinstance(round_loss, (np.ndarray, list)), type(round_loss)
    
    round_loss[np.isnan(round_loss)] = loss_max
    
    ax.hist(
        round_loss, orientation="horizontal", bottom=bottom,
        bins=np.linspace(np.min(round_loss),loss_max+0.5*loss_max/n_bins,n_bins-1), 
        zorder=1, alpha=0.5, facecolor='k', range=(0, loss_max), align='mid', lw=0.0,
        weights=np.ones(round_loss.size)*weight
    )

In [None]:
n_rounds = 4
n_samples_per_round = 2000

plot_cols = {
    'total': {'min': 0, 'max': 7, 'title': r'$\delta_{total}$'},
    'iGluSnFR': {'min': 0, 'max': 1, 'title': r'$\delta_{iGluSnFR}$'},
}

def plot_loss(axs, squeeze_areas=1.0):
    
    # Make axis nice.
    for ax, (loss_key, plot_params) in zip(axs[0,:], plot_cols.items()):
        ax.set_title(plot_params['title'])

    for ax_row in axs:
        for ax, (loss_key, plot_params) in zip(ax_row, plot_cols.items()):
            ax.set_ylim(plot_params['min'], plot_params['max'])
            ax.set_yticks([plot_params['min'], plot_params['max']])
        ax_row[0].set_ylabel('Discrepancy')

    for ax in axs.flatten():
        xticks = [n_samples_per_round*i for i in range(n_rounds)]
        ax.set_xticks([n_samples_per_round*i for i in range(n_rounds+2)])
        ax.set_xticklabels([])
        ax.set_xlim(0, (2+n_rounds)*n_samples_per_round)

        ax.spines["left"].set_position(("axes", -0.03))
        ax.spines['bottom'].set_bounds(0, n_rounds*n_samples_per_round)

    for ax in axs[-1,:]:
        ax.set_xticklabels(list(np.arange(n_rounds)+1) + ["p.", "m."])
        ax.set_xlabel(r'Round')
        
        
    # Plot data.
    for ax_row, (cell, round_losses_cell) in zip(axs, round_losses.items()):
        bottom = 0

        # Plot round loss.
        for round_loss in round_losses_cell:
            for ax, (loss_key, plot_params) in zip(ax_row, plot_cols.items()):
                assert n_samples_per_round == round_loss[loss_key].size
                plot_round_loss(
                    ax, round_loss[loss_key], bottom,
                    loss_min=plot_params['min'], loss_max=plot_params['max'],
                    squeeze=squeeze_areas,
                )
                ax.axvline(bottom, c='k', linestyle='-', zorder=10, alpha=0.4, linewidth=1, clip_on=False)
            bottom += n_samples_per_round

        # Plot post loss.
        for ax, (loss_key, plot_params) in zip(ax_row, plot_cols.items()):
            w = float(n_samples_per_round / post_losses[cell][loss_key].size)
            plot_round_loss(
                ax, post_losses[cell][loss_key], bottom,
                loss_min=plot_params['min'], loss_max=plot_params['max'], weight=w,
                squeeze=squeeze_areas,
            )
            ax.axvline(bottom, c='k', linestyle='--', zorder=10, alpha=0.4, linewidth=1)
        bottom += n_samples_per_round

        # Plot marg loss.
        for ax, (loss_key, plot_params) in zip(ax_row, plot_cols.items()):
            w = float(n_samples_per_round / marg_losses[cell][loss_key].size)
            plot_round_loss(
                ax, marg_losses[cell][loss_key], bottom,
                loss_min=plot_params['min'], loss_max=plot_params['max'], weight=w,
                squeeze=squeeze_areas,
            )
            ax.axvline(bottom, c='darkblue', linestyle='-', zorder=10, alpha=0.4, linewidth=1)
        bottom += n_samples_per_round


    sns.despine()

In [None]:
fig, axs = plt.subplots(3,2,figsize=(12,10), sharey=False, sharex=True, squeeze=False)
plot_loss(axs, squeeze_areas=0.9)

# Plot comparison

In [None]:
all_iGluSnFR_losses = data_utils.load_var('source_data/all_iGluSnFR_losses.pkl')

In [None]:
renaming = {
    'OFF model_output': 'Model',
    'OFF strychnine': 'Target',
    'OFF no_drug': 'No drug',
    'OFF similar_strychnine': 'BC4',
    'ON model_output': 'Model',
    'ON strychnine': 'Target',
    'ON no_drug': 'No drug',
    'ON similar_strychnine': 'BC7',   
}

annotatations = [renaming[col] for col in all_iGluSnFR_losses.columns]

for i in np.arange(0,4):
    'OFF' in all_iGluSnFR_losses.columns[0]
for i in np.arange(4,8):
    'ON' in all_iGluSnFR_losses.columns[0]

In [None]:
import matplotlib.patheffects as path_effects
from mpl_toolkits.axes_grid1 import make_axes_locatable

def plot_comparison(ax):
    
    im = ax.imshow(all_iGluSnFR_losses.values, cmap='Blues', vmin=0, vmax=0.6)
    
    ax.set_xticks(np.arange(len(annotatations)))
    ax.set_xticklabels(annotatations, rotation=90)
    ax.set_yticks(np.arange(len(annotatations)))
    ax.set_yticklabels(annotatations, rotation=0)
    
    ax.set_xlim(ax.get_xlim())
    ax.set_ylim(ax.get_ylim())
    
    ax.tick_params(length=0)
    
    ax.text(-3, len(annotatations)/4 - 0.5,   'OFF', ha='right', va='center', rotation=90)
    ax.text(-3, len(annotatations)*3/4 - 0.5, 'ON',  ha='right', va='center', rotation=90)
    
    ax.plot([-2.9, -2.9], [-0.4,3.4], clip_on=False, c='dimgray')
    ax.plot([-2.9, -2.9], [3.6, 7.4], clip_on=False, c='dimgray')
    
    ax.text(len(annotatations)/4 - 0.5,   10.2,  'OFF', ha='center', va='top')
    ax.text(len(annotatations)*3/4 - 0.5, 10.2,  'ON', ha='center',  va='top')
    
    ax.plot([-0.4,3.4], [10., 10], clip_on=False, c='dimgray')
    ax.plot([3.6, 7.4], [10., 10], clip_on=False, c='dimgray')
    
    for idx1, key1 in enumerate(all_iGluSnFR_losses.index):
        for idx2, key2 in enumerate(all_iGluSnFR_losses.columns):
            color = 'w'
            
            if all_iGluSnFR_losses.loc[key1, key2] >= 0.999:
                st = "1"
            elif all_iGluSnFR_losses.loc[key1, key2] < 1e-4:
                st = "0"
            else:
                st = "{:.2f}".format(all_iGluSnFR_losses.loc[key1, key2])[1:]
    
            text = ax.text(
                idx1, idx2, st, ha="center", va="center", color=color,
                path_effects=[path_effects.Stroke(linewidth=1.5, foreground='k'),  path_effects.Normal()],
                fontsize=7
            )
    
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    
    cb = plt.colorbar(im, ax=ax, cax=cax)
    
    cb.set_ticks([0.0, 0.2, 0.4, 0.6])
    cb.set_ticklabels([0.0, 0.2, 0.4, ] + [">0.6"])
    cb.set_label(r'$\delta_{iGluSnFR}$', rotation=90, ha='center', labelpad=-5)

In [None]:
plot_comparison(ax=plt.subplot(111))

# Make figure

In [None]:
sbnx = 3
sbny = 3

fig, axs = plt.subplots(ncols=sbnx, nrows=sbny, figsize=(5.6,2.5), gridspec_kw=dict(width_ratios=[1,1,2.]))
#fig, axs = plt.subplots(ncols=sbnx, nrows=sbny, figsize=(7.4,2.3), gridspec_kw=dict(width_ratios=[1,1,2.3]))

gs = axs[0, 0].get_gridspec()
for ax in axs[:, 2]: ax.remove()
axbig = fig.add_subplot(gs[:, 2])

for ax, ABC in zip(axs[:,0], ['A', 'B', 'C']):
    ax.set_title(ABC+'          ', loc='left', fontweight="bold", ha='right')
axbig.set_title('D'+'            ', loc='left', fontweight="bold", ha='right')

plot_loss(axs[:,:2], squeeze_areas=0.8)

plt.tight_layout(w_pad=-4, h_pad=0)

plot_comparison(ax=axbig)
box = np.asarray(axbig.get_position().bounds)
box[0] += 0.08
box[2] -= 0.08

box[1] += 0.14
box[3] -= 0.14
axbig.set_position(box)

axbig.spines['top'].set_visible(True)
axbig.spines['right'].set_visible(True)

axbig.set_position(np.array(axbig.get_position().bounds) + [-0.005, 0,0,0])

plt.savefig(f'../_figures/fig{fig_num}_sample_loss.pdf')
plt.show()