# Analysis of [Exp Name]
## Performed [Date]

* Evernote experimental note:
* Evernote analysis note with graphs from this notebook:
* The goal of this notebook is to analyze shift experiments. 

### Load modules

In [None]:
from __future__ import division

# import modules
import sys
import os
# from copy import copy
from pprint import pprint # for human readable file output
import cPickle as pickle
import numpy as np
import scipy.stats as sps
import scipy
from scipy.optimize import least_squares, curve_fit
import pandas as pd
pd.options.display.float_format = '{:,.3f}'.format
from IPython.display import display, HTML

from random import sample

# load mm3 modules, always reload the plotting modules
%load_ext autoreload
%autoreload 2
sys.path.insert(0, '../../mm3/') # path to mm3 folder
import mm3_helpers as mm3
%aimport mm3_plots

# plotting modules and settings. 
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
# %matplotlib
import seaborn as sns
sns.set(style="ticks", color_codes=True, font_scale=1.25)

## Functions for fitting and plotting

### Load experiment specific files
* Parameters, specs, and cell data

In [None]:
param_file_path = './params.yaml'
params = mm3.init_mm3_helpers(param_file_path)

# load specs file
with open('../specs.pkl', 'r') as specs_file:
    specs = pickle.load(specs_file)
    
# load cell data dict
with open('./complete_cells.pkl', 'r') as cell_file:
    Complete_Cells = pickle.load(cell_file)

In [None]:
# picture taking interval in minutes
time_int = 

# shift time, the timepoint, not the actual minute
shift_t = 

# used for coloring single traces
tif_width = 

### Create directory for plots

In [None]:
plot_dir = './plots/'
if not os.path.exists(plot_dir):
    os.makedirs(plot_dir)
save_plots = False

## Filter cells
* We want to limit the cells to those with continuous lineages
* This also means we will only be looking at mother cells
* Hopefully there is no need for additional filtering, but we should ensure that the segmentations are good (we can look at the lineage maps)

In [None]:
# limit to continuous lineages born between these two times
t1 = 100
t2 = 1000

Cells = mm3_plots.find_cells_of_birth_label(Complete_Cells, label_num=1)
# Cells = mm3_plots.find_cells_born_after(Cells, born_after=75)
# Cells = mm3_plots.find_cells_born_before(Cells, born_before=t2)
# Cells = mm3_plots.filter_by_stat(Cells, center_stat='mean', std_distance=3)
Lineages = mm3_plots.organize_cells_by_channel(Cells, specs)
Lineages = mm3_plots.find_continuous_lineages(Lineages, t1=t1, t2=t2)

# number of cells
n = len(Cells)
print('There are {} Cells'.format(n))

# tell us how many cells we have
n_fovs = 0
n_peaks = 0 
n_cells = 0
for fov, peaks in Lineages.iteritems():
    n_fovs += 1
    for peak, lin in peaks.iteritems():
        n_peaks += 1
        for cell_id, cell in lin.iteritems():
            n_cells += 1

print('There are {} FOVs, {} channels, and {} mother cells.'.format(n_fovs, n_peaks, n_cells))

### Save filtered cells pickle

In [None]:
# with open(os.path.join(params['cell_dir'],'mother_cells.pkl'), 'wb') as cell_file:
#     pickle.dump(Cells, cell_file, protocol=pickle.HIGHEST_PROTOCOL)
    
# with open(os.path.join(params['cell_dir'],'continuous_lineages.pkl'), 'wb') as cell_file:
#     pickle.dump(Lineages, cell_file, protocol=pickle.HIGHEST_PROTOCOL)

## Plot individual continuous lineages

In [None]:
fig, ax = mm3_plots.saw_tooth_plot(Lineages, tif_width=tif_width, mothers=True)

# plot a line for the shift time
for axis in ax:
    axis.axvline(x=shift_t, linewidth=2, color='g', ls='--', alpha=0.5, label='Shift-up time')
    axis.set_xlim([t1, t2])

fig.suptitle('All Continuous Lineage Traces', size=26)
fig.subplots_adjust(top=0.96)

plt.show()

if False: fig.savefig(os.path.join(plot_dir, 'single_traces'), dpi=200)

### Plot individual lineages with single fit lines

In [None]:
mothers = True

sns.set(style="ticks", palette="pastel", color_codes=True, font_scale=1.25)

for fov, peaks in Lineages.iteritems():
    
    fig, axes = plt.subplots(ncols=1, nrows=len(peaks), figsize=(10, len(peaks)),
                         sharex=True) #, sharey=True)
    
    # if there is only one peak, then you need to make the axes into a list
    if len(peaks) == 1:
        ax = [axes]
    else:
        ax = axes.flat
    
    ax_i = 0
    ax_second = [] # list holds handles for second axis
    
    max_div_length = 0
    min_birth_length = 10

    for peak, lin in peaks.iteritems():
        # this is to map mothers to daugthers with lines
        last_div_time = None
        last_length = None
        
        # save lambdas and r^2 for plotting
        lin_lambdas = []
        lin_r_squared = []
        lin_plot_times = [] # records time points to plot the above to with
        
        # turn it into a list so it retains time order
        lin = [(cell_id, cell) for cell_id, cell in lin.iteritems()]
        # sort cells by birth time for the hell of it.
        lin = sorted(lin, key=lambda x: x[1].birth_time)

        for cell_id, cell in lin:
            ax[ax_i].semilogy(cell.times_w_div, cell.lengths_w_div,
                           color='b', lw=3, alpha=0.9)

            if mothers:
                # draw a connecting lines betwee mother and daughter
                if cell.birth_time == last_div_time:
                    ax[ax_i].semilogy([last_div_time, cell.birth_time],
                                   [last_length, cell.sb],
                                   color='b', lw=3, alpha=0.9)

                # record the last division time and length for next time
                last_div_time = cell.division_time

            # save the max div length for axis plotting
            last_length = cell.sd
            if last_length > max_div_length:
                max_div_length = last_length
                
            last_birth_length = cell.sb
            if min_birth_length > last_birth_length:
                min_birth_length = last_birth_length
                
            # draw on fit line for all these cells
            y_fit, r_squared = mm3_plots.produce_fit(cell)
            ax[ax_i].semilogy(cell.times_w_div, y_fit,
                              color='blue', ls='--', lw=1, alpha=1)

            
            # record values for plotting on second axis
            lin_lambdas.append(cell.elong_rate)
            lin_r_squared.append(r_squared)
            # use mid point time as value. 
            lin_plot_times.append(cell.birth_time + int(cell.tau / 2))
            
        # print the R^2 or lambda value on second axis
        second_axis = ax[ax_i].twinx()
#         second_axis.yaxis.set_label_position("right")
        second_axis.plot(lin_plot_times, lin_r_squared, lw=0.5, ls='-', color='red', alpha=0.75)
#         second_axis.plot(lin_plot_times, lin_lambdas, 'o', color='red')
        ax_second.append(second_axis)
    
        ax_i += 1

    # removing labels and stuff for cell size plots
    for axis in ax:
        axis.set_ylim([1, max_div_length])
#         axis.axvline(x=975, linewidth=1, color='g')
#         axis.axvline(x=1005, linewidth=1, color='r') # nominal shift up time

        axis.axes.get_xaxis().set_visible(False)
        # this is the one that got rid of the y ticks and label
        axis.axes.get_yaxis().set_visible(False)
    #     axis.axis("off")

    # removing labels for R^2 plots
    for axis in ax_second:
    #     axis.yaxis.tick_right()
    #     axis.yaxis.set_tick_position('right')
        axis.set_ylim([0.95, 1])
        axis.axes.get_yaxis().set_visible(False) # this is the one that got rid of the y ticks and label


    # ax_second[-1].yaxis.set_ticks_position('both')
    ax_second[-1].axes.get_yaxis().set_visible(True)
    ax_second[-1].set_yticks([0.95, 0.975, 1.0])
    # ax_second[-1].set_ylabel('R^2') 

    # special formatting for final plot. 
    # ax[-1].axis("on")

    ax[-1].axes.get_xaxis().set_visible(True)
    ax[-1].set_xlabel('Time [min]', size=16)
    # ax[-1].axes.get_yaxis().set_visible(True) # this is the one that got rid of the y ticks and label

    sns.despine(left=True)

    try:
        plt.tight_layout()
    except:
        print(fov)
    plt.subplots_adjust(hspace=0)
# #     if len(peaks) > 1:
# #         plt.tight_layout()
#     plt.subplots_adjust(hspace=0)
#     plt.subplots_adjust(top=0.875, bottom=0.2)
    fig.suptitle('FOV %d' % fov, size=20)

    plt.show()
    if save_plots: fig.savefig(plot_dir + 'shift_traces_fov_%d.png' % fov, dpi=100)

## Distribution of linear fit r_squared values
* Want to find r_squared values of all cells around this shift up time
* Then we can create a threshold for fits that are particularly bad

In [None]:
all_linear_r2 = []

for fov, peaks in Lineages.iteritems():
    for peak, lin in peaks.iteritems():
        for cell_id, cell in lin.iteritems():
            # calculate fit
            y_fit, r_squared = mm3_plots.produce_fit(cell)

            all_linear_r2.append(r_squared)

sns.set(style="ticks", palette="pastel", color_codes=True, font_scale=1.25)
hist_options = {'histtype' : 'step', 'lw' : 2, 'color' : 'b'}
kde_options = {'lw' : 2, 'linestyle' : '--', 'color' : 'b'}

# create figure, going to apply graphs to each axis sequentially
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=[10, 5])
ax = np.ravel(axes)

sns.distplot(all_linear_r2, ax=ax[0], bins=100, norm_hist=True,
             hist_kws=hist_options, kde_kws=kde_options)

ax[0].axvline(x=np.mean(all_linear_r2), linewidth=1, color='blue', 
              label='Mean = {:.3f}'.format(np.mean(all_linear_r2)))
ax[0].axvline(x=np.median(all_linear_r2), linewidth=1, color='red', 
              label='Median = {:.3f}'.format(np.median(all_linear_r2)))

plt.legend(fontsize=16)

plt.tight_layout()
sns.despine()
plt.subplots_adjust(top=0.9, bottom=0.15)
ax[0].set_xlim([0.9, 1.05])
ax[0].set_title('Distribution of Linear Fit $R^2$ Values', fontsize=24)
ax[0].set_xlabel('$R^2$', size=20)

plt.show()
fig.savefig(plot_dir + 'Lin_R_squared_distributions', dpi=100)

### Plot individual lineages with bilinear fit lines

In [None]:
mothers = True
tif_width = 2560

sns.set(style="ticks", palette="pastel", color_codes=True, font_scale=1.25)

for fov, peaks in Lineages.iteritems():
    
    fig, axes = plt.subplots(ncols=1, nrows=len(peaks), figsize=(10, 1.5*len(peaks)),
                         sharex=True) #, sharey=True)
    if len(peaks) > 1:
        ax = axes.flat
    else:
        ax = [axes]
        
    ax_i = 0
    ax_second = [] # list holds handles for second axis

    max_div_length = 0
    min_birth_length = 10
    
    # array holds all bilinear shift times for all cells
    t_shift_times = [] 

    for peak, lin in peaks.iteritems():
        # this is to map mothers to daugthers with lines
        last_div_time = None
        last_length = None
        
        # save lambdas and r^2 for plotting
        lin_lambdas = []
        lin_r_squared = []
        lin_plot_times = [] # records time points to plot the above to with
        bilin_r_squared = []
        
        # turn it into a list so it retains time order
        lin = [(cell_id, cell) for cell_id, cell in lin.iteritems()]
        # sort cells by birth time for the hell of it.
        lin = sorted(lin, key=lambda x: x[1].birth_time)

        for cell_id, cell in lin:
            ax[ax_i].semilogy(cell.times_w_div, cell.lengths_w_div,
                           color='b', lw=3, alpha=0.9)

            if mothers:
                # draw a connecting lines betwee mother and daughter
                if cell.birth_time == last_div_time:
                    ax[ax_i].semilogy([last_div_time, cell.birth_time],
                                   [last_length, cell.sb],
                                   color='b', lw=3, alpha=0.9)

                # record the last division time and length for next time
                last_div_time = cell.division_time

            # save the max div length for axis plotting
            last_length = cell.sd
            if last_length > max_div_length:
                max_div_length = last_length
                
            last_birth_length = cell.sb
            if min_birth_length > last_birth_length:
                min_birth_length = last_birth_length
                
                
            # draw on fit line for all these cells
            y_fit, r_squared = mm3_plots.produce_fit(cell)
#             ax[ax_i].semilogy(cell.times_w_div, y_fit,
#                               color='blue', ls='--', lw=1, alpha=1)
            
            # calculate bilinear fitline
            y_fit, r_squared_bilin, t_shift, len_at_shift = mm3_plots.produce_bilin_fit(cell)
            bilin_r_squared.append(r_squared_bilin)
            t_shift_times.append(t_shift)
            
            # draw bilinear fit 
            ax[ax_i].semilogy(cell.times_w_div, y_fit,
                              color='blue', ls='--', lw=1, alpha=1)
            # mark shift time with a circle
            ax[ax_i].plot(t_shift, len_at_shift, 'o', color='blue', ms=5)

            # plot shift time for all cells
#             ax[ax_i].axvline(x=t_shift, linewidth=1, color='blue')
            
            # record values for plotting on second axis
            lin_lambdas.append(cell.elong_rate)
            lin_r_squared.append(r_squared)
            # use mid point time as value. 
            lin_plot_times.append(cell.birth_time + int(cell.tau / 2))
            
        # print the R^2 or lambda value on second axis
        second_axis = ax[ax_i].twinx()
#         second_axis.yaxis.set_label_position("right")
        second_axis.plot(lin_plot_times, lin_r_squared, lw=0.5, ls='-', color='red', alpha=0.75,
                         label='Linear $R^2$')
#         second_axis.plot(lin_plot_times, lin_lambdas, 'o', color='red')
        second_axis.plot(lin_plot_times, bilin_r_squared, lw=0.5, ls='-', color='blue', alpha=0.75,
                         label='Bilinear $R^2$')
        ax_second.append(second_axis)
    
        ax_i += 1

    # removing labels and stuff for cell size plots
    for axis in ax:
        axis.set_ylim([min_birth_length, max_div_length])
#         axis.axvline(x=975, linewidth=1, color='g')
#         axis.axvline(x=1005, linewidth=1, color='r') # nominal shift up time

        axis.axes.get_xaxis().set_visible(False) # this is the one that got rid of the y ticks and label
        axis.axes.get_yaxis().set_visible(False) # this is the one that got rid of the y ticks and label
    #     axis.axis("off")

    # removing labels for R^2 plots
    for axis in ax_second:
    #     axis.yaxis.tick_right()
    #     axis.yaxis.set_tick_position('right')
        axis.set_ylim([0.95, 1])
        axis.axes.get_yaxis().set_visible(False) # this is the one that got rid of the y ticks and label

    # ax_second[-1].yaxis.set_ticks_position('both')
    ax_second[-1].axes.get_yaxis().set_visible(True)
    ax_second[-1].set_yticks([0.95, 0.975, 1.0])
    # ax_second[-1].set_ylabel('R^2') 

    # special formatting for final plot. 
    # ax[-1].axis("on")
    
    ax_second[-1].legend(loc='lower right')
    ax[-1].axes.get_xaxis().set_visible(True)
    ax[-1].set_xlabel('Time [min]', size=16)
    # ax[-1].axes.get_yaxis().set_visible(True) # this is the one that got rid of the y ticks and label

    sns.despine(left=True)

    plt.tight_layout()
    plt.subplots_adjust(hspace=0)
    # plt.subplots_adjust(top=0.875, bottom=0.1)
    fig.suptitle('FOV %d' % fov, size=24)
    
#     break

    plt.show()
    fig.savefig(plot_dir + 'shift_traces_bilin_fov_%d.png' % fov, dpi=100)

## Plot distribution of t shift times
* This is for all bilinear fits on all cells
* Maybe special times will come out of the noise

### Calculate bi linear fit times

In [None]:
t_shift_times = [] 
for fov, peaks in Lineages.iteritems():
    for peak, lin in peaks.iteritems():
        for cell_id, cell in lin.iteritems():
            
            # calculate bilinear fitline
            y_fit, r_squared_bilin, t_shift, len_at_shift = mm3_plots.produce_bilin_fit(cell)
            bilin_r_squared.append(r_squared_bilin)
            t_shift_times.append(t_shift)

### Plot it

In [None]:
sns.set(style="ticks", palette="pastel", color_codes=True, font_scale=1.25)
hist_options = {'histtype' : 'step', 'lw' : 2, 'color' : 'b'}

# create figure, going to apply graphs to each axis sequentially
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=[10, 5])
ax = np.ravel(axes)

bin_edges = np.array(range(int(np.min(t_shift_times)), int(np.max(t_shift_times))+1, 2)) + 0.5

sns.distplot(t_shift_times, ax=ax[0], bins=bin_edges, norm_hist=False,
             hist_kws=hist_options, kde=False)

ax[0].axvline(x=shift_t, linewidth=1, color='g', ls='--')
# ax[0].axvline(x=1005, linewidth=1, color='r') # nominal shift up time

plt.tight_layout()
sns.despine()
plt.subplots_adjust(top=0.9, bottom=0.15, left=0.1)
# ax[0].set_xlim([800, 1250])
ax[0].set_title('Distribution of Bilinear Fit Inflections', fontsize=24)
ax[0].set_ylabel('Count', size=20)
ax[0].set_xlabel('Time [min]', size=20)

plt.show()
fig.savefig(plot_dir + 'bilinear_fit_inflections', dpi=100)

## Distribution of bilinear fit R^2 values

In [None]:
sns.set(style="ticks", palette="pastel", color_codes=True, font_scale=1.25)
hist_options = {'histtype' : 'step', 'lw' : 2, 'color' : 'b'}
kde_options = {'lw' : 2, 'linestyle' : '--', 'color' : 'b'}

# create figure, going to apply graphs to each axis sequentially
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=[10, 5])
ax = np.ravel(axes)

sns.distplot(bilin_r_squared, ax=ax[0], bins=100, norm_hist=True,
             hist_kws=hist_options, kde_kws=kde_options)

plt.tight_layout()
sns.despine()
plt.subplots_adjust(top=0.9, bottom=0.15)
ax[0].set_xlim([0.9, 1.05])
ax[0].set_title('Distribution of Bilinear Fit $R^2$ Values', fontsize=24)
ax[0].set_xlabel('$R^2$', size=20)

plt.show()
# fig.savefig(plot_dir + 'bilin_R_squared_distributions', dpi=100)

## Distribution of both linear and bilinear R^2 values

In [None]:
sns.set(style="ticks", palette="pastel", color_codes=True, font_scale=1.25)
hist_options = {'histtype' : 'step', 'lw' : 2, 'color' : 'b'}
kde_options = {'lw' : 2, 'linestyle' : '--', 'color' : 'b'}

# create figure, going to apply graphs to each axis sequentially
fig, axes = plt.subplots(nrows=2, ncols=1, figsize=[10, 10])
ax = np.ravel(axes)

sns.distplot(all_linear_r2, ax=ax[0], bins=100, norm_hist=True,
             hist_kws=hist_options, kde_kws=kde_options)

ax[0].set_xlim([0.9, 1.05])
ax[0].set_title('Linear Fit $R^2$ Values', fontsize=20)
ax[0].set_xlabel('$R^2$', size=20)

sns.distplot(bilin_r_squared, ax=ax[1], bins=100, norm_hist=True,
             hist_kws=hist_options, kde_kws=kde_options)

ax[1].set_xlim([0.9, 1.05])
ax[1].set_title('Bilinear Fit $R^2$ Values', fontsize=20)
ax[1].set_xlabel('$R^2$', size=20)

plt.tight_layout()
sns.despine()
plt.subplots_adjust(top=0.9, bottom=0.1)
fig.suptitle('$R^2$ Distributions', size=24)

plt.show()
fig.savefig(plot_dir + 'R_squared_distributions', dpi=100)

## Log-linear Cumulative growth over time

In [None]:
sns.set(style="ticks", palette="pastel", color_codes=True, font_scale=1.25)
fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(12, 12))

ax = axes.flat

# This dictionary will carry lengths per time point. 
all_lin_lengths = {}

for fov, peak in Lineages.iteritems():
# lin = Continuous_Lineages[test_fov][test_peak]

    for peak, lin in peak.iteritems():
        # sort the cells by birth time
        lin = [(cell_id, cell) for cell_id, cell in lin.iteritems()]
        lin = sorted(lin, key=lambda x: x[1].birth_time)

        lin_times = np.array([])
        lin_lengths = np.array([])

        # get the offset length. This should be very first length
        offset = lin[0][1].lengths[0] * params['pxl2um'] 
        
#         if offset > 4:
#             print(lin[0][0])

        # create the cumulative times and lengths. 
        for cell_tuple in lin:
            # add the times to the list
            lin_times = np.concatenate((lin_times, cell_tuple[1].times))

            # update lengths to reflect true size in um
            current_lengths = np.array(cell_tuple[1].lengths) * params['pxl2um']

            # update offset to incorporate birth length
            offset -= np.log(current_lengths[0])

            # offset new length values 
            new_lengths = np.log(current_lengths) + offset

            # add them to the big list
            lin_lengths = np.concatenate((lin_lengths, new_lengths))

            # update offset to reflect the previous division length
            offset = lin_lengths[-1]
            
        # add line to graph   
        ax[0].plot(lin_times, lin_lengths, c='b', alpha=0.1)

        # add data to dictionary
        for i, t in enumerate(lin_times):
            if t in all_lin_lengths:
                all_lin_lengths[t].append(lin_lengths[i])
            else:
                all_lin_lengths[t] = [lin_lengths[i]]
                
# create average line from dictionary
mean_len_by_time = []
for t, lengths in all_lin_lengths.iteritems():
    mean_len_by_time.append((t, np.mean(lengths)))
mean_len_by_time = sorted(mean_len_by_time, key=lambda x: x[0])
mean_times, mean_lens = zip(*mean_len_by_time)
ax[0].plot(mean_times, mean_lens, color='blue',
           label='Average')

# ax[0].axvline(x=1010, color='r', alpha=0.75, label='Shift time, t=1010')
        
# ax[0].set_xlim([800,1250])
# ax[0].set_ylim([0,15])
ax[0].set_title('Log-linear Cumulative Growth Over Time', size=22)
ax[0].set_ylabel('Log(Cumulative Length [um])', size=20)
# ax[0].set_xlabel('Time [min]', size=20)
ax[0].legend(loc='upper left', fontsize=16)


# Plot derivative and rolling average of derivative
mean_lens_d = np.diff(mean_lens)
ax[1].plot(mean_times[:-1], mean_lens_d, alpha=0.9)

time_window = 5
xlims = [min(mean_times[:-1]), max(mean_times[:-1])]
bin_mean, bin_edges, bin_n = sps.binned_statistic(mean_times[:-1], mean_lens_d,
                statistic='mean', bins=np.arange(xlims[0]-1, xlims[1]+1, time_window))
bin_centers = bin_edges[:-1] + np.diff(bin_edges) / 2
ax[1].plot(bin_centers, bin_mean, color='blue', lw=2,
           label='Rolling average %d min window' % time_window)

# ax[1].set_xlim([800,1250])
ax[1].set_ylim([-0.01, 0.1])
ax[1].set_title('Derivative of Mean Cumulative Length', size=22)
ax[1].set_xlabel('Time [min]', size=20)
ax[1].set_ylabel('$d$Log(L)/$dt$', size=20)
ax[1].legend(loc='upper left', fontsize=16)
    
sns.despine()
plt.show()

fig.savefig(plot_dir + 'cumulative_length', dpi=100)

# Individual cumulative length differentiation

In [None]:
# This dictionary carries all the lengths by time point, and rate of change by timepoint
stats_by_time = {'lengths_by_time' : {},
                 'diffs_by_time' : {}}

# and this dictionary will carry times and lengths for each lineage
lineage_lengths = {}

for fov, peak in Lineages.iteritems():
# lin = Continuous_Lineages[test_fov][test_peak]

    lineage_lengths[fov] = {}

    for peak, lin in peak.iteritems():
        # sort the cells by birth time
        lin = [(cell_id, cell) for cell_id, cell in lin.iteritems()]
        lin = sorted(lin, key=lambda x: x[1].birth_time)

        lin_times = np.array([])
        lin_lengths = np.array([])

        # get the offset length. This should be very first length
        offset = lin[0][1].lengths[0] * params['pxl2um'] 

        # create the cumulative times and lengths. 
        for cell_tuple in lin:
            # add the times to the list
            lin_times = np.concatenate((lin_times, cell_tuple[1].times))

            # update lengths to reflect true size in um
            current_lengths = np.array(cell_tuple[1].lengths) * params['pxl2um']

            # update offset to incorporate birth length
            offset -= np.log(current_lengths[0])

            # offset new length values 
            new_lengths = np.log(current_lengths) + offset

            # add them to the big list
            lin_lengths = np.concatenate((lin_lengths, new_lengths))

            # update offset to reflect the previous division length
            offset = lin_lengths[-1]
            
        # calculate bilinear fit on all lengths
        y_fit, r_squared_bilin, t_shift, len_at_shift = mm3_plots.produce_bilin_fit3(lin_times, lin_lengths)

   
        # this dictionary holds all the times, lengths, and derivatives
        n_diff = 1
        lin_lengths_diff = np.diff(lin_lengths[::n_diff])
        diff_times = lin_times[n_diff::n_diff]
        lineage_lengths[fov][peak] = {'times' : lin_times * time_int,
                                      'lengths' : lin_lengths,
                                      'diff' : lin_lengths_diff,
                                      'diff_times' : diff_times * time_int, 
                                      'bi_fit' : y_fit, 
                                      'bi_tshift' : t_shift * time_int,
                                      'bi_tshift_len' : len_at_shift}
        
        # add data to time point centric dictionary 
        for i, t in enumerate(lin_times):
            # for the lengths per time
            if t in stats_by_time['lengths_by_time']:
                stats_by_time['lengths_by_time'][t].append(lin_lengths[i])
            else:
                stats_by_time['lengths_by_time'][t] = [lin_lengths[i]]
                
        for i, t in enumerate(diff_times):
            if t in stats_by_time['diffs_by_time']:
                stats_by_time['diffs_by_time'][t].append(lin_lengths_diff[i])
            else:
                stats_by_time['diffs_by_time'][t] = [lin_lengths_diff[i]]
                
# calculate timepoint by timepoint stats
stats_by_time['all_diff_times'] = []
stats_by_time['diff_means'] = []
stats_by_time['diff_stds'] = []
stats_by_time['diff_SE'] = []
stats_by_time['diff_n'] = []
for t, values in stats_by_time['diffs_by_time'].items():
    stats_by_time['all_diff_times'].append(t * time_int)
    stats_by_time['diff_means'].append(np.mean(values))
    stats_by_time['diff_stds'].append(np.std(values))
    stats_by_time['diff_SE'].append(np.std(values) / np.sqrt(len(values)))
    stats_by_time['diff_n'].append(len(values))


## Plot the cumulative growths

In [None]:
sns.set(style="ticks", palette="pastel", color_codes=True, font_scale=1.25)
fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(12, 12))

ax = axes.flat

for fov, peaks in lineage_lengths.items():
    for peak, data in peaks.items():        
        # plot cumulative length
        ax[0].plot(data['times'], data['lengths'], c='b', lw=3, alpha=0.4, 
                   label='Cumulative length')
        
        # draw bilinear fit 
        ax[0].plot(data['times'], data['bi_fit'],
                          color='blue', ls='--', lw=2, alpha=0.2, 
                   label='Bilinear fit')

        # plot derivative 
#         ax[1].plot(data['diff_times'], data['diff'], c='b', lw=1, alpha=0.1, 
#                   label='Numerical derivative')
#         ax[1].scatter(data['diff_times'], data['diff'], c='b', s=3, alpha=0.4)
        
for fov, peaks in lineage_lengths.items():
    for peak, data in peaks.items():       
        # mark shift time with a circle
        ax[0].plot(data['bi_tshift'], data['bi_tshift_len'], 'o', color='blue', ms=4)
        
# plot average and standard deviation of the rate of change.
ax[1].plot(stats_by_time['all_diff_times'], stats_by_time['diff_means'], c='r', lw=2, alpha=1, 
          label='Average rate of change')
ax[1].fill_between(stats_by_time['all_diff_times'], 
                np.array(stats_by_time['diff_means']) - np.array(stats_by_time['diff_SE']), 
                np.array(stats_by_time['diff_means']) + np.array(stats_by_time['diff_SE']),
                facecolor='r', alpha=0.25)
# ax[1].errorbar(stats_by_time['all_diff_times'], stats_by_time['diff_means'], stats_by_time['diff_SE'],
#                c='r', lw=2, alpha=1, elinewidth=1, label='Average inst. rate of change with SE')

# vertical lines for shift up time
ax[0].axvline(x=shift_t * time_int, linewidth=2, color='g', ls='--', alpha=0.5, label='Shift-up time')
ax[1].axvline(x=shift_t * time_int, linewidth=2, color='g', ls='--', alpha=0.5, label='Shift-up time')

# format plot
# ax[0].set_xlim([240, 330])
# ax[0].set_ylim([3, 8])
ax[0].set_title('Log-linear Cumulative Growth Over Time (All Lineages)', size=22)
ax[0].set_ylabel('Log(Cumulative Length [um])', size=18)
# ax[0].set_xlabel('Time [min]', size=20)
# ax[0].legend(loc='upper left', fontsize=16)

# ax[1].set_xlim([1000, 1300])
ax[1].set_ylim([0, 0.06])
ax[1].set_title('Average Derivative of Cumulative Growth Across Lineages, with SE', size=22)
ax[1].set_ylabel('Average Numerical Derivative', size=18)
ax[1].set_xlabel('Time [min]', size=18)
ax[1].legend(loc='upper left', fontsize=16)

# plt.subplots_adjust(hspace=0.5)
plt.tight_layout()
fig.suptitle('Exp. 20170824 - BS15 - Shift 001, 40min > 25min', size=24)
plt.subplots_adjust(top=0.92)    

plt.show()

if True: fig.savefig(plot_dir + 'cumulative_growth_and_rate_of_change_cont_lins', dpi=100)

## Find all instant rate of length changes for all cells
* As opposed to just those cells from continuous lineages

In [None]:
# This dictionary carries all the lengths by time point, and rate of change by timepoint
stats_by_time = {'diffs_by_time' : {}, 
                 'all_diff_times' : [],
                 'diff_means' : [],
                 'diff_stds' : [],
                 'diff_SE' : [],
                 'diff_n' : []}

# we loop through each cell to find the rate of length change
for cell_id, Cell in Mother_Cells.items():

        # convert lengths to um from pixels and take log
        log_lengths = np.log(np.array(Cell.lengths) * params['pxl2um'])

        # take numerical n-step derivative 
        n_diff = 1
        lengths_diff = np.diff(log_lengths[::n_diff])
        
        # get corresponding times (will be length-1)
        diff_times = Cell.times[n_diff::n_diff]
        
        # add data to time point centric dictionary 
        for i, t in enumerate(diff_times):
            if t in stats_by_time['diffs_by_time']:
                stats_by_time['diffs_by_time'][t].append(lengths_diff[i])
            else:
                stats_by_time['diffs_by_time'][t] = [lengths_diff[i]]                

                
# calculate timepoint by timepoint stats
stats_by_time['all_diff_times'] = []
stats_by_time['diff_means'] = []
stats_by_time['diff_stds'] = []
stats_by_time['diff_SE'] = []
stats_by_time['diff_n'] = []
for t, values in stats_by_time['diffs_by_time'].items():
    stats_by_time['all_diff_times'].append(t * time_int)
    stats_by_time['diff_means'].append(np.mean(values))
    stats_by_time['diff_stds'].append(np.std(values))
    stats_by_time['diff_SE'].append(np.std(values) / np.sqrt(len(values)))
    stats_by_time['diff_n'].append(len(values))


### Plot rate of change over time for all cells

In [None]:
sns.set(style="ticks", palette="pastel", color_codes=True, font_scale=1.25)
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(12, 6))
  
# plot average and standard deviation of the rate of change.
ax.plot(stats_by_time['all_diff_times'], stats_by_time['diff_means'], c='r', lw=2, alpha=0.5, 
          label='Average rate of change')
ax.fill_between(stats_by_time['all_diff_times'], 
                np.array(stats_by_time['diff_means']) - np.array(stats_by_time['diff_SE']), 
                np.array(stats_by_time['diff_means']) + np.array(stats_by_time['diff_SE']),
                facecolor='r')
# ax.errorbar(stats_by_time['all_diff_times'], stats_by_time['diff_means'], stats_by_time['diff_SE'],
#                c='r', lw=2, alpha=1, elinewidth=1, capsize=1, barsabove=True, ecolor='r', capthick=1,
#                label='Average inst. rate of change with SE')

# vertical lines for shift up time
ax.axvline(x=shift_t * time_int, linewidth=2, color='g', ls='--', alpha=0.5, label='Shift-up time')

# format plot
ax.set_xlim([shift_t-200, shift_t+600])
ax.set_ylim([0.01, 0.04])
ax.set_title('Average Derivative All Cells, with SE', size=22)
ax.set_ylabel('Average Numerical Derivative (Time Step = {})'.format(n_diff), size=20)
ax.set_xlabel('Time [min]', size=20)
ax.legend(loc='lower right', fontsize=16)

# make inset for zoom in
# These are in unitless percentages of the figure size. (0,0 is bottom left)
# left, bottom, width, height = [0.6, 0.5, 0.30, 0.30]
# ax2 = fig.add_axes([left, bottom, width, height])
# # ax2.plot(stats_by_time['all_diff_times'], stats_by_time['diff_means'], c='r', lw=2, alpha=0.5, 
# #           label='Average inst. rate of change with SE')
# # ax2.fill_between(stats_by_time['all_diff_times'], 
# #                 np.array(stats_by_time['diff_means']) - np.array(stats_by_time['diff_SE']), 
# #                 np.array(stats_by_time['diff_means']) + np.array(stats_by_time['diff_SE']),
# #                 facecolor='r')
# ax2.errorbar(stats_by_time['all_diff_times'], stats_by_time['diff_means'], stats_by_time['diff_SE'],
#                c='r', lw=2, alpha=1, elinewidth=1, capsize=1, barsabove=True, ecolor='r', capthick=1,
#                label='Average inst. rate of change with SE')
# # vertical lines for shift up time
# ax2.axvline(x=shift_t * time_int, linewidth=2, color='g', ls='--', alpha=0.5, label='Shift-up time')
# ax2.set_xlim([shift_t * time_int - 30, shift_t * time_int + 60])
# ax2.set_ylim([0., 0.05])


plt.tight_layout()
fig.suptitle('Exp. 20170824 - BS15 - Shift 001, 40min > 25min', size=24)
plt.subplots_adjust(top=0.85)
    
plt.show()

fig.savefig(plot_dir + 'avg_deriv_BS15_shift_001.png', dpi=100)