In [1]:
import sys
sys.path.append('../../../src')
from dimer_model_fit import *
from helpers import *
import random
import scipy.stats as st
%matplotlib notebook

### Import data

In [2]:
df_main = pd.read_csv('df_main.csv')
df_s1 = pd.read_csv('df_s1.csv')
df_s2 = pd.read_csv('df_s2.csv')
# df_s3 = pd.read_csv('df_s3.csv')
# df_s4 = pd.read_csv('df_s4.csv')

### Run analysis

In [3]:
wd_wt = 15.5303089860562 # <- AUC measurement
wd_mut = 15 # <- initial guess
fix_wt = True
fix_mut = False
fit_D = False

analysis_main = EnergiesConfidenceIntervalPaired(df_main, log=True, p0=(wd_wt, wd_mut, 5), fix_wt=fix_wt, 
                                                 fix_mut=fix_mut, fit_D=fit_D)
analysis_s1 = EnergiesConfidenceIntervalPaired(df_s1, log=True, p0=(wd_wt, wd_mut, 5), fix_wt=fix_wt, 
                                               fix_mut=fix_mut, fit_D=fit_D)
analysis_s2 = EnergiesConfidenceIntervalPaired(df_s2, log=True, p0=(wd_wt, wd_mut, 5), fix_wt=fix_wt, 
                                               fix_mut=fix_mut, fit_D=fit_D)
# analysis_s3 = EnergiesConfidenceIntervalPaired(df_s3, log=True, p0=(wd_wt, wd_mut, 5), fix_wt=fix_wt, 
#                                                fix_mut=fix_mut, fit_D=fit_D)
# analysis_s4 = EnergiesConfidenceIntervalPaired(df_s4, log=True, p0=(wd_wt, wd_mut, 5), fix_wt=fix_wt, 
#                                                fix_mut=fix_mut, fit_D=fit_D)

### Plotting functions

In [4]:
def energies_scatter(analysis):
    # Initiate figure
    fig = plt.figure(figsize=(4, 4))
    gs = fig.add_gridspec(2, 2,  width_ratios=(7, 2), height_ratios=(2, 7),
                          left=0.2, right=0.8, bottom=0.2, top=0.8,
                          wspace=0.1, hspace=0.1)
    ax = fig.add_subplot(gs[1, 0])
    ax_histx = fig.add_subplot(gs[0, 0])
    ax_histy = fig.add_subplot(gs[1, 1])
    
    # Axis limits (auto)
    xmin = np.min(analysis.wms) - 0.1
    xmax = np.max(analysis.wms) + 0.1
    ymin = min(np.min(analysis.wds[0]), np.min(analysis.wds[1])) - 0.5
    ymax = max(np.max(analysis.wds[0]), np.max(analysis.wds[1])) + 0.5
    nbins = 50
    
    # Scatter
    ax.scatter(analysis.wms, analysis.wds[0], s=0.1, c='tab:blue')
    ax.scatter(analysis.wms, analysis.wds[1], s=0.1, c='tab:orange')
    ax.set_xlabel(r'$w_m \: / \: RT$')
    ax.set_ylabel(r'$w_d \: / \: RT$')
    ax.tick_params(axis='both', labelsize=8)
    ax.set_xlim(xmin, xmax)
    ax.set_ylim(ymin, ymax)
    
    # Best fit parameters    
    ax.plot([xmin, analysis.wm_full], [analysis.wd_full[0], analysis.wd_full[0]], c='0.8', linestyle='--')
    ax.plot([xmin, analysis.wm_full], [analysis.wd_full[1], analysis.wd_full[1]], c='0.8', linestyle='--')
    ax.plot([analysis.wm_full, analysis.wm_full], [ymin, analysis.wd_full[0]], c='0.8', linestyle='--')
    ax.scatter(analysis.wm_full, analysis.wd_full[0], c='k', edgecolors='w', zorder=10)
    ax.scatter(analysis.wm_full, analysis.wd_full[1], c='k', edgecolors='w', zorder=10)
    
    # Scatter - legend
    ax.scatter([], [], c='tab:blue', s=10, label='PAR-2 (WT)')
    ax.scatter([], [], c='tab:orange', s=10, label='PAR-2 (L109R)')
    ax.legend(fontsize=6, frameon=False)

    # wm kde
    kde = st.gaussian_kde(analysis.wms)
    xx = np.linspace(min(analysis.wms), max(analysis.wms), 1000)
    ax_histx.fill_between(xx, 0, kde(xx), color='tab:gray', alpha=0.5)
    ax_histx.tick_params(axis='both', labelsize=8)
    ax_histx.set_xlim(xmin, xmax)
    ax_histx.set_xticks([])
    ax_histx.set_yticks([])
    ax_histx.set_ylabel('Density', fontsize=8)
    ax_histx.set_ylim(bottom=0)

    # wd kde
    xx = np.linspace(min(analysis.wds[0]), max(analysis.wds[0]), 1000)
    kde = st.gaussian_kde(analysis.wds[0])
    ax_histy.fill_betweenx(xx, 0, kde(xx), color='tab:blue', alpha=0.5)
    xx = np.linspace(min(analysis.wds[1]), max(analysis.wds[1]), 1000)
    kde = st.gaussian_kde(analysis.wds[1])
    ax_histy.fill_betweenx(xx, 0, kde(xx),  color='tab:orange', alpha=0.5)
    ax_histy.tick_params(axis='both', labelsize=8)
    ax_histy.set_ylim(ymin, ymax)
    ax_histy.set_xticks([])
    ax_histy.set_yticks([])
    ax_histy.set_xlabel('Density', fontsize=8)
    ax_histy.set_xlim(left=0)
    return fig


class random_grouped_scatter:
    def __init__(self, linewidth=0.1, edgecolors='k', s=20):
        self.points = []
        self.linewidth = linewidth
        self.edgecolors = edgecolors
        self.s = s
    
    def add(self, x, y, color):
        for _x, _y in zip(x, y):
            self.points.append({'x':_x, 'y':_y, 'color':color})
    
    def plot(self, ax):
        random.shuffle(self.points)
        for p in self.points:
            ax.scatter(p['x'], p['y'], linewidth=self.linewidth, edgecolors=self.edgecolors, s=self.s, 
                       color=p['color'])
    
    
def rundown_plot(analysis):
    
    # Initiate figure
    fig = plt.figure(figsize=(4, 4))
    gs = fig.add_gridspec(2, 2,  width_ratios=(9, 2), height_ratios=(2, 9),
                          left=0.2, right=0.8, bottom=0.2, top=0.8,
                          wspace=0.1, hspace=0.1)
    ax = fig.add_subplot(gs[1, 0])
    ax_dimx = fig.add_subplot(gs[0, 0])
    ax_dimy = fig.add_subplot(gs[1, 1])
        
    # Scatter
    b = np.array([r == 'Pol' for r in analysis.unipol])    
    r = random_grouped_scatter()
    r.add(analysis.cyts[(analysis.l109r == 0) * ~b], analysis.mems[(analysis.l109r == 0) * ~b], lighten('tab:blue'))
    r.add(analysis.cyts[(analysis.l109r == 0) * b], analysis.mems[(analysis.l109r == 0) * b], 'tab:blue')
    r.add(analysis.cyts[(analysis.l109r == 1) * ~b], analysis.mems[(analysis.l109r == 1) * ~b], lighten('tab:orange'))
    r.add(analysis.cyts[(analysis.l109r == 1) * b], analysis.mems[(analysis.l109r == 1) * b], color='tab:orange')
    r.plot(ax)
    
    # Plot model
    ax.plot(analysis.res_x[0], analysis.res_y[0], c='k', linewidth=1)
    ax.plot(analysis.res_x[1], analysis.res_y[1], c='k', linewidth=1)
#     ax.fill_between(analysis.res_x[0], analysis.all_fits_lower[0], analysis.all_fits_upper[0],
#                    alpha=0.1, color='tab:blue')
#     ax.fill_between(analysis.res_x[1], analysis.all_fits_lower[1], analysis.all_fits_upper[1],
#                    alpha=0.1, color='tab:orange')
    ax.tick_params(axis='both', labelsize=8)
    
    # Add lines and ticks
    xlim = ax.get_xlim()
    ylim = ax.get_ylim()
    for i in np.arange(-10, 10, 0.5):
        ax.plot([-10, 10], [i - 10, i + 10], c='0.9', zorder=-100, linewidth=1)
    ax.set_xticks(np.arange(-10, 0))
    ax.set_yticks(np.arange(-10, 0))
    minor_ticks(ax, [-10, 0], [-10, 0])
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.xaxis.set_major_formatter(fake_log)
    ax.yaxis.set_major_formatter(fake_log) 
    ax.set_xlabel('Cytoplasmic concentration (M)', fontsize=8)
    ax.set_ylabel('Membrane concentration (M)', fontsize=8)
    
    # Set axis limits
    ax.set_ylim(min(analysis.mems) - 0.1, max(analysis.mems) + 0.2)
    ax.set_xlim(min(analysis.cyts) - 0.1, max(analysis.cyts) + 0.15)
    xlim = ax.get_xlim()
    ylim = ax.get_ylim()
    
    # Scatter - legend
    ax.scatter([], [], c='tab:blue', edgecolors='k', s=20, label='PAR-2 (WT)', linewidth=0.1)
    ax.scatter([], [], c='tab:orange', edgecolors='k', s=20, label='PAR-2 (L109R)', linewidth=0.1)
    ax.legend(fontsize=6, frameon=False, loc='upper left')
    
    # Cytoplasmic dimer fraction
    ax_dimx.plot(analysis.res_x[0], analysis.cyt_dim[0], c='tab:blue')
    ax_dimx.plot(analysis.res_x[1], analysis.cyt_dim[1], c='tab:orange')
#     ax_dimx.fill_between(analysis.res_x[0], analysis.cyt_dim_lower[0], analysis.cyt_dim_upper[0], color='tab:blue',
#                         alpha=0.1)
#     ax_dimx.fill_between(analysis.res_x[1], analysis.cyt_dim_lower[1], analysis.cyt_dim_upper[1], color='tab:orange',
#                         alpha=0.1)
    ax_dimx.set_ylim(-5, 105)
    ax_dimx.set_xlim(*xlim)
    ax_dimx.set_xticks([])
    ax_dimx.set_ylabel('% Dimer\n(cyt)', fontsize=8)
    ax_dimx.tick_params(axis='both', labelsize=8)
    
    # Membrane dimer fraction
    ax_dimy.plot(analysis.mem_dim[0], analysis.res_y[0], c='tab:blue')
    ax_dimy.plot(analysis.mem_dim[1], analysis.res_y[1], c='tab:orange')
#     ax_dimy.fill_betweenx(analysis.res_y[0], analysis.mem_dim_lower[0], analysis.mem_dim_upper[0], color='tab:blue',
#                           alpha=0.1)
#     ax_dimy.fill_betweenx(analysis.res_y[1], analysis.mem_dim_lower[1], analysis.mem_dim_upper[1], color='tab:orange',
#                          alpha=0.1)
    ax_dimy.set_xlim(-5, 105)
    ax_dimy.set_ylim(*ylim)
    ax_dimy.set_yticks([])
    ax_dimy.set_xlabel('% Dimer\n(mem)', fontsize=8)
    ax_dimy.tick_params(axis='both', labelsize=8)
    
    # Text box
    s1 = r'$w_d$' + f' (wt) = {analysis.wd_full[0]:.2f} (AUC)'
    s2 = r'$w_d$' + f' (L109R) = {analysis.wd_full[1]:.2f} (fit)'
    s3 = r'$w_m$' + f' = {analysis.wm_full:.2f} (fit)'
    s = s1 + '\n' + s2 + '\n' + s3
    ax.text(x=0.4, y=0.05, s=s, transform =ax.transAxes, horizontalalignment='left', verticalalignment='bottom',
           fontsize=6, bbox=dict(facecolor='w', edgecolor='k', linewidth=0.5))
    
    return fig

In [5]:
# fig = energies_scatter(analysis_main)
# fig = energies_scatter(analysis_s1)
# fig = energies_scatter(analysis_s2)
# fig = energies_scatter(analysis_s3)
# fig = energies_scatter(analysis_s4)

In [6]:
fig = rundown_plot(analysis_main)
fig.savefig('Figs/model_fit_main.png', dpi=600, transparent=True)

fig = rundown_plot(analysis_s1)
fig.savefig('Figs/model_fit_s1.png', dpi=600, transparent=True)

fig = rundown_plot(analysis_s2)
fig.savefig('Figs/model_fit_s2.png', dpi=600, transparent=True)

# fig = rundown_plot(analysis_s3)
# fig = rundown_plot(analysis_s4)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>