### compare predictions from absolute fitness to pairwise competition 

In [None]:


### execute script to load modules here
exec(open('setup_aesthetics.py').read())

In [None]:
import pandas as pd           

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches

from scipy import integrate
from scipy import stats
import random

from scipy.stats import spearmanr, pearsonr
from latex_format import float2latex


In [None]:
from bulk_simulation_code import run_pairwise_experiment
from bulk_simulation_code import CalcRelativeYield,CalcReferenceFrequency
from bulk_simulation_code import CalcTotalSelectionCoefficientLogit
from m3_model import CalcRelativeSaturationTime as CalcSaturationTimeExact

In [None]:
### Update dependent parameters according to input
import os
import os.path
from os import path

## create export directory if necessary
## foldernames for output plots/lists produced in this notebook
import os
FIG_DIR_STEM = f'./figures/absolute_fitness/'
os.makedirs(FIG_DIR_STEM, exist_ok=True)


In [None]:


### execute script to load modules here
exec(open('setup_aesthetics.py').read())

In [None]:
DATASET_COLOR = 'darkorange'

In [None]:
### set initial frequency of competition

INITIAL_FREQ = 0.5


## set which trait distribution to plot

DIST = 'no_lag_variation'

### set which time to use for AUC cutoff
AUC_CUTOFF_TIME = 16

See cell [below](#create-subsets-of-data-only-with-marginals) for a choice of trait distributions.

In [None]:
SUFFIX_DATASET = f'{DIST}/AUC_CUTOFF_{AUC_CUTOFF_TIME}/'

FIG_DIR = FIG_DIR_STEM + SUFFIX_DATASET
os.makedirs(FIG_DIR, exist_ok=True)


### Load wild-type traits

In [None]:
INDEX_COL = [0,1,2,3,4]
list_na_representations = ['not_present', 'failed_to_compute']

In [None]:
PCWS_TRAITS_WARRINGER = './output/df_M3_traits.csv'
df_warringer = pd.read_csv(PCWS_TRAITS_WARRINGER, header = 0, index_col= INDEX_COL,\
                                  float_precision=None, na_values=list_na_representations)


In [None]:
### define default wild_type
df_wildtypes = df_warringer[df_warringer['is_wildtype']==True]

WILDTYPE = df_wildtypes.median(axis = 0, numeric_only = True)

### Load mutant data (averaged)

In [None]:

PCWS_TRAITS_WARRINGER_AVERAGED = './output/df_M3_traits_averaged.csv'
df_averaged = pd.read_csv(PCWS_TRAITS_WARRINGER_AVERAGED, header = 0, float_precision=None)

In [None]:
### assign wild-type label
def is_wildtype(row):
    genotype = row['genotype']
    
    if genotype == 'BY4741':
        return True
    else:
        return False
    

row = df_averaged.iloc[0]
is_wildtype(row)

In [None]:
df_averaged['is_wildtype'] = df_averaged.apply(is_wildtype, axis =1)

In [None]:
### append mutant values (averaged) to set of individual wild-type strains
df_knockouts = df_averaged[~df_averaged['is_wildtype']]
df_knockouts = df_knockouts
df_all_vary = df_wildtypes.reset_index().append(df_knockouts.reset_index())

In [None]:
### restore index
index_col_names = df_warringer.index.names
df_all_vary = df_all_vary.set_index(index_col_names)


In [None]:
n_wildtypes = sum(df_all_vary['is_wildtype'])
n_knockouts = sum(~df_all_vary['is_wildtype'])

### Set units of time

In [None]:
df_all_vary['gmax'] = df_all_vary['gmax']*60 # change units to growth rate per hour
df_all_vary['lag']  = df_all_vary['lag']/60 # change units to hour



In [None]:
WILDTYPE['gmax'] = WILDTYPE['gmax']*60 # change units to growth rate per hour
WILDTYPE['lag']  = WILDTYPE['lag']/60 # change units to hour

### create subsets of data only with marginals

In [None]:
dist2data = {}

## full datadist with all traits
tmp = df_all_vary.copy(deep=True)
dist2data['all_traits_vary'] = tmp
## distribution with no yield variation
tmp = df_all_vary.copy(deep=True)
tmp['yield'] = WILDTYPE['yield']
dist2data['no_yield_variation'] = tmp
## distribution with some yield variation, but only equal or larger than wild-type
tmp = df_all_vary.copy(deep=True)
tmp['yield'] = [v if v > WILDTYPE['yield'] else WILDTYPE['yield'] for v in df_all_vary['yield']]
dist2data['no_deleterious_yield'] = tmp
## distribution with no growth rate variation
tmp = df_all_vary.copy(deep=True)
tmp['gmax'] = WILDTYPE['gmax']
dist2data['no_gmax_variation'] = tmp
## distribution with no growth rate variation
tmp = df_all_vary.copy(deep=True)
tmp['lag'] = WILDTYPE['lag']
dist2data['no_lag_variation'] = tmp

## marginal distribution in gmax
tmp = df_all_vary.copy(deep=True)
tmp['lag'] = WILDTYPE['lag']
tmp['yield'] = WILDTYPE['yield']
dist2data[ 'only_gmax_varies'] = tmp
## marginal distribution in lag
tmp = df_all_vary.copy(deep=True)
tmp['yield'] = WILDTYPE['yield']
tmp['gmax'] = WILDTYPE['gmax']
dist2data[ 'only_lag_varies'] = tmp
## marginal distribution in yield
tmp = df_all_vary.copy(deep=True)
tmp['lag'] = WILDTYPE['lag']
tmp['gmax'] = WILDTYPE['gmax']
dist2data[ 'only_yield_varies'] = tmp


### Choose subset

In [None]:
df_input = dist2data[DIST]

### Load trait data into the standard form required by Michaels code

In [None]:
n_input = df_input.shape[0]

In [None]:
### growth rates
gs = np.zeros(n_input+1)
gs[0] = WILDTYPE['gmax']
gs[1:] = df_input['gmax'].values

### lag times
ls = np.zeros(n_input+1)
ls[0] = WILDTYPE['lag']
ls[1:] = df_input['lag'].values

### yield
Ys = np.zeros(n_input+1)
Ys[0] = WILDTYPE['yield']
Ys[1:] = df_input['yield'].values


### Define initial condition for bulk growth cycle

In [None]:
### set initial resource concentrations

CONCENTRATION_GLUCOSE = 20/180 * 1e3 # concentrations are recored  in milliMolar, to match the units of yield
print(CONCENTRATION_GLUCOSE)

In [None]:
### define default initial_OD
OD_START = 0.05  #df_warringer['od_start'].median()

### compare to initial OD in the monoculture cycles
fig, ax = plt.subplots(figsize = (FIGWIDTH_TRIPLET, FIGHEIGHT_TRIPLET))

ax = df_warringer['od_start'].hist(bins=41, color = DATASET_COLOR, alpha = 0.6, log = True, rasterized = True)


ax.axvline(OD_START, color = 'tab:red', label = f'median value: $N_0={OD_START:.3f}$')
ax.legend()
ax.set_xlabel('initial OD')
ax.set_ylabel('no. growth curves')

### Calculate effective yield

In [None]:
from bulk_simulation_code import CalcRelativeYield

In [None]:
### calculcate effective yields
nus = CalcRelativeYield(Ys, R0 = CONCENTRATION_GLUCOSE, N0 = OD_START)


### Simulate pairwise competition growth cycles

In [None]:
from bulk_simulation_code import toPerGeneration

In [None]:
%%time
xs_pair, xs_pair_final, tsats,fcs_both, fcs_wt, fcs_mut = run_pairwise_experiment(
                                                                gs=gs,   ls=ls,   nus = nus, 
                                                                g1=gs[0],l1=ls[0],nu1=nus[0],
                                                                x0 = INITIAL_FREQ)

In [None]:
s_percycle = CalcTotalSelectionCoefficientLogit(xs_pair,xs_pair_final)
s_pergen = np.divide(s_percycle, np.log(fcs_wt))

In [None]:
fig,axes = plt.subplots(1,2, figsize = (2*FIGWIDTH_TRIPLET,FIGHEIGHT_TRIPLET))
ax = axes[0]
ax.hist(tsats, log = True)
ax.set_xlabel('Saturation Time')
ax = axes[1]
ax.hist(fcs_both, log = True)
ax.set_xlabel('Total Fold-Change')

### store results

In [None]:
df_output = df_input.copy()

In [None]:
## compute true outpu

df_output['logfc_wt'] = np.log(fcs_wt[1:])
df_output['logfc_mut'] = np.log(fcs_mut[1:])

df_output['logit_percycle'] = s_percycle[1:]
df_output['logit_pergen'] = s_pergen[1:]



In [None]:
### compute proxies

df_output['delta_gmax'] = df_output['gmax'] - WILDTYPE['gmax']
df_output['delta_lag'] = df_output['lag'] - WILDTYPE['lag']
df_output['delta_yield'] = df_output['yield'] - WILDTYPE['yield']

### simulate monoculture growth

In [None]:
### To calculate the fold-change in monoculture, 
# we can use the same code but with a 100% mutant frequency

In [None]:
%%time
xs_pair, xs_pair_final, tsats,fcs_both, fcs_wt, fcs_mut = run_pairwise_experiment(
                                                                gs=gs,   ls=ls,   nus = nus, 
                                                                g1=gs[0],l1=ls[0],nu1=nus[0],
                                                                x0 = 1.)


## store the saturation time in monoculture
WILDTYPE['tsat_mono'] = tsats[0]
df_output['tsat_mono'] = tsats[1:]

## store the fold-changes in monoculture


WILDTYPE['logfc_mono'] = np.log(fcs_mut[0]) # index 0 is the median wild-type
df_output['logfc_mono'] = np.log(fcs_mut[1:])

In [None]:
## check monoculture growth
np.testing.assert_array_equal(xs_pair,1.) # initally the  mutant is 100% of biomass
np.testing.assert_array_equal(xs_pair_final,1.) # finally the  mutant is 100% of biomass

np.testing.assert_array_equal(fcs_wt, 1.)
np.testing.assert_array_equal(fcs_mut, fcs_both)

### Prepare data for plotting

In [None]:
### sort by label prepare for plotting

def row2label(row):
    if row['is_wildtype'] == True:
        return 'wild-type'
    else:
        return 'knockout'
    

In [None]:
df_output['label']  = df_output.apply(row2label,axis=1)

df_output = df_output.sort_values('label')

In [None]:
palette = {'wild-type':'orange', 'knockout': 'dimgrey'}


### Compute  delta log-foldchange in monoculture

In [None]:
### compute log-foldchange as an absolute fitness prox

df_output['delta_logfc_mono'] = df_output['logfc_mono'] - WILDTYPE['logfc_mono']

In [None]:
# check distribution of logfc values
fig,ax = plt.subplots()
sns.histplot(df_output, x= 'logfc_mono', ax = ax, hue = 'label', palette = palette)
ax.set_yscale('log')
ax.axvline(WILDTYPE['logfc_mono'], color = 'tab:red')


In [None]:
# check correlation of fold-change in mono and co-culture
fig,ax = plt.subplots()
sns.scatterplot(df_output, x= 'logfc_mono', y = 'logfc_mut', ax = ax,
                hue = 'label', palette = palette)

ax.legend([],[])
ax.legend(frameon = False)

### Compute area under the curve in monoculture

In [None]:
# define the evaluation time for area under the curve
t_trim = AUC_CUTOFF_TIME
print(f"Integrating the area under the curve up to a time t_trim = {t_trim:.2f}.")

fig,ax = plt.subplots(figsize = (FIGHEIGHT_TRIPLET,FIGHEIGHT_TRIPLET,))
ax.axvline(t_trim, color = 'black', label = 'time window for AUC')
sns.histplot(df_output, x = 'tsat_mono', hue = 'label', palette = palette, legend = True)
ax.legend_.set_title('')

ax.set_yscale('log')

ax.set_xlabel('saturation time in monoculture [hours]')
ax.set_ylabel('number of growth curves')


fig.savefig(FIG_DIR + f'choice_of_timewindow_AUC_x0={INITIAL_FREQ:.2f}.pdf', DPI = DPI, bbox_inches = 'tight', pad_inches = PAD_INCHES)


In [None]:
from bulk_simulation_code import CalcAbundanceTimeseries, CalcAreaUnderTheCurve

In [None]:
tvec = np.linspace(0,t_trim)

def row2AUC(row):
    tsat = row['tsat_mono']
    g,l = row[['gmax', 'lag']]
    t, y = CalcAbundanceTimeseries(tvec, g,l,tsat=tsat, N0 = OD_START)
    return CalcAreaUnderTheCurve(t,y, t_trim = t_trim)
    
    
    
## test
row = df_output.iloc[2]
row2AUC(row)

In [None]:
### apply to all points
df_output['AUC_mono'] = df_output.apply(row2AUC, axis = 1)

In [None]:
## calculate for wild-type

tsat = tsats[0]
g, l = gs[0], ls[0]
t, y = CalcAbundanceTimeseries(tvec, g,l,tsat=tsat, N0 = OD_START)
WILDTYPE['AUC_mono']= CalcAreaUnderTheCurve(t,y, t_trim = t_trim)

In [None]:
# check distribution of AUC values
fig,ax = plt.subplots()
sns.histplot(df_output, x= 'AUC_mono', ax = ax, hue = 'label', palette = palette)
ax.set_yscale('log')
ax.axvline(WILDTYPE['AUC_mono'], color = 'tab:red')

In [None]:
df_output['delta_AUC_mono'] = df_output['AUC_mono'] - WILDTYPE['AUC_mono']

### compare impact of different variables

In [None]:
### choose target variable

target = 'logit_percycle'### choose dataset

df_subset = df_output



In [None]:
## rich labels

column2label = {'delta_gmax':r'$\Delta$ growth rate:'+'\n'+r'$g_{\mathrm{mut}}-g_{\mathrm{wt}}$',
                'delta_lag':r'$\Delta$ lag time:'+'\n'+r'$\lambda_{\mathrm{mut}}-\lambda_{\mathrm{wt}}$',
                'delta_yield':r'$\Delta$ biomass yield:'+'\n'+r'$Y_{\mathrm{mut}}-Y_{\mathrm{wt}}$',
                'delta_logfc_mono':r'$\Delta$ log fold-change:'+'\n'+r'$\mathrm{LFC}_{\mathrm{mut}}-\mathrm{LFC}_{\mathrm{wt}}$',
                'delta_AUC_mono':r'$\Delta$ area under curve:'+'\n'+r'$\mathrm{AUC}_{\mathrm{mut}}-\mathrm{AUC}_{\mathrm{wt}}$',
               }


In [None]:
## simple labels

column2label = {'delta_gmax':r'$\Delta\; g$ ',
                'delta_lag':r'$\Delta$ lag',
                'delta_yield':r'$\Delta$ yield',
                'delta_logfc_mono':r'$\Delta\; \mathrm{LFC}$',
                'delta_AUC_mono':r'$\Delta\; \mathrm{AUC}$',
               }


### set up a correlation analysis

In [None]:
from scipy.stats import spearmanr

In [None]:
### set up datastorage for results from linear regression
df_correlation = pd.DataFrame()


for predictor  in column2label.keys():

    x = df_subset[predictor].values
    y = df_subset[target].values
    try:
        result = spearmanr(x,y)
    except Exception as e: 
        print(predictor)
        print(e)
        result = (np.nan, np.nan)
    
    series = {'predictor':predictor, 'spearman_r':result[0],
              'n_obs':len(x), 'pvalue':result[1],
              'label' :column2label[predictor]}
    df_correlation = df_correlation.append(series,ignore_index=True)


## update index
df_correlation.reset_index(drop = True, inplace = True)

In [None]:
## add absolute value
df_correlation['spearman_abs'] = np.abs(df_correlation['spearman_r'])

In [None]:
def row2masked(row):
    temp = row.copy(deep=True)
    temp['spearman_abs'] = np.nan
    return temp

In [None]:
### choose color
cmap = sns.color_palette('Set2', as_cmap=True)
color = cmap(8/8)

In [None]:
fig,ax = plt.subplots(figsize = (FIGHEIGHT_TRIPLET,FIGHEIGHT_TRIPLET,))

set_colors = set()


sns.barplot(df_correlation,x='label', y = 'spearman_abs', ax =ax, color = color, 
            order = df_correlation['label'])


    
ax.set_ylim(0,1)

ax.set_xlabel("")
ax.set_ylabel(r"magnitude of rank correlation $|\rho|$")

#rotate labels
#_ = plt.xticks(rotation=90)

#ax.legend(frameon = False)


fig.savefig(FIG_DIR + f'barplot_spearman_{target}_x0={INITIAL_FREQ:.2f}.pdf', DPI = DPI, bbox_inches = 'tight', pad_inches = PAD_INCHES)


In [None]:
### make a nice print
print(df_correlation[['predictor', 'spearman_r', 'pvalue', 'n_obs']].to_markdown())

#### load additional modules for linear regression

In [None]:
import statsmodels.api as sm
import statsmodels.formula.api as smf

In [None]:
### we use a custom function to convert 
from latex_format import float2latex

## test
float2latex(1.12345e-12, display = '.2e')

### run linear regresion

In [None]:
### set up datastorage for results from linear regression
df_results = pd.DataFrame()


for predictor  in column2label.keys():
    
    # note: the ols function automatically chooses a subset of the data
    # such that rows that are missing a predictor variable are dropped
    
    
    results = smf.ols(f'{target} ~ ' + predictor, data=df_subset).fit()
    
    series = {'predictor':predictor, 'rsquared':results.rsquared,
              'n_obs':results.nobs, 'df_model':results.df_model, 'pvalue':results.f_pvalue,
              'label' :column2label[predictor]}
    df_results = df_results.append(series,ignore_index=True)


## update index
df_results.reset_index(drop = True, inplace = True)

In [None]:
def row2masked(row):
    temp = row.copy(deep=True)
    temp['rsquared'] = np.nan
    return temp

In [None]:
### choose color
cmap = sns.color_palette('Set2', as_cmap=True)
color = cmap(8/8)

In [None]:
fig,ax = plt.subplots(figsize = (FIGHEIGHT_TRIPLET,FIGHEIGHT_TRIPLET,))

set_colors = set()


sns.barplot(df_results,x='label', y = 'rsquared', ax =ax, color = color, 
            order = df_results['label'])



### sort 
for i in df_results.index:
    
   
    ## print no datapoints
    x = i - 0.25
    y = df_results.at[i,'rsquared']+0.01
    y = 0.01
    n_obs = df_results.at[i,'n_obs']
    text = f"$n={n_obs:3.0f}$"
    ### add text
    #ax.text(x, y, text,{'fontsize':9}) 
    

    ## print p-value
    p_value = df_results.at[i,'pvalue']
    p_latex = float2latex(p_value, display=".2g")
    text = fr"$p={p_latex}$"


    x = i-0.1
    y = 0.25
    
    ### add text
    #ax.text(x, y, text,{'fontsize':10, 'rotation':90}) 
    

    

    
ax.set_ylim(0,1)

ax.set_xlabel("")
ax.set_ylabel("$R^2$ for linear model fit")

#rotate labels
#_ = plt.xticks(rotation=90)

#ax.legend(frameon = False)


fig.savefig(FIG_DIR + f'barplot_rsquared_{target}_x0={INITIAL_FREQ:.2f}.pdf', DPI = DPI, bbox_inches = 'tight', pad_inches = PAD_INCHES)


In [None]:
### make a nice print
print(df_results[['predictor', 'rsquared', 'pvalue', 'df_model', 'n_obs']].to_markdown())

### Plot the correlation with Area under the Curve

In [None]:
fig,ax = plt.subplots(figsize = (FIGHEIGHT_TRIPLET,FIGHEIGHT_TRIPLET,))

xvar = 'delta_AUC_mono'

# plot raw datapoints
sns.scatterplot(df_output, x = xvar, y = target, hue = 'label', 
                palette = palette)

# plot regression line
offset = results.params['Intercept'] 
slope = results.params[xvar]
xmin,xmax = ax.get_xlim()
xvec = np.linspace(xmin,xmax)
ax.plot(xvec, xvec*slope + offset, color = 'tab:red', lw = 2, label = 'regression')

## fix legend
ax.legend([],[])
ax.legend(frameon=False)

ax.set_xlabel(r'area under the curve: $\Delta \mathrm{AUC}$')
ax.set_ylabel('relative fitness per-cycle:' + r'  $s^{\mathrm{logit}}_{\mathrm{cycle}}$')



fig.savefig(FIG_DIR + f'scatterplot_{xvar}-vs-{target}_x0={INITIAL_FREQ:.2f}.pdf', DPI = DPI, bbox_inches = 'tight', pad_inches = PAD_INCHES)


### plot outlier with large Area under the curve

In [None]:
top_outliers = df_output.sort_values('AUC_mono', ascending = False)[:5]
top_outliers[['gmax', 'lag', 'yield', 'tsat_mono']]

In [None]:
WILDTYPE[['gmax', 'lag', 'yield', 'tsat_mono']]

In [None]:
select = top_outliers.index[0]

In [None]:
g,l, tsat = top_outliers.loc[select, ['gmax', 'lag', 'tsat_mono']]

tvec = np.linspace(0,25)
t, y = CalcAbundanceTimeseries(tvec, g,l,tsat=tsat, N0 = OD_START)

In [None]:
fig, ax = plt.subplots()
ax.plot(t,y, marker = 'x', label = 'timeseries')
ax.axvline(l, label = 'lag time', color = 'black')
ax.legend()

ax.axvline(t_trim, label = 't_trim', color = 'tab:red')
ax.set_yscale('log')
ax.set_ylabel('log absolute abundance')
ax.set_xlabel('time')

In [None]:
fig,ax = plt.subplots(figsize = (FIGHEIGHT_TRIPLET,FIGHEIGHT_TRIPLET,))

xvar = 'AUC_mono'

# plot raw datapoints
sns.scatterplot(df_output, x = xvar, y = 'yield', hue = 'label', 
                palette = palette)
