# PDE Modelling Mimic of 190413 experiment

## Objectives 
* Microscope movie was performed on 190413 that simply observed growth of a constitutively fluorescent strain on agar pads, varying in pad size and cell occupation
* Simulate each of the experimental setups executed
* eventually build into a fitting routine

## Model considerations 
* Species
    1. Cell density (sender and pulse cells)
    1. Nutrient density 
    1. mScarlet
* Reactions 
    1. cell growth and diffusion 
        * Cells diffuse very slowly
        * nutrient-dependent growth (from Liu et al 2011, Science) 
        $$  $$ 
    1. Constitutive fluorescence
        * Basal protein expression 
        * initial protein concentration set to fixed point of max nutrient
    1. Dilution and degradation 
        * Assume that all proteins are degradation tagged
    1. Diffusion 
        * Here, you're going to use convoultion of the diffusion kernel
        * Diffusion in/out of cell is considered faster than spatial diffusion at these scales
    1. Parameters
        * We are also assuming, for the moment, that each time point is 6 minutes. Parameters with time dimensions shown below may use different units than the parameter from the cited paper.
        * dx: Length modification of diffusion terms. In the compartmental model, diffusion is calculated via Ficks' first law, where the flux between two adjacent compartments is equal to the flux multiplied by the area of the interface between the components :  
        $\frac{\mathrm{d} C}{\mathrm{d} t} $ 
        in continuous form gives up 
        $\Delta C = D \frac{A}{V} \frac{\Delta C}{\Delta x} = D \frac{2.25 \cdot 5 \cdot \mathrm{scale}^2 \mathrm{mm}^2}{\mathrm{scale} \cdot 2.25^2 \cdot 5 \mathrm{mm}^3} \frac{\Delta C \cdot \mathrm{scale}}{2.25 \mathrm{mm}} = \frac{D \Delta C \mathrm{scale}^2}{2.25^2 \mathrm{mm}^2}$. the dx parameter below is the symbol $A$ in this equation.
        * Dc : Diffusion rate for cells. $7\frac{mm^2}{min}$
        * rc : Division rate of cells. $\frac{1.14}{min}$
        * Kn : Half-point of nutrient availability. 75
        * Dn : Diffusion rate of nutrient. $28\frac{mm^2}{min}$
        * kn : Consumption rate of nutrient by cells
        * Da : Diffusion rate of nutrient. $28\frac{mm^2}{min}$
        * xs : Expression rate of protein. 


In [None]:
# imports
from __future__ import division, print_function
import numpy as np
import pandas as pd
import os
import sys
import string
import selenium
import scipy.integrate as itg
import scipy.optimize as opt
import scipy.interpolate as itp
import scipy.ndimage as ndi

import matplotlib as mpl
mpl.use("Agg")
import seaborn as sns
import itertools

import matplotlib.pyplot as plt 
plt.rcParams['animation.ffmpeg_path'] = '/usr/bin/ffmpeg' # Add the path of ffmpeg here!!

import matplotlib.animation as anm
import skimage.measure
import skimage.filters
import numba
import gc

from multiprocessing import Pool, Process

%load_ext line_profiler

import bokeh
from bokeh.plotting import figure, output_file, save
from bokeh.io import output_notebook, show
from bokeh import palettes, transform
from bokeh.models import LogColorMapper, LogTicker, ColorBar, LinearColorMapper, Ticker
output_notebook()

from IPython.display import HTML

rc = {'lines.linewidth': 2, 
      'axes.labelsize': 18, 
      'axes.titlesize': 24, 
      'xtick.labelsize': 18, 
      'ytick.labelsize': 18, 
      'legend.fontsize': 18,
      'axes.facecolor': 'DFDFE5'}

sns.set_context('paper', rc=rc)

%matplotlib inline

## 2D Discrete Laplacian

In continuous form : 
$$ U_t = \triangle U - \lambda U $$

In discrete form, for point $i$ : 
$$ \Delta U_i = \sum_{1 = w(i,j)}\omega(i,j)(U_i - U_j) - \lambda U_i $$

Use discrete laplacian approximation w/o diagonals for grid spacing, so that we can have zero-flux  boundary conditions. 

$$ L = 
 \begin{pmatrix}
  0 & 1 & 0 \\
  1 & -4 & 1 \\
  0 & 1 & 0 
 \end{pmatrix} $$

I use a convolution function to calculate the diffusion terms. 

# Helper functions used to define the arenas 
### Needs
* read excel or csv files 
* rescaling arrays and contents 
* convert row/col to array index


* disk function, projects circular areas onto an input grid 
* 

In [None]:
# Universal constants
species = 3
scale_s = 0.5
rtol = 1e-3
col_thresh = 0.01
c_i, n_i, s_i = np.arange(species)

def disk(A, center, radius):
    h, w = A.shape
    ind_mat = np.zeros((h, w, 2))
    cx, cy = center
    for i in range(h):
        ind_mat[i,:,0] = np.power(np.arange(w) - cx, 2)
    
    for i in range(w):
        ind_mat[:,i,1] = np.power(np.arange(h) - cy, 2)
    
    outmat = (ind_mat[:,:,0] + ind_mat[:,:,1]) < radius**2
    return outmat

# units : L = mm, T = minutes, concentration in nM = moles / mm^3
# Da = 6 - 1.2 E-2
#LEGACY
# Params :    dx,                          Dc,    rc,   Kn,   Dn,   kn,  Da,  xa,  xs,  ha,  ka, 
#p0 = np.array([np.power((scale/2.25),2),   1e-4, 6e-3, 75,  8e-3,  2, 8e-2, 1e3, 2e-0, 2.3, 40,    
          # pa,   leak   od0
#             5e-5, 1e-8, 0.5], dtype=np.float32)



#@numba.jit('void(float32[:,:,:],float32[:,:,:])', nopython=True, cache=True)
@numba.jit(nopython=True, cache=True)
def calc_diffusion(A, D):
    # Middle
    D[:,1:-1,1:-1] = A[:,1:-1, 2:] + A[:,1:-1, :-2] + A[:,:-2, 1:-1] + A[:,2:, 1:-1] - 4*A[:,1:-1, 1:-1]
    # Edges
    D[:,0,1:-1] = A[:,0, 2:] + A[:,0, :-2] + A[:,1, 1:-1] - 3*A[:,0, 1:-1]
    D[:,-1,1:-1] = A[:,-1, 2:] + A[:,-1, :-2] + A[:,-2, 1:-1] - 3*A[:,-1, 1:-1]
    D[:,1:-1,0] = A[:,2:,0] + A[:,:-2,0] + A[:,1:-1,1] - 3*A[:,1:-1,0]
    D[:,1:-1,-1] = A[:,2:,-1] + A[:,:-2,-1] + A[:,1:-1,-2] - 3*A[:,1:-1,-1]
    # Corners
    D[:,0,0] = A[:,0,1] + A[:,1,0] - 2*A[:,0,0]
    D[:,-1,0] = A[:,-1,1] + A[:,-2,0] - 2*A[:,-1,0]
    D[:,0,-1] = A[:,0,-2] + A[:,1,-1] - 2*A[:,0,-1]
    D[:,-1,-1] = A[:,-1,-2] + A[:,-2,-1] - 2*A[:,-1,-1]

#@numba.jit('float32[:,:](float32[:,:],float32,float32)',nopython=True, cache=True)
@numba.jit(nopython=True, cache=True)
def hill(a, n, k):
    h_ma = 1 - (1 / (1 + (a/k)**n))
    return h_ma

#@numba.jit('float32[:,:](float32[:,:],float32,float32)',nopython=True, cache=True)
@numba.jit(nopython=True, cache=True)
def hillN(a, n, k):
    return 1 / (1 + (a/k)**n)

#@numba.jit('void(float32[:,:,:],float32[:,:,:],float32[:,:,:],float32[:,:])',nopython=True, cache=True)
# @numba.jit(nopython=True, cache=True)
@numba.jit(cache=True)
def calc_f(y, d_y, diff_terms, nut_avail, p0):
    dx, Dc, Dn, rc, Kn, Hn, pn, xs, ps, leak, od = p0
    calc_diffusion(y, diff_terms)
    
    # Nutrient term
    nut_avail[:] = hill(y[n_i,:,:], Hn, Kn)
    
    # Cell growth and diffusion
    d_y[c_i,:,:] = (dx)*Dc*diff_terms[c_i,:,:] + rc * nut_avail * y[c_i,:,:]
    
    # Nutrient consumption
    d_y[n_i,:,:] = (dx)*Dn*diff_terms[n_i,:,:] - pn * nut_avail * y[c_i,:,:].sum(axis=0)
    
    # Synthase production
    d_y[s_i,:,:] = (xs * np.greater(y[c_i,:,:],col_thresh) - rc * y[s_i,:,:]) * nut_avail - ps * y[s_i,:,:]
    

@numba.jit
def f_ivp(t, y, d_y, diff_terms, nut_avail, p0, dims):
    species, n_h, n_w, scale, scale_s = dims
    y.shape = (species, n_h, n_w)
    calc_f(y, d_y, diff_terms, nut_avail, p0)
    
    return d_y.flatten()

def prep_initial_condition(cell_spots, dims, p0, A):
    dx, Dc, Dn, rc, Kn, Hn, pn, xs, ps, leak, od0 = p0
    species, n_h, n_w, scale, scale_s = dims
    
    cells = np.zeros((n_h, n_w), dtype=np.float32)
    for center in cell_spots:
        cells += disk(cells, np.array(center), scale_s)*od0
        
    # Set initial conditions
    A[c_i,:,:] += cells
    
    return A

def sim_omnitray(dims, p0, tmax, initial_array, atol): 
    species, n_h, n_w, scale, scale_s = dims
    args=(np.zeros(initial_array.shape, dtype=np.float32,order='C'), 
          np.zeros(initial_array.shape, dtype=np.float32,order='C'), 
          np.zeros(initial_array.shape[1:], dtype=np.float32,order='C'), 
          p0, dims)
    initial_array.shape = n_h*n_w*species
#     print('Starting')
    f_lambda = lambda t, y : f_ivp(t, y, *args)
    out = itg.solve_ivp(f_lambda, [0, tmax], initial_array, vectorized=True, method='RK23', 
                        atol=atol, rtol=rtol)#, t_eval=np.arange(tmax))
#     print('DONE')
    return out


def wrapper(dims, p0, initial_array, tmax, atol):
    species, n_h, n_w, scale, scale_s = dims
    out = sim_omnitray(dims, p0, initial_array=initial_array, tmax=tmax, atol=atol)
    exp_t = out.t
    exp_y = out.y.T
    exp_y.shape = (len(exp_t), species, n_h, n_w)
    return exp_y, exp_t

def prep_pad_0(p0, scale):
    Dc, Dn, rc, Kn, Hn, pn, xs, ps, leak, od0 = p0

    # Pad size definition
    scale_s = scale//8
    
    # Calculate dx and redefine p0
    dx = np.power((scale/4.5),2)
    p0 = np.array([dx, Dc, Dn, rc, Kn, Hn, pn, xs, ps, leak, od0])
    
    init_cells = skimage.io.imread('../../../microscope-movie-analysis/worker_outputs/masks/pad_0_frame_10.tif')
    scale_factor = (scale/4500)/(2.475/4)
    scaled_init = skimage.transform.rescale((init_cells>1).astype(np.float), 
                                            scale_factor, 
                                            order=1, 
                                            mode='constant', 
                                            multichannel=False,
                                            cval=0)>0

    n_h, n_w = scaled_init.shape
    col_thresh = 0.05
    tmax=1000
    species = 3 # cells, nutrients, mscarlet
    dims = [species, n_h, n_w, scale, scale_s]
    tup = np.array([species, n_h, n_w])
    c_i, n_i, s_i = np.arange(species)

    # Make empty array, and tolerance arrays
    atol = np.zeros((species, n_h, n_w), dtype=np.float32,order='C')# + 1e-7
    A = np.zeros((species, n_h, n_w), dtype=np.float32,order='C')# + 1e-7

    # set tolerances
    atol[c_i,:,:] = 1e-5*np.ones((n_h, n_w), dtype=np.float32)
    atol[n_i,:,:]  = 1e-2*np.ones((n_h, n_w), dtype=np.float32)
    atol[s_i,:,:]  = 1e-4*np.ones((n_h, n_w), dtype=np.float32)

    atol.shape = species*n_h*n_w
    rtol = 1e-3

    # Set initial conditions
    # Nutrients. All at 100
    A[n_i,:,:] = 100*np.ones((n_h, n_w), dtype=np.float32)
    
    initial_array = np.zeros((species, n_h, n_w), dtype=np.float32, order='C')# + 1e-7
    initial_array[n_i,:,:] = 100*np.ones((n_h, n_w), dtype=np.float32)
    initial_array[c_i,:,:] = scaled_init

    return dims, p0, initial_array, tmax, atol

def sim_pad(p0, prep_fn, tmax=10*60):
    dims, p0, initial_array, tmax, atol = prep_fn(p0)
    return wrapper(dims, p0, initial_array, tmax, atol)

    

In [None]:
scale = np.power(2,8).astype(np.int)
fn_prep0 = lambda p0 : prep_pad_0(p0, scale)

#Params : Dc, Dn, rc, Kn, Hn, pn, xs, ps, leak, od0
p0 = np.array([1e-5, 3e-2, 8e-3, 15, 2.5, 1, 1e-8, 8e-3, 5e2, 0.1], dtype=np.float32)
out = sim_pad(p0, fn_prep0, tmax=60)
# init_array = prep_pad_0(p0)[2]
# _, ar_h, ar_w = init_array.shape
# plt.imshow(init_array[c_i,ar_h//4:3*ar_h//4,ar_w//4:3*ar_w//4])
# plt.figure()
# init_cells = skimage.io.imread('../../../microscope-movie-analysis/worker_outputs/masks/pad_0_frame_10.tif')
# scale = np.power(2,9)
# scale_factor = (scale/4500)/(2.475/4)
# plt.imshow(init_cells>1)
# plt.figure()
# scaled_init = skimage.transform.rescale(init_cells>1, 
#                                         [scale_factor, scale_factor],
#                                         mode='constant',
#                                         cval=0,
#                                         order=1, 
#                                         multichannel=False, 
#                                         preserve_range=True)
# plt.imshow(scaled_init)


In [None]:

    
def write_movie(im_arr, t_vec, skip=1, n_frames=200):
    plt.close('all')
    
    frames, s, h, w = im_arr.shape
    t_points = np.arange(0,t_vec.max(),n_frames)
    f_points = np.arange(frames)
    
    #frames = len(t)
    
    t, s, h, w = im_arr.shape
    xticks = []
    xticklabels = []
    # First set up the figure, the axis, and the plot element we want to animate
    blank_array = np.zeros([h, w])
    fig, axs = plt.subplots(1,3, figsize=(10,7))
    im_list = [0,0,0]
    
    # Plot cell densities
    ax = axs[0]
    indxs = [c_i]
    vmax = im_arr[-1,indxs,:,:].sum(axis=0).max()
    vmin = im_arr[-1,indxs,:,:].sum(axis=0).min()
    im = ax.imshow(blank_array, animated=True, vmax=vmax, vmin=vmin, interpolation='none', aspect=1)
    cbar = fig.colorbar(im, ax=ax, ticks=[vmin, vmax])
    ax.set_xticks([])
    ax.set_xticklabels(xticklabels)
    ax.set_yticks([])
    ax.set_title('cell densities')
    im_list[0] = im
    
    # Plot nutrient densities
    ax = axs[1]
    indxs = [n_i]
    vmax = im_arr[:,indxs[0],:,:].max()
    vmin = im_arr[:,indxs[0],:,:].min()
    im = ax.imshow(blank_array, animated=True, vmax=vmax, vmin=vmin, interpolation='none', aspect=1)
    cbar = fig.colorbar(im, ax=ax, ticks=[vmin, vmax])
    ax.set_xticks(xticks)
    ax.set_xticklabels(xticklabels)
    ax.set_yticks([])
    ax.set_title('nutrient')
    im_list[1] = im
    
    # Plot synthase densities
    ax = axs[2]
    indxs = [s_i]
    v_arr = im_arr[:,indxs,:,w//2:].prod(axis=1)
    vmax = v_arr.max()
    vmin = v_arr.min()
    im = ax.imshow(blank_array, animated=True, vmax=vmax, vmin=vmin, interpolation='none', aspect=1)
    cbar = fig.colorbar(im, ax=ax, ticks=[vmin, vmax])
    ax.set_xticks(xticks)
    ax.set_xticklabels(xticklabels)
    ax.set_yticks([])
    ax.set_title('synthases')
    im_list[2] = im

    # animation function.  This is called sequentially
    t_points = np.linspace(0,t_vec.max(), 200)
    f_inds = []
    t_ind = 0
    for tp in t_points:
        while tp > t_vec[t_ind]:
            t_ind += 1
        f_inds.append(t_ind-1)
        
    def animate(t_point):
        i = f_inds[t_point]
        
        # Plot cell densities
        ax = axs[0]
        indxs = [c_i]
        frame_arr = im_arr[i,indxs,:,:].sum(axis=0)
        im_list[0].set_array(frame_arr)

        # Plot nutrient densities
        ax = axs[1]
        indxs = [n_i]
        frame_arr = im_arr[i,indxs,:,:].sum(axis=0)
        im_list[1].set_array(frame_arr)

        # Plot synthase densities
        ax = axs[2]
        indxs = [s_i]
        frame_arr = im_arr[i,indxs,:,:].prod(axis=0)
        im_list[2].set_array(frame_arr)

        #return im_list,

    # call the animator.  blit=True means only re-draw the parts that have changed.
    anim = anm.FuncAnimation(fig, animate, interval=50, frames=n_frames)

#     anim.save('./animation_test.gif', writer='pillow')
    fig.tight_layout()
    plt.close('all')
    return anim
#     HTML(anim.to_html5_video())


In [None]:
anim = write_movie(out[0], out[1])
HTML(anim.to_html5_video())

In [None]:
plt.imshow(out[0][8,c_i,:,:]>0)

In [None]:
# scale = np.power(2,6).astype(np.int)
# fn_prep0 = lambda p0 : prep_pad_0(p0, scale)

# #Params : Dc, Dn, rc, Kn, Hn, pn, xs, ps, leak, od0
# p0 = np.array([1e-5, 3e-2, 8e-3, 15, 2.5, 1, 1e-8, 8e-3, 5e2, 0.1], dtype=np.float32)
# out = sim_pad(p0, fn_prep0, tmax=60)
# init_array = prep_pad_0(p0, np.power(2,6).astype(np.int))[2]
# _, ar_h, ar_w = init_array.shape
# plt.imshow(init_array[c_i,ar_h//4:3*ar_h//4,ar_w//4:3*ar_w//4])
# plt.figure()
init_cells = skimage.io.imread('../../../microscope-movie-analysis/worker_outputs/masks/pad_0_frame_10.tif')>1
scale = np.power(2,9)
scale_factor = (scale/4500)/(2.475/4)
scaled_init = skimage.transform.rescale(init_cells.astype(np.float), 
                                        scale_factor, 
                                        order=1, 
                                        mode='constant', 
                                        multichannel=False,
                                        preserve_range=True,
                                        clip=True,
                                        cval=0)

# scale_factor = (scale/4500)/(2.475/4)
plt.figure(figsize=(10,10))
plt.imshow(init_cells)
plt.figure(figsize=(10,10))
plt.imshow(scaled_init>0)
plt.figure()
_ = plt.hist(scaled_init.flatten(),bins=20)
# scaled_init = skimage.transform.rescale(init_cells>1, 
#                                         [scale_factor, scale_factor],
#                                         mode='constant',
#                                         cval=0,
#                                         order=1, 
#                                         multichannel=False, 
#                                         preserve_range=True)
# plt.imshow(scaled_init)
