In [1]:
import pandas as pd
from matplotlib import pyplot as plt
import datetime
%matplotlib inline
import seaborn as sns
sns.set()
from matplotlib.backends.backend_pdf import PdfPages
#from mpl_toolkits.basemap import Basemap
import matplotlib.cm as cm
import matplotlib.colors as col
import matplotlib as mpl
import numpy as np #import for transposing
import math
import pylab as pl
import math
try:  # SciPy >= 0.19
    from scipy.special import comb, logsumexp
except ImportError:
    from scipy.misc import comb, logsumexp  # noqa 
from sklearn.metrics import mean_squared_error

<b>Function Names:</b> 
- add_tdate(), 
- get_week(), 
- density_interp(),
- extract_date(),
- nearestDate() <b><i>deprec.</i></b>,
- nearestValue() <b><i>deprec.</i></b>,
- nearest_df() <b><i>deprec.</i></b>,
- qtr_comp(),
- new_compare(),
- IC_run_sim()
- run_sim(),
- stoch_eqs(),
- Stoch_Iteration(),
- stoch_model(),


In [2]:
def smape(A, F):
    return (100/len(A)) * np.sum(np.abs(F - A) / (np.abs(A) + np.abs(F)))

In [3]:
def find(condition):
    res, = np.nonzero(np.ravel(condition))
    return res

In [4]:
##Converting date in census df to be on percentage scale in new column 'tdate'
##This will match the time scale format of the model better
##NOTE###
##0 IN TDINT DEPENDS SOLELY ON EARLIEST YEAR OF OBSERVATION PER TROOP
##0 DOES NOT UNIVERSALLY INDICATE THE SAME YEAR


def add_tdate(df):
    ##ADDS TDATE, TDINT, AND QUARTER COLUMNS
    
    
    ##Functionality for full (multitroop) dataframes
    try:
        troop_names = df.loc[~df.Troop.str.contains('/')].Troop.unique()

        df.loc[:,'tdate'] = 0
        df.loc[:,'tdint'] = 0
        
        def add_qtr(x):
            if x.month in [11,12,1]:
                return 1
            elif x.month in [2,3,4]:
                return 2
            elif x.month in [5,6,7]:
                return 3
            else:
                return 4

        df.loc[:,'quarter'] = df.loc[:,'Date'].apply(lambda x: add_qtr(x))
        
        

        for troop in troop_names:

            td = df.loc[df.Troop.str.contains(troop)]
            
            #td.loc[:,'tdint'] = td.Date.apply(lambda x: x.year)
            #td.loc[:,'tdint'] -= td.loc[:,'tdint'].iloc[0]
            
            ###ecount identifies first incomplete year
            ecount = 0
            ###Counter assigns tdint value
            counter = 0

            for year in td.ObsYear.unique():
                ##Finds full single Nov-start year for a troop
                wd = td.loc[(td.Date >= datetime.datetime(int(year),11,1))&(td.Date < datetime.datetime(int(year)+1,11,1))]

                ##Special handling for data that doesn't have a full Nov-start year at beginning of troop data
                if ecount == 0:
                    ed = td.loc[td.Date<datetime.datetime(int(year),11,1)]
                    ed.loc[:,'tdate'] = ed.loc[:,'Date'].apply(lambda x: (int(x.strftime('%j'))-305+365)/365)
                    ed.loc[:,'tdint'] = -1
                    ecount+=1

                try:
                    ##Converts date to a percentage number (0,1) based on Nov-start year Nov 1 = 0 and Oct 31 = 0.99 
                    wd.loc[:,'tdate'] = wd.loc[:,'Date'].apply(lambda x: (int(x.strftime('%j'))-305)/365 if x <= datetime.datetime(int(year),12,31) else (int(x.strftime('%j'))-305+365)/365)
                    wd.loc[:,'tdint'] = counter
                    counter += 1
                    df.update(wd)
                    df.update(ed)
                except:
                    raise Exception('There is an error updating input dataframe in add_tdate function.')
                    
                    
                    
    ##Functionality for single troop dataframes
    ##MUST HAVE DATE COLUMN
    ##For single troop data, will not add tdint based on iterations
        ##Will pull unique years from Date column
    except:
        
        df.loc[:,'tdate'] = 0
        df.loc[:,'tdint'] = 0
        
        def add_qtr(x):
            if x.month in [11,12,1]:
                return 1
            elif x.month in [2,3,4]:
                return 2
            elif x.month in [5,6,7]:
                return 3
            else:
                return 4

        df.loc[:,'quarter'] = df.loc[:,'Date'].apply(lambda x: add_qtr(x))
        
        

        td = df.copy()
        ###ecount identifies first incomplete year
        ecount = 0
        ###Counter assigns tdint value
        counter = 0
    
        for year in td.loc[:,'Date'].apply(lambda x: x.year).unique():
            ##Finds full single Nov-start year for a troop
            wd = td.loc[(td.Date >= datetime.datetime(int(year),11,1))&(td.Date < datetime.datetime(int(year)+1,11,1))]

            ##Special handling for data that doesn't have a full Nov-start year at beginning of troop data
            if ecount == 0:
                ed = td.loc[td.loc[:,'Date']<datetime.datetime(int(year),11,1)]
                ed.loc[:,'tdate'] = ed.loc[:,'Date'].apply(lambda x: (int(x.strftime('%j'))-305+365)/365)
                ed.loc[:,'tdint'] = -1
                ecount+=1

            try:
                ##Converts date to a percentage number (0,1) based on Nov-start year Nov 1 = 0 and Oct 31 = 0.99 
                wd.loc[:,'tdate'] = wd.loc[:,'Date'].apply(lambda x: (int(x.strftime('%j'))-305)/365 if x <= datetime.datetime(int(year),12,31) else (int(x.strftime('%j'))-305+365)/365)
                wd.loc[:,'tdint'] = counter
                counter += 1
                df.update(wd)
                df.update(ed)
            except:
                raise Exception('There is an error updating input dataframe in add_tdate function.')

    
    
    return df

In [5]:
def get_week(date,obs_year=float('nan')):
    
    if np.isnan(obs_year)==True:
        year = date.year
    else:
        year = obs_year
        
    start = datetime.datetime(year,11,1)

    if date >= start:
        x = date - start
        
    else:
        x = date - datetime.datetime(year-1,11,1)
        
        
    week = round(x.days/7) + 1
        
    return week

In [6]:
##This will take a rolling average of present weekly values and interpolate data for any missing weeks

##df must contain Date, T, J, S, and A columns



def density_interp(df,method='linear',limit=4,fusion=0):
    
    ##Add week column
    df.loc[:,'week'] = df.loc[:,'Date'].apply(lambda x: get_week(x))
    
    try:
        ##For multitroop df
        if len(df.Troop.unique()) > 1:
            full = pd.DataFrame()

            for name in df.Troop.unique():
                if fusion == 0:
                    troop = df.loc[df.Troop == name]
                elif fusion == 1:
                    troop = df.loc[df.Troop.str.contain(name)]

                ##Take average value of T,J,S,A values for each present week
                avg = troop.set_index('Date').resample('W').mean()

                ##Input missing weeks
                avg.loc[:,'week'] = avg.loc[:,'week'].interpolate()

                ##Interpolate missing T,J,S,A values
                avg.loc[:,['T','J','S','A']] = round(avg.loc[:,['T','J','S','A']].interpolate(method=method,limit=limit))

                ##Round any float week values to integers
                avg.loc[:,'week'] = avg.loc[:,'week'].apply(lambda x: round(x))

                ##Fix tdate,tdint,and quarter columns to match data
                avg = add_tdate(avg.reset_index())#.set_index(['tdint','week'])

                ##Fix any oddly interpolated week values so they match their date values
                avg.loc[:,'week'] = avg.loc[:,'Date'].apply(lambda x: get_week(x))

                avg.loc[:,'Troop'] = name


                full = full.append(avg)

            return full
    
        #################################################################
        ##For single troop df
        else:
            name = df.Troop.unique()[0]

            ##Take average value of T,J,S,A values for each present week
            avg = df.set_index('Date').resample('W').mean()

            ##Input missing weeks
            avg.loc[:,'week'] = avg.loc[:,'week'].interpolate()

            ##Interpolate missing T,J,S,A values
            avg.loc[:,['T','J','S','A']] = round(avg.loc[:,['T','J','S','A']].interpolate(method=method,limit=limit))

            ##Round any float week values to integers
            avg.loc[:,'week'] = avg.loc[:,'week'].apply(lambda x: round(x))

            ##Fix tdate,tdint,and quarter columns to match data
            avg = add_tdate(avg.reset_index())#.set_index(['tdint','week'])

            ##Fix any oddly interpolated week values so they match their date values
            avg.loc[:,'week'] = avg.loc[:,'Date'].apply(lambda x: get_week(x))

            avg.loc[:,'Troop'] = name


            return avg
        
        
    ##In case there is no troop column 
    except:
        ##NEED TO FIGURE OUT HOW TO ADD DATE FROM TDATE
        
        name = 'simulated'

        ##Take average value of T,J,S,A values for each present week
        avg = df.set_index('Date').resample('W').mean()

        ##Input missing weeks
        avg.loc[:,'week'] = avg.loc[:,'week'].interpolate()

        ##Interpolate missing T,J,S,A values
        avg.loc[:,['T','J','S','A']] = round(avg.loc[:,['T','J','S','A']].interpolate(method=method,limit=limit))

        ##Round any float week values to integers
        avg.loc[:,'week'] = avg.loc[:,'week'].apply(lambda x: round(x))

        ##Fix tdate,tdint,and quarter columns to match data
        avg = add_tdate(avg.reset_index())#.set_index(['tdint','week'])

        ##Fix any oddly interpolated week values so they match their date values
        avg.loc[:,'week'] = avg.loc[:,'Date'].apply(lambda x: get_week(x))

        avg.loc[:,'Troop'] = name


        return avg




        
    
#df.head()            

In [7]:
def extract_date(tdate,obs_year=float('nan')):
    from datetime import timedelta 
    
    ##This step gets the dates to have matching years with observation data (from which ICs were initialized)
    if np.isnan(obs_year)==True:
        year=2017
        
    else:
        year=obs_year
        
        
    if np.isnan(tdate) == True:
        raise Exception('tdate is NaN')
    
    
    if tdate >= 1:
        tdate = tdate%1
        
    if tdate == 0:
        date = datetime.datetime(year,11,1)
        return date
    
    if tdate <= (int(datetime.datetime(year,12,31).strftime('%j'))-305)/365:
        if tdate == 0: 
            date = datetime.datetime(year,11,1)
        else:
            jul = int((tdate*365)+305)
            jul = datetime.datetime.strptime(str(jul),'%j')
            diff = year - jul.year
            date = jul + timedelta(days=(float(format(365.25*diff,'.3f'))))

    else:
        if tdate == 0:
            date = datetime.datetime(year,11,1)
        else:
            try:
                jul = int((tdate*365)+305-365)
                jul = datetime.datetime.strptime(str(jul),'%j')
                diff = year - jul.year
                date = jul + datetime.timedelta(days=(float(format(365.25*diff,'.3f'))))
            except:
                jul = tdate*365
                date = datetime.datetime(year,11,1)+datetime.timedelta(days = jul)
            
        
    return date

In [8]:
'''DEPRECIATED'''
def nearestDate(tdate, df):
    try:
        nearness = { abs(tdate - date) : date for date in df['tdate'] }
        return nearness[min(nearness.keys())]
    except:
        return tdate

In [9]:
'''DEPRECIATED'''
def nearestValue(tdate,troop_df,age):

        try:
            val = int(troop_df.loc[troop_df.tdate == nearestDate(tdate,troop_df)][age])

        except:
            val = float('NaN')

        return val

In [10]:
'''DEPRECIATED'''
def nearest_df(sim,troop_df):
    ##This function builds a dataframe of values from troop_df that are closest to the date/tdate values of the sim dataframe
    ##BE SURE TO GIVE ONLY ONE GROUP AT A TIME: OTHERWISE TROOP AND TDINT VALUES WILL BE WRONG
    
    diff_df = pd.DataFrame(columns = ['Date','tdate_near','count_near','Troop','tdint'])
    def nearestDate(tdate, df):
        try:
            nearness = { abs(tdate - date) : date for date in df['tdate'] }
            return nearness[min(nearness.keys())]
        except:
            return tdate

    def nearestValue(tdate,troop_df):

        try:
            val = int(troop_df.loc[troop_df.tdate == nearestDate(tdate,troop_df)]['counts'])
        

        except:
            if (troop_df.loc[troop_df.tdate == nearestDate(tdate,troop_df)]['counts'].empty==True):
                val = float('NaN')
            else:
                print(tdate)
                print(troop_df)
                print(troop_df.loc[troop_df.tdate == nearestDate(tdate,troop_df)]['counts'])

        return val

    def add_qtr(x):
        if x <= 0.25:
            return 1
        elif (x>0.25)&(x<=0.5):
            return 2
        elif (x>0.5)&(x<=0.75):
            return 3
        else:
            return 4

    diff_df.loc[:,'tdate_near'] = sim.loc[:,'tdate'].apply(lambda x: nearestDate(x,troop_df))
    diff_df.loc[:,'count_near'] = sim.loc[:,'tdate'].apply(lambda x: nearestValue(x,troop_df))
    diff_df.loc[:,'Date'] = diff_df.loc[:,'tdate_near'].apply(lambda x: extract_date(x))
    diff_df.loc[:,'week'] = diff_df.loc[:,'Date'].apply(lambda x: get_week(x))
    diff_df.loc[:,'quarter'] = diff_df.loc[:,'tdate_near'].apply(lambda x: add_qtr(x))
    diff_df.loc[:,'Troop'] = troop_df.iloc[-1,-1]
    diff_df.loc[:,'tdint'] = troop_df.iloc[-1,-3]

    
    return diff_df

In [11]:
##This function is designed to be used AFTER a simulation run

##This function will pull quarterly data information for the comparison function
##Comparison (RMSE/MAE) will be caluculated based on the quarterly values of the model that align with quarterly values of field data
##Both field data and sim data will be interpolated based on sparsity of data sets

##Built to only handle one troop's year of data at a time


def qtr_comp(grp):
    
##Meant to be used in a .apply manner
    
##grp must be a sim_iters group sorted by ['obs_troop','obs_tdint','obs_quarter','age_class','iteration']
##To calculate error from all iterations, remove iteration from groupby


    ##Calculating nRMSE value    
    try:
        ##Making sure the obs and sim counts are of same length
        tru_ind = list(grp.loc[(np.isnan(grp.obs_counts)==False)&(np.isnan(grp.sim_counts)==False)].index)
        obs = grp.loc[tru_ind,'obs_counts']
        pred = grp.loc[tru_ind,'sim_counts']

        RMSE = np.sqrt(mean_squared_error(obs,pred))
        
        if np.mean(grp.obs_counts) != 0:
            NRMSE = RMSE/np.mean(grp.obs_counts)
        
        elif np.mean(grp.obs_counts)==0:
            NRMSE = RMSE = RMSE/abs(np.max(grp.obs_counts)-np.min(grp.obs_counts))
    except:
        NRMSE = float('NaN')
        RMSE = float('NaN')
    
    
    ##Calculating sMAPE value
    try:
        sMAPE = smape(obs,pred)
        
    except:
        sMAPE = float('NaN')
        #print(obs,pred)
    
####################################################################################    
    ##VARIANCE
    ##OCCASIONALLY, RELATIVE ERROR WILL BE INF BECAUSE A QUARTER ONLY HAS ONE ENTRY FOR THAT AGE CLASS

    var_df = pd.DataFrame(columns=['obs_troop','obs_tdint','quarter','obs_min','obs_max','obs_mean','obs_var','sim_min','sim_max','sim_mean','sim_var','abs_err','rel_err','age_class','NRMSE','RMSE','sMAPE','iteration','IC'])

    var_df = var_df.append({'obs_troop':grp.obs_troop.unique()[0],
                            'obs_tdint':grp.obs_tdint.unique()[0],
                            'quarter':grp.obs_quarter.unique()[0],
                            'age_class':grp.age_class.unique()[0],
                            'obs_min':np.min(grp.obs_counts),
                            'obs_max':np.max(grp.obs_counts),
                            'obs_var':np.var(grp.obs_counts),
                            'obs_mean': np.mean(grp.obs_counts),
                            'sim_min':np.min(grp.sim_counts),
                            'sim_max':np.max(grp.sim_counts),
                            'sim_mean': np.mean(grp.sim_counts),
                            'sim_var':np.var(grp.sim_counts),
                            'abs_err':abs(np.var(grp.obs_counts)-np.var(grp.sim_counts)),
                            'rel_err': abs(np.var(grp.sim_counts)-np.var(grp.obs_counts))/np.var(grp.obs_counts),
                            'NRMSE':NRMSE,
                            'RMSE' :RMSE,
                            'sMAPE':sMAPE,
                            'iteration':grp.iteration.unique()[0],
                            'IC':grp.IC.unique()[0]},ignore_index=True)
            
        
        

#####################################################################################################
    
    
    return var_df
    

In [12]:
def IC_run_sim(trp_grp,iters,fusion= 0.1,dispersal=0.1,b_j = 0.5,d_j = 0.345,d_s = 0.214,d_a = 0.143,t_p = 0.3,por = 0.2,fsr = 0.4,ae_thresh = 5,fis_thresh = 45,lit_size = 4,dfe='off'):
    ##NOTE: trp_grp is a SINGLE group by(Troop,tdint) from full
    
    ##This function will extract initial conditions from a given troop (trp_grp)
    ##and run iterations (iters) with those given initial conditions.
    ##The multiple iterations will be returned in a dataframe
    
    ##This function is meant to be used in a .apply() on a grouped dataframe
    ##containing all of the melted new_census data
    
    ##Initialize df for all sim iterations to export
    sim_iters = pd.DataFrame()

    ##Extract ICs
    q1 = trp_grp.loc[trp_grp.quarter==1]
    
    ##This line ensures that there is always at least a total count (removes lines with no count data at all)
    cond = (q1.age_class=='T') & (np.isnan(q1.counts)==True)
    q1 = q1.mask(cond,999)
    q1 = q1.loc[q1.counts!=999]
    
    ##Find date to take ICs from
    try:
        min_date = q1.loc[q1.tdate.idxmin()]['tdate']
        IC_df = q1.loc[q1.tdate==min_date]
    except:
        ##Error occurs if there is no quarter 1 data in the given tdint
        return
    
    ##Get IC counts
    try:
        T = round(IC_df.loc[IC_df.age_class=='T']['counts'].iloc[0])
    except:
        #raise Exception('Empty total counts')
        return
    #############################################################
    try:
        J = round(IC_df.loc[IC_df.age_class=='J']['counts'].iloc[0])
    except:
        J=0
   #####################################################################
    try:
        S = round(IC_df.loc[IC_df.age_class=='S']['counts'].iloc[0])
    except:
        S=0
    ##############################################################
    try:
        A = round(IC_df.loc[IC_df.age_class=='A']['counts'].iloc[0])
    except:
        A=0
        
    '''except:
        print(round(IC_df.loc[IC_df.age_class=='T']['counts'].iloc[0]))
        ##The only error that occurs here is the absence of a total age_class
        ##Which means that there are no J,S,A counts either
        ##So there are no ICs to extract
        
        ##ASK BRYAN OF IMPORTANCE OF TESTING ALL GROUPS
        ##IF NO VALUES: DON'T TEST GROUP OR TEST ON NEXT QUARTER VALUES(MODEL NEEDS CHANGES)
        return'''
           

    ##There are no total NaNs, so there are only these options
    try:

        if ((np.isnan(J)==True)&(np.isnan(S)==True)&(np.isnan(A)==True)):
            IC = [0,0,0,T,1]
            
        ##Clerical errors lead A==1 and T>1,
        ##In which case, just use T counts 
        elif (A==1) | (A==0):
            IC = [0,0,0,T,1]

        else:
            IC = [J,0,S,A,1]
            IC = list(pd.Series(IC).fillna(0))
    except:
        print('J: '+str(J))
        print('S: '+str(S))
        print('A: '+str(A))
        print('T: '+str(T))
        


    ##Run all iterations of sim
    for i in range(iters):
        #Run simulation
        while True:
            try:
                sim = run_sim(IC,fusion = fusion,dispersal = dispersal,b_j = b_j,d_j = d_j,d_s = d_s,d_a = d_a,t_p = t_p,por = por,fsr = fsr,ae_thresh = ae_thresh,fis_thresh = fis_thresh,lit_size = lit_size,dfe='off')
                break
            except:
                continue
            
        sim = sim.melt(id_vars=['Troop','Date','week','tdate','tdint','quarter'],value_vars=['T','J','S','A']).rename(columns={'variable':'age_class','value':'counts'})
        sim.loc[:,'iteration'] = i
        sim = sim[sim.tdate<1]
        diff = abs(list(sim.Date)[-1].year-list(trp_grp.Date)[-1].year)
        sim.loc[:,'Date'] = sim.loc[:,'Date'].apply(lambda x: x-datetime.timedelta(days=365.25*diff))
        
        
        merged = trp_grp.merge(sim,how='right',on=['week','age_class'])
        temp = merged.loc[np.isnan(merged.quarter_x)==True]
        if len(temp) != 0:
            temp.loc[:,'tdint_x'] = merged.iloc[temp.index[0]-1]['tdint_x'] ##grabs tdint from ONE index position above first tdint NaN entry

            ##Have missing obs data match sim data (all except for count => counts still NaN)
            temp.loc[:,'tdate_x'] = temp.tdate_y
            temp.loc[:,'Date_x'] = temp.Date_y
            temp.loc[:,'quarter_x'] = temp.quarter_y
            temp.loc[:,'Troop_x'] = list(merged.Troop_x)[0]

            merged.update(temp)
        
        
        merged = merged.rename(columns = {'Troop_x':'obs_troop',
                                 'Date_x':'obs_date',
                                 'tdate_x':'obs_tdate',
                                 'tdint_x':'obs_tdint',
                                 'quarter_x':'obs_quarter',
                                 'counts_x':'obs_counts',
                                 'Troop_y':'sim_troop',
                                 'Date_y':'sim_date',
                                 'tdate_y':'sim_tdate',
                                 'tdint_y':'sim_tdint',
                                 'quarter_y':'sim_quarter',
                                 'counts_y':'sim_counts',
                                })
    
            
        sim_iters = sim_iters.append(merged)
        sim_iters = sim_iters.reset_index(drop=True)
        sim_iters.loc[:,'IC'] = str(IC)
        

    return sim_iters

In [13]:
def new_compare(full_df,iters = 1000,years = 1,fusion= 0.1,dispersal=0.1,b_j = 0.5,d_j = 0.345,d_s = 0.214,d_a = 0.143,t_p = 0.3,por = 0.2,fsr = 0.4,ae_thresh = 5,fis_thresh = 45,lit_size = 4,dfe='off'):
    
    
    ##full_df = new_census
    ##Fully grouped and melted observation data
    full_df = full_df.melt(id_vars=['Troop','Date','week','tdate','tdint','quarter'],value_vars=['T','J','S','A']).rename(columns={'variable':'age_class','value':'counts'})
    full = full_df.groupby(['Troop','tdint'])
    
    ##EXtract IC and run simulations
    ##full is grouped by Troop and tdint => IC_run_sim will only receive one year of one Troop's data at a time
    sim_iters = full.apply(lambda grp: IC_run_sim(grp,iters=iters,fusion = fusion,dispersal=dispersal,b_j = b_j,d_j = d_j,d_s = d_s,d_a = d_a,t_p = t_p,por = por,fsr = fsr,ae_thresh = round(ae_thresh),fis_thresh = round(fis_thresh),lit_size = round(lit_size),dfe='off')).reset_index(drop=True)

    
    ##Calculating error and descriptive stats
    grp_iters = sim_iters.groupby(['obs_troop','obs_tdint','obs_quarter','age_class','iteration'])                   
    full_err = grp_iters.apply(lambda x: qtr_comp(x))
    
    print(str(datetime.datetime.now())+' calculated new_compare')
    ##Logging each param value for later analysis
    full_err.loc[:,'fusion'] = fusion
    full_err.loc[:,'dispersal'] = dispersal
    full_err.loc[:,'b_j'] = b_j
    full_err.loc[:,'d_j'] = d_j
    full_err.loc[:,'d_s'] = d_s
    full_err.loc[:,'d_a'] = d_a
    full_err.loc[:,'t_p'] = t_p
    full_err.loc[:,'por'] = por
    full_err.loc[:,'fsr'] = fsr
    full_err.loc[:,'ae_thresh'] = ae_thresh
    full_err.loc[:,'fis_thresh'] = fis_thresh
    full_err.loc[:,'lit_size'] = lit_size
    full_err.loc[:,'dfe'] = dfe
    
    
    return full_err

        
        

In [14]:
def run_sim(IC,obs_year=2017,log=0,plot = 0,method='linear',limit=4,fusion= 0.1,dispersal=0.1,b_j = 0.5,d_j = 0.345,d_s = 0.214,d_a = 0.143,t_p = 0.3,por = 0.2,fsr = 0.4,ae_thresh = 5,fis_thresh = 45,lit_size = 4,dfe='off'):
    ##IC must come in the form of (J,S1,S2,A,year)
    if log ==0:
        sim = stoch_model(IC[0],IC[1],IC[2],IC[3],IC[4],log=log,fusion = fusion,dispersal=dispersal,b_j = b_j,d_j = d_j,d_s = d_s,d_a = d_a,t_p = t_p,por = por,fsr = fsr,ae_thresh = ae_thresh,fis_thresh = fis_thresh,lit_size = lit_size,dfe='off')
    else:
        sim,log_df = stoch_model(IC[0],IC[1],IC[2],IC[3],IC[4],log=log,fusion = fusion,dispersal=dispersal,b_j = b_j,d_j = d_j,d_s = d_s,d_a = d_a,t_p = t_p,por = por,fsr = fsr,ae_thresh = ae_thresh,fis_thresh = fis_thresh,lit_size = lit_size,dfe='off')
    
    
    sim.loc[:,'Date'] = sim.loc[:,'tdate'].apply(lambda x: extract_date(x,obs_year=obs_year))
    
    ##Fixing the year alignment (dates between november and december need to be one )
    import datetime
    temp = sim.loc[(sim.quarter==1)&(sim.Date <= datetime.datetime(obs_year,12,31))&(sim.Date >= datetime.datetime(obs_year,11,1))]
    temp.loc[:,'Date'] = temp.Date - datetime.timedelta(days=365)
    sim.update(temp)
    
    sim = density_interp(sim,method=method,limit=limit)

    ##To make tdate consistent 
    ##For values in a new year/tdint tdate must be greater than 1
    temp = sim.loc[sim.tdint == 1]
    temp.loc[:,'tdate'] += 1
    sim.update(temp)
    
    
        

        
    if (plot == 1):
        #pl.figure(figsize=(8,8))

        pl.subplot(411).set_xticklabels([])
        pl.ylim(ymin = 0,ymax = max(sim['T'])+2)
        pl.plot(sim['tdate'], sim['T'], 'g')
        pl.ylabel ('Total')

        pl.axvline(x = 0.25,color='b',linestyle='--')
        pl.axvline(x = 0.5,color='b',linestyle='--')
        pl.axvline(x = 0.75,color='b',linestyle='--')
        pl.axvline(x = 1,color='b',linestyle='--')


        pl.subplot(412).set_xticklabels([])
        pl.ylim(ymin = 0,ymax = max(sim['J'])+2)
        pl.plot(sim['tdate'], sim['J'], 'r')
        pl.ylabel ('Juveniles')

        pl.axvline(x = 0.25,color='b',linestyle='--')
        pl.axvline(x = 0.5,color='b',linestyle='--')
        pl.axvline(x = 0.75,color='b',linestyle='--')
        pl.axvline(x = 1,color='b',linestyle='--')


        pl.subplot(413).set_xticklabels([])
        pl.ylim(ymin = 0,ymax = max(sim['S'])+2)
        pl.plot(sim['tdate'],sim['S'], 'y')
        pl.ylabel('Subs')

        pl.axvline(x = 0.25,color='b',linestyle='--')
        pl.axvline(x = 0.5,color='b',linestyle='--')
        pl.axvline(x = 0.75,color='b',linestyle='--')
        pl.axvline(x = 1,color='b',linestyle='--')


        pl.subplot(414)#.set_xticklabels(['0','Nov','Jan','Mar','May','Jul','Sep','Nov'])
        pl.ylim(ymin = 0,ymax = max(sim['A'])+2)
        pl.plot(sim['tdate'], sim['A'], 'k')
        pl.xlabel ('Time (years)')
        pl.ylabel ('Adults')

        pl.axvline(x = 0.25,color='b',linestyle='--')
        pl.axvline(x = 0.5,color='b',linestyle='--')
        pl.axvline(x = 0.75,color='b',linestyle='--')
        pl.axvline(x = 1,color='b',linestyle='--')
        #ml.display()
        pl.show()

    

    if log ==0:
        return sim
    else:
        return sim,log_df

In [15]:
##Note:  higher group size increases pup survival (escorts) 
##Note: Pups need to be extinct by dry season (transitioned or death occurs)
##Note: Total end population should not be 
##Stoch Function

def stoc_eqs(INP,ts,qtr,time,pulse,dfe='off',fusion= 0.1,dispersal=0.1,b_j = 0.5,d_j = 0.345,d_s = 0.214,d_a = 0.143,t_p = 0.3,por = 0.2,fsr = 0.4,ae_thresh = 5,fis_thresh = 45,lit_size = 4): 
    V = INP
    if dfe == 'off':
        Rate=np.zeros((9))
        Change=np.zeros((9,4))
    else:
        Rate=np.zeros((8))
        Change=np.zeros((8,4))
    N=np.sum(V[range(4)])
    
    temp_log = pd.DataFrame(columns=['Message','Time']) ##a temporary log to track infection events

    s_j = 1 - d_j
    s_s = 1 - d_s
    s_a = 1 - d_a
    ##V[0] = Juv | V[1] = Sub1 | V[2] = Sub2 | V[3] = Adult

    if np.sum(INP) >= ae_thresh: #above AE/extinction thresh

        ##1st qtr during wet season (above thresh)
        if qtr == 1: ##qtr = 1 => We assume we've found a cohort with all of these age classes present

            Rate[0] = 0; Change[0,:]=([0, 0, 0, 0]);##no additions to juv until after 1st qtr
            Rate[1] = 0; Change[1,:] = ([0, 0, 0, 0]); ##no juvs present until after 1st qtr
            Rate[2] = 0; Change[2,:] = ([0, 0, 0, 0]); ##no juvs present until after 1st qtr
            Rate[3] = 0; Change[3,:] = ([0, 0, 0, 0]); ##no sub1 present until after 2nd qtr
            Rate[4] = 0; Change[4,:] = ([0, 0, 0, 0]); ##no sub1 present until after 2nd qtr
            Rate[5] = s_s*V[2]; Change[5,:] = ([0, 0, -s_s*V[2], +s_s*V[2]]);
            Rate[6] = d_s*V[2]; Change[6,:] = ([0, 0, -d_s*V[2], 0]);
            Rate[7] = d_a*V[3]; Change[7,:] = ([0, 0, 0, -d_a*V[3]]);
            if dfe == 'off':
                ##Reaction for infection (mortality only occurs during dry season)
                Rate[8] = 0; Change[8,:] = ([0, 0, 0, 0]);     

        ##2nd qtr during wet season (above thresh)
        ##pups survived 1st qtr and juvs can be added
        elif qtr == 2: 
            Rate[0] = b_j*(4*fsr*V[3]); Change[0,:]=([+lit_size*fsr*V[3], 0, 0, 0]);
            Rate[1] = s_j*V[0]; Change[1,:] = ([-s_j*V[0], +s_j*V[0], 0, 0]);
            Rate[2] = d_j*V[0]; Change[2,:] = ([-d_j*V[0], 0, 0, 0]);
            Rate[3] = 0; Change[3,:] = ([0, 0, 0, 0]); ##no sub1 present until after 2nd qtr
            Rate[4] = 0; Change[4,:] = ([0, 0, 0, 0]); ##no sub1 present until after 2nd qtr
            Rate[5] = 0; Change[5,:] = ([0, 0, 0, 0]); ##no sub2 again until 4th qtr
            Rate[6] = 0; Change[6,:] = ([0, 0, 0, 0]); ##no sub2 again until 4th qtr
            Rate[7] = d_a*V[3]; Change[7,:] = ([0, 0, 0, -d_a*V[3]]);
            if dfe == 'off':
                ##Reaction for infection  (mortality only occurs during dry season)
                Rate[8] = 0; Change[8,:] = ([0, 0, 0, 0]);
            if np.sum(Change[0,:]) > 0:
                pulse += 1 ##Pulse tracks whether or not a birth pulse has occurred

        ##3rd qtr during dry season (above thresh)
        elif qtr == 3:
        #print("dry")  
            if time%1 > 0.5: #is time past April?
                Rate[0] = 0; Change[0,:]=([0, 0, 0, 0]); ##no more juv additions after April
            else:
                Rate[0] = b_j*(4*fsr*V[3]); Change[0,:]=([+lit_size*fsr*V[3], 0, 0, 0]) ##no more juv additions after April

            if time%1 > 8./12: #is time past June?
                Rate[1] = 0; Change[1,:] = ([0, 0, 0, 0]); ##no more transitioning after June
                Rate[2] = 0; Change[2,:] = ([0, 0, 0, 0]); ##no juv death since all juvs transitioned
            else:
                Rate[1] = s_j*V[0]; Change[1,:] = ([-s_j*V[0], +s_j*V[0], 0, 0]); ##juv transitioning occurs before June
                Rate[2] = d_j*V[0]; Change[2,:] = ([-d_j*V[0], 0, 0, 0]);
                ##will need to include logic transition b/c of this conditional

            Rate[3] = s_s*V[1]; Change[3,:] = ([0, -s_s*V[1], +s_s*V[1], 0]);
            Rate[4] = d_s*V[1]; Change[4,:] = ([0, -d_s*V[1], 0, 0]);
            Rate[5] = 0; Change[5,:] = ([0, 0, 0, 0]); ##no sub2 again until 4th qtr
            Rate[6] = 0; Change[6,:] = ([0, 0, 0, 0]); ##no sub2 again until 4th qtr
            Rate[7] = d_a*V[3]; Change[7,:] = ([0, 0, 0, -d_a*V[3]]);
            if dfe == 'off':
                ##Reaction for infection
                Rate[8] = t_p*np.sum(V); Change[8,:] = ([0, 0, 0, -por*V[3]]);
            if np.sum(Change[0,:]) > 0:
                pulse += 1

        ##4th qtr during dry season (above thresh)
        elif qtr == 4:
            Rate[0] = 0; Change[0,:]=([0, 0, 0, 0]);   ##No juvenile transition during 4th quarter
            Rate[1] = 0; Change[1,:] = ([0, 0, 0, 0]); ##no more juvs
            Rate[2] = 0; Change[2,:] = ([0, 0, 0, 0]); ##no more juvs

            if time%1 > 11./12: #is time past September?
                Rate[3] = 0; Change[3,:] = ([0, 0, 0, 0]); ##no sub1 after Sept.
                Rate[4] = 0; Change[4,:] = ([0, 0, 0, 0]); ##no sub1 after Sept.
            else:
                Rate[3] = s_s*V[1]; Change[3,:] = ([0, -s_s*V[1], +s_s*V[1], 0]);
                Rate[4] = d_s*V[1]; Change[4,:] = ([0, -d_s*V[1], 0, 0]);
                ##will need to include logic transition b/c of this conditional

            Rate[5] = s_s*V[2]; Change[5,:] = ([0, 0, -s_s*V[2], +s_s*V[2]]);
            Rate[6] = d_s*V[2]; Change[6,:] = ([0, 0, -d_s*V[2], 0]);
            Rate[7] = d_a*V[3]; Change[7,:] = ([0, 0, 0, -d_a*V[3]]);
            if dfe == 'off':
                ##Reaction for infection
                Rate[8] = t_p*np.sum(V); Change[8,:] = ([0, 0, -por*V[2], -por*V[3]]);
                
        R1=pl.rand();
        R2=pl.rand();
        ts = -np.log(R2)/(np.sum(Rate)); ##Gillespie SSA
        #################################################################################################
        ##Implementing Changes
        def find(condition):
            res, = np.nonzero(np.ravel(condition))
            return res

        m=min(find(pl.cumsum(Rate)>=R1*pl.sum(Rate)));
        V[range(4)]=V[range(4)]+Change[m,:]
        
        ##Create log of infection events (if occurred)
        if m ==8:
            temp_log = temp_log.append(pd.DataFrame({'Message':['Deaths from infection occurred'], 'Time': [time]}))
            temp_log = temp_log.append(pd.DataFrame({'Message':['Sub-Adults | Adults died from infection: {}|{}'.format(abs(round(Change[8,2])),abs(round(Change[8,3])))], 'Time': time}))
        
        return [V,ts,pulse,temp_log]

    else: #below thresh
        ##Calculate next timestep
        R1=pl.rand();
        R2=pl.rand();
        ts = -np.log(R2)/(np.sum(Rate)); ##Gillespie SSA
        
        ##Determine Allee Effect outcome
        ##An even probability of extinction or fusion occurring to offset AE
        if R1 <= 0.5: 
            ##A fusion occurs to offset Allee Effect
            m=min(find(pl.cumsum(Rate)>=R1*pl.sum(Rate)));
            V[range(4)]=V[range(4)]+Change[m,:]
            V[3] += pl.randint(10,30) ##add on new adults
            
            ##Create log of infection events (if occurred)
            if m ==8:
                temp_log = temp_log.append(pd.DataFrame({'Message':['Deaths from infection occurred'], 'Time': [time]}))
                temp_log = temp_log.append(pd.DataFrame({'Message':['Sub-Adults | Adults died from infection: {}|{}'.format(abs(round(Change[8,2])),abs(round(Change[8,3])))], 'Time': time}))

            return [V,ts,pulse,temp_log]
        
        else:
            #Extinction occurs as a result of Allee Effect
            V = [0,0,0,0]
            return [V,ts,pulse,temp_log]
       

    


In [16]:
#########################################
##Model Evaluation
def Stoch_Iteration(INPUT,log=0,dfe='off',ND=1,fusion= 0.1,dispersal = 0.1,b_j = 0.5,d_j = 0.345,d_s = 0.214,d_a = 0.143,t_p = 0.3,por = 0.2,fsr = 0.4,ae_thresh = 5,fis_thresh = 45,lit_size = 4):

    lop=0
    ts=-np.log(pl.rand())/np.sum(INPUT) #calculate first timestep using Gillespie SSA
    T=[0] ##This T represents time
    J=[INPUT[0]]
    S1=[INPUT[1]]
    S2=[INPUT[2]]
    A=[INPUT[3]]
    qt = 0
    pulse = 0
    fsr = fsr

    ran = np.array((0,0.25,0.5,0.75,1))
    qtr = np.array((0,0.25,0.5,0.75,1))

    ##A log for error messages
    if log > 0:
        log_df = pd.DataFrame(columns=['Message','Time'])

    while T[lop] < ND:
        lop=lop+1
        T.append(T[lop-1]+ts)
        ###############################################
        ###Transitioning remaining classes (for logic)
        S2_logic = 0

        ##Determining and Indicating Quarter of Year
        try:
            if ND == 1:
                if (T[lop] <= qtr[1]) & (T[lop] >= qtr[0]):
                    qt = 1
                elif (T[lop] <= qtr[2]) & (T[lop] > qtr[1]):
                    qt = 2
                elif (T[lop] <= qtr[3]) & (T[lop] > qtr[2]):
                    qt = 3
                elif (T[lop] <= qtr[4]) & (T[lop] > qtr[3]):
                    qt = 4
                    pulse = 0
            else:
                if (T[lop]%1 <= qtr[1]) & (T[lop]%1 >= qtr[0]):
                    qt = 1
                elif (T[lop]%1 <= qtr[2]) & (T[lop]%1 > qtr[1]):
                    qt = 2
                elif (T[lop]%1 <= qtr[3]) & (T[lop]%1 > qtr[2]):
                    qt = 3
                elif (T[lop]%1 <= qtr[4]) & (T[lop]%1 > qtr[3]):
                    qt = 4
                    pulse = 0
        except:
            print('lop: {}'.format(lop))
            print(T)
        #############################################
        ##Running and recieving data for time step   
        [res,ts,pulse,temp_log] = stoc_eqs(INPUT,ts,qt,T[lop],pulse,dfe=dfe,fusion = fusion,dispersal = dispersal,b_j = b_j,d_j = d_j,d_s = d_s,d_a = d_a,t_p = t_p,por = por,fsr = fsr,ae_thresh = ae_thresh,fis_thresh = fis_thresh,lit_size = lit_size)
        
        ##Appending infection log (if log is recording)
        if (log>0) & (len(temp_log)>0):
            log_df = log_df.append(temp_log)
        
        ###############################################
        ##Force birth to occur at least once a year (if it hasn't already)
        if (qt == 3) & (pulse == 0) & (np.sum(res) > 0):
            J[-1] = lit_size*fsr*A[-1] 
            res[0] = J[-1]
            pulse = 1
            if log > 0:
                log_df = log_df.append(pd.DataFrame({'Message':['force birth occurred'],'Time':[T[lop]]}))


        #######################################################
        ##Fusion/Fission/Dispersal changes occur here
        fusion = fusion
        fis_thresh = fis_thresh
        dispersal = dispersal
        fiss_test = 0
        ev = pl.rand()
        disp_ev = pl.rand()


        ##Fission occurs here
        ###Split for troop too large
        ######This is for when the troop is too large and a portion of the group leaves######
        if np.sum(res) >= fis_thresh:
            if log>0:
                log_df = log_df.append(pd.DataFrame({'Message':['Fission occurred: Troop too large'], 'Time': [T[lop]]}))
                log_df = log_df.append(pd.DataFrame({'Message':['Adult Pop: Before:{} | After:{}'.format(res[3],int(round(0.76*res[3])))], 'Time': [T[lop]]}))
            ##Removing adults without changing sex ratio
            res[3] = 0.76*res[3] ##remove 24% of adults
       
    
        ##Dispersal/Eviction event (sex ratio changes)
        if disp_ev <= dispersal:
            ##Removing dispersed individuals
            res[3] = 0.83*res[3] ##mean dispersed of 6 out of mean troop size of 35 => ~17% of troop dispersed
        
            ##Logging event
            if log>0:
                log_df = log_df.append(pd.DataFrame({'Message':['Dispersal occurred'], 'Time': [T[lop]]}))
                log_df = log_df.append(pd.DataFrame({'Message':['Adult Pop: Before:{} | After:{}'.format(res[3],int(round(0.83*res[3])))], 'Time': [T[lop]]}))
                       
        
            ###Adjusting sex ratio for troop dispersal###
            
            ##According to "Banded mongooses: demography, life history, and social behavior" (Cant 2016), females evicted more than males
                ##AND males were never evicted without females (i.e. fsr decreseases or ratio doesn't change)
                ##AND 53% of eviction events are female only
                
            ##According to "Banded mongooses avoid inbreeding when mating with members of the same natal group" (Cant 2015), 12% of females are evicted
            
            choose = pl.rand()
            
            if choose <= 0.53: ##53% of evictions are female only
                fsr = 0.88*fsr #remove 12% of females
                if log>0:
                    log_df = log_df.append(pd.DataFrame({'Message':['Only females evicted'],'Time':[T[lop]]}))
                    
            else: ##Other 47% of events don't change sex ratio
                if log>0:
                    log_df = log_df.append(pd.DataFrame({'Message':['Males and females evicted'],'Time':[T[lop]]}))


        ##Stochastically determining whether a fusion occurs onto existing group
        if ev <= fusion:
            ##Fusion
            ###This is for when a group isn't too big to accept another group's fission into its own group###
            if np.sum(res) < fis_thresh:
                add = pl.randint(10,30)
                res[3] += add 
                if log>0:
                    log_df = log_df.append(pd.DataFrame({'Message':['Fusion occurred: {} adults added'.format(add)], 'Time':[T[lop]]}))
                
                ##No sex changes for fusion events
                


        ###############################################
        ###Transitioning remaining classes (for logic)
        if fiss_test == 0:
            S2_logic = 0


            if T[lop-1]%1 >= 8./12: #if time is past June (see transition diagram on board)
                if len(J) > 1:
                    if T[lop-1]%1 >= 11./12:#if time is past September
                        S2_end = round(res[2] + res[1]) ##move all S1 to S2
                        S2.append(S2_end)
                        S1.append(0)
                        S2_logic = 1
                        J.append(round(res[0]))

                    else: #time past June before September
                        S1_end = round(res[1] + res[0]) ##move all J to S1
                        S1.append(S1_end)
                        J.append(0)

            else:
                J.append(round(res[0]))
                S1.append(round(res[1]))

            ###############################################
            if (qt == 1) & (T[lop-1]>=1): ##if past first year in first quarter
                A_end = round(res[2] + res[3]) ##move all S2 to A
                A.append(A_end)
                S2.append(0)

            else:
                if S2_logic == 0: ##if S2 hasn't been appended yet
                    S2.append(round(res[2])) ##append S2
                    A.append(round(res[3]))

                else:
                    A.append(round(res[3]))
    if log==0:        
        return [T,J,S1,S2,A]

    else:
        return [T,J,S1,S2,A,log_df]

###############################################

In [17]:

def stoch_model(J0=0,S10=0,S20=6,A0=15,years=1,log = 0,fusion= 0.1,dispersal=0.1,b_j = 0.5,d_j = 0.345,d_s = 0.214,d_a = 0.143,t_p = 0.3,por = 0.2,fsr = 0.4,ae_thresh = 5,fis_thresh = 45,lit_size = 4,dfe='off'):

####################################################
####These are global variables within the function
    b_j = b_j ##recruitment/birth rate of juveniles
    
    d_j = d_j ##death rate of juveniles
    d_s = d_s ##death rate of subadults
    d_a = d_a ##death rate of adults
    
    ae_thresh = ae_thresh #threshold where AE takes effect and leads to extinction
    
    t_p = t_p ##Probability of transmission
    por = por ##Portion of population killed by infection
    fsr = fsr ##sex ratio for females to males
    fis_thresh = fis_thresh ##Thresh for fission to occur
    lit_size = lit_size ##average litter size
    
    fusion = fusion ##probability that a fusion will occur onto the existing group (another group's fission/dispersal added to current group)
    dispersal = dispersal ##probability that a dispersal/eviction event occurs
    
    ND=MaxTime=years

    INPUT = np.array((J0,S10,S20,A0)) 
    

    ##Model Run    
    #Param values are called from very beginning of code block
    if log == 0:
        [T,J,S1,S2,A]=Stoch_Iteration(INPUT,log=log,dfe=dfe,ND=ND,fusion = fusion,dispersal=dispersal,b_j = b_j,d_j = d_j,d_s = d_s,d_a = d_a,t_p = t_p,por = por,fsr = fsr,fis_thresh = fis_thresh,lit_size = lit_size)
    else:
        [T,J,S1,S2,A,log_df]=Stoch_Iteration(INPUT,log=log,dfe=dfe,ND=ND,fusion = fusion,dispersal = dispersal,b_j = b_j,d_j = d_j,d_s = d_s,d_a = d_a,t_p = t_p,por = por,fsr = fsr,fis_thresh = fis_thresh,lit_size = lit_size)


    def add_qtr(x):
        if x <= 0.25:
            return 1
        elif (x>0.25)&(x<=0.5):
            return 2
        elif (x>0.5)&(x<=0.75):
            return 3
        else:
            return 4
    try:   
        sim = pd.DataFrame(data={'tdate':T[0:],'J':J[0:],'S1':S1[0:],'S2':S2[0:],'A':A[0:]})
    except:
        raise Exception('T:{}, J:{}, S1:{}, S2:{}, A:{}'.format(len(T),len(J),len(S1),len(S2),len(A)))
    sim['S'] = sim[['S1','S2']].sum(axis=1)
    sim['T'] = sim[['J','S','A']].sum(axis=1)
    sim = sim[['T','J','S','A','tdate']]
    sim['quarter'] = sim.tdate.apply(lambda x: add_qtr(x))

    if log ==0:
        return sim
    else:
        return sim,log_df    
    
###############################################



