In [1]:
import numpy as np
import matplotlib.pyplot as plt


In [2]:
parameters_={'SCALE': 75,
             'NSS': 5000,
             'CLONE_SIZE': 1,
             'ARRIVAL_RATE': 1500,
             'NOISE_SIGMA': 0.04,
             'DT': 0.01,
             'ID_RANGE': 100000,
             'MU_STD': 0.017,
             'X0' : -2,
             'EXIT_THRESH': .1,
             'MU_STD_POP': 0.022}

parameters_mouse = parameters_.copy()
MOUSE_SCALE=5
parameters_mouse["SCALE"] = parameters_["SCALE"]*MOUSE_SCALE
parameters_mouse["DT"]=parameters_["DT"]/10
parameters_mouse["ARRIVAL_RATE"] = 12000
parameters_mouse["NSS"] = 2000


class Runner(object):
    def __init__(self,parameters):
        self.parameters = parameters.copy()
        
    def get_distirbution_of_exit_times_sampled(self,mu_mean,mu_std,N = 10000,tmax_years = 20,SAMPLE_RATE = .1):
        tmax = tmax_years * self.parameters["SCALE"]
        dt = self.parameters["DT"]
        state = self.load_population(N,mu_mean,mu_std)
        surv = np.array([])
        t_out = np.array([])
        surv_fact = 1
        for t in np.arange(0,tmax,dt):
            state = self.euler_step(dt,state,self.parameters["NOISE_SIGMA"])
            if t%SAMPLE_RATE < 0.0001:
                idx = self.get_exit_idx(state)
                if idx.any():
                    state[idx] = np.random.choice(state[~idx],idx.sum())
                    surv_fact = surv_fact*(1-idx.mean())
                t_out = np.append(t_out,t/self.parameters["SCALE"])
                surv = np.append(surv,surv_fact)
        return t_out,surv
    
    def run_competition(self,N_INIT=10,tmax_years = 100,#arrival_rate = lambda t: 30,
                        mu_std=None,first_run=True):
        if type(mu_std)==type(None):
            mu_std=self.parameters["MU_STD_POP"]
        NSS=self.parameters["NSS"]
        dt = self.parameters["DT"]
        tmax = tmax_years*self.parameters["SCALE"]
        if first_run:
            state = self.load_population(N_INIT,0,mu_std)
            hist = np.array([])
            t_=np.arange(0,tmax,dt)
            t_start = 0
        else: 
            state = self.state
            hist = self.hist
            t_start = self.t_.max()+dt
            t_=np.arange(0,self.t_.max()+tmax,dt)
        

        
        for t in t_[t_>=t_start]:
            if t%1 < 1e-9:
                hist_inst = self.load_hist_inst(t,state,NSS)
                hist = np.append(hist,hist_inst) if hist.shape[0]>0 else hist_inst

            state = self.euler_step(dt,state,self.parameters["NOISE_SIGMA"],quorum_sensing=True)
            idx = self.get_exit_idx(state)
            if idx.any():
                state = state[~idx]

            state["age"]+=dt/self.parameters["SCALE"]
            if np.random.rand() < dt*self.parameters["ARRIVAL_RATE"]/self.parameters["SCALE"]:
                state = np.append(state,self.load_population(self.parameters["CLONE_SIZE"],0,mu_std,V=state["V"].mean()))

        self.state = state.copy()
        self.hist = hist.copy()
        self.t_ = t_.copy()
        
        return t_,0,hist,state

class NormalFormRunner(Runner):
    def euler_step(self,dt,state,sigma,quorum_sensing=False):
        mu = self.quorum_sensing_mu(state, self.parameters["NSS"]) if quorum_sensing else state["mu"]
        active_cells = (~self.get_exit_idx(state)).sum()
        state["x"] += dt*(np.power(state["x"],2)+mu) + np.sqrt(dt)*sigma*np.random.normal(size=state.shape[0])
        state["V"] += dt*(active_cells-state["V"])
        idx = self.get_exit_idx(state)
        state["x"][idx]=self.parameters["EXIT_THRESH"]+1
        return state
    
    def load_population(self,N,mu_mean,mu_std,V=1.0):
        x0=self.parameters["X0"]
        id_range=self.parameters["ID_RANGE"]
        id_=np.random.choice(id_range)
        return np.array([(x0,V,0.0,np.random.normal(mu_mean,mu_std),id_) for i in range(N)],dtype=[('x','float'),('V','float'),('age','float'),('mu','float'),("id",'int')])        
    
    def get_exit_idx(self,state):
        return state["x"]>self.parameters["EXIT_THRESH"]
    
    def load_hist_inst(self,t,state,NSS):
        pop_size = state.shape[0]
        rep_size= np.unique(state["id"]).shape[0]
        mu_with_adj,mu_no_adj,std_mu = self.quorum_sensing_mu(state,NSS).mean(),\
                                       self.quorum_sensing_no_adj(state,NSS),\
                                       self.quorum_sensing_mu(state,NSS).std()
        pop_mu = state["mu"].mean()
        return np.array([(t/self.parameters["SCALE"],state["V"].mean(),pop_size,rep_size,mu_with_adj,mu_no_adj,std_mu,pop_mu)],
                        dtype=[('t','float'),('V','float'),('pop_size','int'),('repertoire_size','int'),
                               ('mu_with_adj','float'),('mu_no_adj','float'),('std_mu','float'),('pop_mu','float')])    

    def quorum_sensing_no_adj(self,state,NSS):
        return (1 - NSS/state["V"].mean())
    def quorum_sensing_mu(self,state,NSS):
        return self.quorum_sensing_no_adj(state,NSS)+state["mu"]
    
def plotExitTimesSampled(ax,t_out,surv,tmin=0,tmax=100,c='r',lw=3,norm=True,label=None,ls='solid'):
    t_out_scaled = t_out
    x = np.linspace(t_out_scaled.min(),t_out_scaled.max(),1000)
    x=x[(x>=tmin) & (x<=tmax)]
    y = interp1d(t_out_scaled,surv)(x)
    if norm:
        log_y = np.log(y)
        y = np.exp((log_y-log_y.min())/(log_y.max()-log_y.min()))/np.exp(1)
    ax.plot(x,y,c=c,lw=lw,ls=ls)

    
def get_timescale_monostable(parameters,N=1000):
    return NormalFormRunner(parameters).get_exit_times_unsampled(1,0,N).mean()

def get_timescale_bistable(parameters,N=1000):
    return NormalFormRunner(parameters).get_exit_times_unsampled(0,0,N).mean()