In [59]:
import time as time
import math

import numpy as np
import matplotlib.pyplot as plt 
import scipy.linalg as sclin

from numba import njit
from numba.typed import Dict
from numba.core import types

from pocs.File_manage import read_write as rw 

In [60]:
def measure_time(func):
    def decorated(param):
        x = time.time()
        res = func(**param)
        dt = time.time()-x
        print(dt)
        return dt, res
    #
    return decorated
# 

In [None]:
plot_dir    = "plots/"
data_dir    = "data/" 

## Chemical Master equation 
P(m,p) = 
+alphaR( P(m-1,p) - P(m,p) ) 
+alphaP( P(m,p-1) - P(m,p) )
+dR[ (m+1)P(m+1,p) - mP(m,p) ]
+dP[ (p+1)P(m,p+1) - pP(m,p) ]




In [61]:
@njit 
def delta(a, b): 
    if a == b:
        return 1 
    else : 
        return 0 
    #
# 


In [62]:
@njit
def element_fun(i, j, pmax, mmax, alpha, beta, gamma):
    
    m = i // (pmax + 1)
    p = i % (pmax + 1)
    
    mp = j // (pmax + 1)
    pp = j % (pmax + 1)
    
    Mij =   alpha * (      delta(m-1, mp)*delta(p, pp) - delta(m, mp)*delta(p, pp) ) 
    Mij +=  beta * gamma * m * (  delta(m, mp)*delta(p-1, pp) - delta(m, mp)*delta(p, pp) )
    Mij +=  gamma * (  (m+1)*delta(m+1, mp)*delta(p, pp) - m*delta(m,mp)*delta(p,pp) )    
    Mij +=  (p+1)*delta(m, mp)*delta(p+1, pp) - p*delta(m,mp)*delta(p,pp)


    return Mij 

@njit
def fill_matrix_FB(mat_len, pmax, mmax, mat, alpha, beta, gamma):
    for i in range(mat_len): 
        for j in range(mat_len):
            mat[i][j] = element_fun(i, j, pmax, mmax, alpha, beta, gamma)
        #
    #
    
    return mat 
#

@njit 
def matrixNormalization(mat_len, mat):
    for i in range(mat_len):
        mat[i][i] = mat[i][i] - sum(mat[:,i]) 
    #
    return mat
#

In [63]:
para_list = [20, 2.5, 10]

pmax = 199
mmax = 9
mat_len = (pmax+1) * (mmax + 1)
M = np.zeros((mat_len, mat_len), dtype=float)
I = np.zeros((mat_len, mat_len), dtype=float)
for i in range(mat_len):
    I[i][i] = 1 
# 


M = fill_matrix_FB(mat_len, pmax, mmax, M, *para_list)
M = matrixNormalization(mat_len, M)

eig_val, eig_vec = np.linalg.eig(M)
arg_sort_eig = eig_val.argsort()
arg_sort_eig = np.flip(arg_sort_eig)

prob_eq = np.zeros(( mmax+1, pmax+1))
for idx, eig_idx in enumerate(arg_sort_eig[:20]):
    max_vec = eig_vec[:,eig_idx]
    eig_value = eig_val[eig_idx]

    vec_norm = sum(max_vec)
    for i in range(len(max_vec)): 
        m = i // (pmax + 1)
        p = i % (pmax + 1)
        
        prob_eq[m,p] = max_vec[i]/vec_norm
    # 

    plt.title(f"Lambda = {-round(eig_value,4)}, \n alpha {para_list[0]} beta {para_list[1]} gamma {para_list[2]}")
    plt.xlabel("Protein #")
    plt.ylabel("mRNA #")
    plt.imshow(prob_eq, aspect=10, origin="lower")
    plt.colorbar()
    plt.savefig(plot_dir +f"Sol_P_lambda{idx:03d}.png", dpi = 300)
    plt.clf()
    
    plt.plot(sum(prob_eq.T), marker="o")
    plt.grid()
    plt.title(f"mRNA Lambda = {-round(eig_value,4)}, \n alpha {para_list[0]} beta {para_list[1]} gamma {para_list[2]}")
    plt.savefig(plot_dir + f"Sol_P_lambda{idx:03d}_mRNA.png", dpi = 300)
    plt.clf()
    
    plt.plot(sum(prob_eq), marker="o")
    plt.grid()
    plt.title(f"Protien Lambda = {-round(eig_value,4)}, \n alpha {para_list[0]} beta {para_list[1]} gamma {para_list[2]}")
    plt.savefig(plot_dir + f"Sol_P_lambda{idx:03d}_Protein.png", dpi = 300)
    plt.clf()
    
    
    
rw.pickle_dump(data_dir, f"p_{pmax}_m_{mmax}_para_{para_list[0]}_{para_list[1]}_{para_list[2]}_eig.bin", eig_val[0])
rw.pickle_dump(data_dir, f"p_{pmax}_m_{mmax}_para_{para_list[0]}_{para_list[1]}_{para_list[2]}_vec.bin", eig_vec[1])




<Figure size 432x288 with 0 Axes>