### Compare the fitness distribution using selection coefficients with different timescales

In [None]:


### execute script to load modules here
exec(open('setup_aesthetics.py').read())

In [None]:
import pandas as pd           

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches

from scipy import integrate
from scipy import stats
import random

from scipy.stats import spearmanr, pearsonr
from latex_format import float2latex


In [None]:
from bulk_simulation_code import run_pairwise_experiment
from bulk_simulation_code import CalcRelativeYield,CalcReferenceFrequency
from bulk_simulation_code import CalcTotalSelectionCoefficientLogit
from m3_model import CalcRelativeSaturationTime as CalcSaturationTimeExact

In [None]:
### Update dependent parameters according to input
import os
import os.path
from os import path

## create export directory if necessary
## foldernames for output plots/lists produced in this notebook
import os
FIG_DIR_STEM = f'./figures/selection_timescales/'
os.makedirs(FIG_DIR_STEM, exist_ok=True)


In [None]:


### execute script to load modules here
exec(open('setup_aesthetics.py').read())

In [None]:
DATASET_COLOR = 'darkorange'

In [None]:
### set initial frequency of competition

INITIAL_FREQ = 0.01


## set which trait distribution to plot

DIST = 'all_traits_vary'

See cell [below](#create-subsets-of-data-only-with-marginals) for a choice of trait distributions.

In [None]:
SUFFIX_DATASET = f'{DIST}/'

FIG_DIR = FIG_DIR_STEM + SUFFIX_DATASET
os.makedirs(FIG_DIR, exist_ok=True)


### Load wild-type traits

In [None]:
INDEX_COL = [0,1,2,3,4]
list_na_representations = ['not_present', 'failed_to_compute']

In [None]:
PCWS_TRAITS_WARRINGER = './output/df_M3_traits.csv'
df_warringer = pd.read_csv(PCWS_TRAITS_WARRINGER, header = 0, index_col= INDEX_COL,\
                                  float_precision=None, na_values=list_na_representations)


In [None]:
### define default wild_type
df_wildtypes = df_warringer[df_warringer['is_wildtype']==True]

WILDTYPE = df_wildtypes.median(axis = 0, numeric_only = True)

### Load mutant data (averaged)

In [None]:

PCWS_TRAITS_WARRINGER_AVERAGED = './output/df_M3_traits_averaged.csv'
df_averaged = pd.read_csv(PCWS_TRAITS_WARRINGER_AVERAGED, header = 0, float_precision=None)

In [None]:
### assign wild-type label
def is_wildtype(row):
    genotype = row['genotype']
    
    if genotype == 'BY4741':
        return True
    else:
        return False
    

row = df_averaged.iloc[0]
is_wildtype(row)

In [None]:
df_averaged['is_wildtype'] = df_averaged.apply(is_wildtype, axis =1)

In [None]:
### append mutant values (averaged) to set of individual wild-type strains
df_knockouts = df_averaged[~df_averaged['is_wildtype']]
df_knockouts = df_knockouts
df_all_vary = df_wildtypes.reset_index().append(df_knockouts.reset_index())

In [None]:
### restore index
index_col_names = df_warringer.index.names
df_all_vary = df_all_vary.set_index(index_col_names)


In [None]:
n_wildtypes = sum(df_all_vary['is_wildtype'])
n_knockouts = sum(~df_all_vary['is_wildtype'])

### Set units of time

In [None]:
df_all_vary['gmax'] = df_all_vary['gmax']*60 # change units to growth rate per hour
df_all_vary['lag']  = df_all_vary['lag']/60 # change units to hour



In [None]:
WILDTYPE['gmax'] = WILDTYPE['gmax']*60 # change units to growth rate per hour
WILDTYPE['lag']  = WILDTYPE['lag']/60 # change units to hour

### create subsets of data only with marginals

In [None]:
dist2data = {}

## full datadist with all traits
tmp = df_all_vary.copy(deep=True)
dist2data['all_traits_vary'] = tmp
## distribution with no yield variation
tmp = df_all_vary.copy(deep=True)
tmp['yield'] = WILDTYPE['yield']
dist2data['no_yield_variation'] = tmp
## distribution with some yield variation, but only equal or larger than wild-type
tmp = df_all_vary.copy(deep=True)
tmp['yield'] = [v if v > WILDTYPE['yield'] else WILDTYPE['yield'] for v in df_all_vary['yield']]
dist2data['no_deleterious_yield'] = tmp
## distribution with no growth rate variation
tmp = df_all_vary.copy(deep=True)
tmp['gmax'] = WILDTYPE['gmax']
dist2data['no_gmax_variation'] = tmp
## distribution with no growth rate variation
tmp = df_all_vary.copy(deep=True)
tmp['lag'] = WILDTYPE['lag']
dist2data['no_lag_variation'] = tmp

## marginal distribution in gmax
tmp = df_all_vary.copy(deep=True)
tmp['lag'] = WILDTYPE['lag']
tmp['yield'] = WILDTYPE['yield']
dist2data[ 'only_gmax_varies'] = tmp
## marginal distribution in lag
tmp = df_all_vary.copy(deep=True)
tmp['yield'] = WILDTYPE['yield']
tmp['gmax'] = WILDTYPE['gmax']
dist2data[ 'only_lag_varies'] = tmp
## marginal distribution in yield
tmp = df_all_vary.copy(deep=True)
tmp['lag'] = WILDTYPE['lag']
tmp['gmax'] = WILDTYPE['gmax']
dist2data[ 'only_yield_varies'] = tmp


### Choose subset

In [None]:
df_input = dist2data[DIST]

### Load trait data into the standard form required by Michaels code

In [None]:
n_input = df_input.shape[0]

In [None]:
### growth rates
gs = np.zeros(n_input+1)
gs[0] = WILDTYPE['gmax']
gs[1:] = df_input['gmax'].values

### lag times
ls = np.zeros(n_input+1)
ls[0] = WILDTYPE['lag']
ls[1:] = df_input['lag'].values

### yield
Ys = np.zeros(n_input+1)
Ys[0] = WILDTYPE['yield']
Ys[1:] = df_input['yield'].values


### Define initial condition for bulk growth cycle

In [None]:
### set initial resource concentrations

CONCENTRATION_GLUCOSE = 20/180 * 1e3 # concentrations are recored  in milliMolar, to match the units of yield
print(CONCENTRATION_GLUCOSE)

In [None]:
### define default initial_OD
OD_START = 0.05  #df_warringer['od_start'].median()

### compare to initial OD in the monoculture cycles
fig, ax = plt.subplots(figsize = (FIGWIDTH_TRIPLET, FIGHEIGHT_TRIPLET))

ax = df_warringer['od_start'].hist(bins=41, color = DATASET_COLOR, alpha = 0.6, log = True, rasterized = True)


ax.axvline(OD_START, color = 'tab:red', label = f'median value: $N_0={OD_START:.3f}$')
ax.legend()
ax.set_xlabel('initial OD')
ax.set_ylabel('no. growth curves')

### Calculate effective yield

In [None]:
from bulk_simulation_code import CalcRelativeYield

In [None]:
### calculcate effective yields
nus = CalcRelativeYield(Ys, R0 = CONCENTRATION_GLUCOSE, N0 = OD_START)


### Simulate pairwise competition growth cycles (scenario A)

In [None]:
from bulk_simulation_code import toPerGeneration

In [None]:
%%time
xs_pair, xs_pair_final, tsats,fcs_both, fcs_wt, fcs_mut = run_pairwise_experiment(
                                                                gs=gs,   ls=ls,   nus = nus, 
                                                                g1=gs[0],l1=ls[0],nu1=nus[0],
                                                                x0 = INITIAL_FREQ)

In [None]:
s_percycle = CalcTotalSelectionCoefficientLogit(xs_pair,xs_pair_final)
s_pergen = np.divide(s_percycle, np.log(fcs_wt))

### store results

In [None]:
df_output = df_input.copy()

In [None]:
df_output['logfc_wt'] = np.log(fcs_wt[1:])
df_output['logfc_mut'] = np.log(fcs_mut[1:])

df_output['logit_percycle'] = s_percycle[1:]
df_output['logit_pergen'] = s_pergen[1:]

df_output['logit_percycle_rank'] =df_output['logit_percycle'].rank(ascending =True, method = 'min')
df_output['logit_pergen_rank'] = df_output['logit_pergen'].rank(ascending = True, method = 'min')

df_output['deltarank'] =  df_output['logit_pergen_rank'] - df_output['logit_percycle_rank']


In [None]:
df_output.rank()

In [None]:
### consistency check

In [None]:
df_output['logit_percycle_test'] = df_output['logfc_mut'] - df_output['logfc_wt']

In [None]:
df_output['logit_percycle_residuals'] =df_output['logit_percycle_test'] - df_output['logit_percycle']

In [None]:
sns.scatterplot(df_output, x = 'logit_percycle', y = 'logit_percycle_test')

In [None]:
sns.scatterplot(df_output, x = 'logit_percycle', y = 'logit_percycle_residuals')

### Prepare data for plotting

In [None]:
### sort by label prepare for plotting

def row2label(row):
    if row['is_wildtype'] == True:
        return 'wild-type'
    else:
        return 'knockout'
    

In [None]:
df_output['label']  = df_output.apply(row2label,axis=1)

df_output = df_output.sort_values('label')

In [None]:
## sort by misranking

df_output['deltarank_abs'] = np.abs(df_output['deltarank'])
df_sorted = df_output.sort_values('deltarank_abs', ascending = False)
select = df_sorted.index[[0]]

In [None]:
df_sorted.loc[select]

### plot misranking

In [None]:
### plot residuals

palette = {'wild-type':'orange', 'knockout': 'dimgrey'}


fig, ax = plt.subplots(figsize = (FIGHEIGHT_TRIPLET, FIGHEIGHT_TRIPLET))

x_var = 'logit_percycle'
y_var = 'deltarank'
data = df_output
sns.scatterplot(data = data, x = x_var, y = y_var, rasterized = True, ax = ax,
                hue = 'label', palette = palette)

### plot select points
is_labeled = True
for i in select:
    A, B = float(data.loc[i, x_var]), float(data.loc[i, y_var])
    #ax.scatter(A-0.15,B,s=150,color ='tab:red', zorder = -1, marker = 5 )
    if is_labeled == False: 
        label = 'max. disagreement' 
        is_labeled = True
    else: label = None
    ax.scatter(A,B,s=200,color ='tab:blue', zorder = -1,label = label, alpha = 0.25)


### plot horizontal line for orientation
ax.axhline(0,ls = '--', color = 'black')

### annotate
ax.set_xlabel('relative fitness per-cycle:' + r'  $s^{\mathrm{logit}}_{\mathrm{cycle}}$')
#ax.set_ylabel('rank difference to\nrelative fitness per-generation [rank]')
ax.set_ylabel('rank difference between\nfitness ' + r'$s^{\mathrm{logit}}_{\mathrm{gen}}$ '
              + 'and fitness ' + r'$s^{\mathrm{logit}}_{\mathrm{cycle}}$')
ax.legend(loc = 'upper left', bbox_to_anchor = (-0.05,1.0), frameon=False) #inside

title = f"n = {sum(~data['is_wildtype'])} knockouts"
ax.set_title(title, loc = 'left')

fig.savefig(FIG_DIR + f"residuals_{x_var}_vs_{y_var}_x0={INITIAL_FREQ:.2f}.pdf", DPI = DPI, bbox_inches = 'tight', pad_inches = PAD_INCHES)


In [None]:
data[y_var].max()

### plot on foldchange phase diagram

In [None]:
def eval_isocline_percycle(logfc_wt, level):
    return logfc_wt + level

def eval_isocline_pergen(logfc_wt, level):
    return np.multiply((level+1),logfc_wt)

In [None]:

palette = {'wild-type':'orange', 'knockout': 'dimgrey'}

In [None]:
### plot cloud of points

fig, ax = plt.subplots(figsize = (FIGHEIGHT_TRIPLET, FIGHEIGHT_TRIPLET))

x_var = 'logfc_wt'
y_var = 'logfc_mut'
data = df_output
sns.scatterplot(data = data, x = x_var, y = y_var, rasterized = True, ax = ax,
                hue = 'label', palette = palette)



## find value limits

fcmax = np.max([data[x_var].max(),data[y_var].max()])
fcmin = np.min([data[x_var].min(),data[y_var].min()])
assert fcmin > 0

### set axis limits
#xmin = 1
#assert xmin < fcmin, 'We are cutting points from the dataplot!'
#xmax = 5.5
#assert xmax > fcmax, 'We are cutting points from the dataplot!'
#ax.set_xlim(xmin,xmax)
#ax.set_ylim(xmin,xmax)


## take off axis spines
#sns.despine(left=False, bottom = False, ax = ax)

### find axis limits
xmin, xmax = ax.get_xlim()
fcwt_vec = np.linspace(xmin,xmax, num = 100) 
fcwt_vec = np.concatenate((-fcwt_vec,fcwt_vec))
color_percycle = 'tab:grey'
color_pergen = 'navy'

### plot per cycle isoclines
levels = np.outer([-1,1],np.linspace(0.01,8,num = 6)).flatten()
levels.sort()

for level in levels: 

    y = eval_isocline_percycle(fcwt_vec, level = level)
    #ax.plot(fcwt_vec, y, color = color_percycle)
    
    
## plot per generationa isoclines
angles = np.linspace(0,np.pi/2 - 0.001, num = 6)
levels = np.outer([-1,1],np.tan(angles)).flatten()

for level in levels: 
    y = eval_isocline_pergen(fcwt_vec, level =level) 
    #ax.plot(fcwt_vec, y, color = color_pergen)
    
## plot diagonal 
xmin, xmax = ax.get_xlim()
ymin, ymax = ax.get_ylim()
ax.plot([-xmin,xmax],[-xmin,xmax], color = 'black', ls = '--')
ax.set_xlim(xmin,xmax)
ax.set_ylim(ymin,ymax)
## add  legend items
#ax.plot([],[], color = color_percycle, label = 'per-cycle $s$ isocline')
#ax.plot([],[], color = color_pergen, label = 'per-generation $Q$ isocline')

## plot red cone for a select point
#select = [0]
for i in select:
    A, B = float(data.loc[i, x_var]), float(data.loc[i, y_var])
    #ax.scatter(A,B,s=70,color ='tab:red', zorder = 3)
    #ax.scatter(A-0.05,B,s=150,color ='tab:red', zorder = -1, marker = 5 )
    #ax.scatter(A,B,s=70,color ='tab:red', zorder = 3, marker = 'v')
    ax.scatter(A,B,s=200,color ='tab:blue', zorder = -1,label = label, alpha = 0.25)


    x_fill = np.linspace(fcwt_vec[0],fcwt_vec[-1])
    y_fill = B/A*x_fill

    ax.fill_between(x_fill, (x_fill - A) + B, y_fill, color='tab:red', alpha=0.5)

### annotate
#ax.legend(loc = 'upper left', bbox_to_anchor = (1.3,1)) # outside
ax.legend(loc = 'upper left', bbox_to_anchor = (-0.05,0.25), frameon=False) #inside
title = f"n = {sum(~data['is_wildtype'])} knockouts"
#ax.set_title(title, loc = 'left')

ax.set_xlabel(r"wild-type log fold-change: $\mathrm{LFC}_{\mathrm{wt}}$")
ax.set_ylabel(r"mutant log fold-change: $\mathrm{LFC}_{\mathrm{mut}}$")

title = f"n = {sum(~data['is_wildtype'])} knockouts"
ax.set_title(title, loc = 'left')


fig.savefig(FIG_DIR+ f'scatterplot_logfc_wt_vs_logfc_mut_x0={INITIAL_FREQ:.2f}.pdf', DPI = DPI, bbox_inches = 'tight', pad_inches = PAD_INCHES)
              

### plot DFEs

In [None]:
import warnings

In [None]:
palette = {'wild-type':'orange', 'knockout': 'dimgrey', 'wild-type median':'navy'}

In [None]:
fig, ax = plt.subplots(figsize = (FIGHEIGHT_TRIPLET, FIGHEIGHT_TRIPLET))

data = df_output
x_var = 'logit_percycle'
#sns.histplot(data, x = 'logit_percycle', rasterized = True, ax = ax,
#                hue = 'label', palette = palette)

## need a context wrapper, else pandas throws a Future Warning
## see https://stackoverflow.com/questions/15777951/how-to-suppress-pandas-future-warning
with warnings.catch_warnings():
    warnings.simplefilter(action='ignore', category=FutureWarning)
    # Warning-causing lines of code here
    sns.kdeplot(data=data, x=x_var, hue="label",common_norm = True,
            palette = palette, multiple="layer", ax = ax, fill = True, legend = True)
    
### plot selection coefficient zero for orientation
ax.axvline(0, ls = '--', color = 'black')
ax.legend_.set_title('')

title = f"n = {sum(~data['is_wildtype'])} knockouts"
ax.set_title(title, loc = 'left')

ax.set_xlabel('relative fitness per-cycle:' + r'  $s^{\mathrm{logit}}_{\mathrm{cycle}}$')
ax.set_ylabel('sample density')


fig.savefig(FIG_DIR+ f'dfeplot_logit_percycle_x0={INITIAL_FREQ:.2f}.pdf', DPI = DPI, bbox_inches = 'tight', pad_inches = PAD_INCHES)
    

In [None]:
fig, ax = plt.subplots(figsize = (FIGHEIGHT_TRIPLET, FIGHEIGHT_TRIPLET))

data = df_output
x_var = 'logit_pergen'
#sns.histplot(data, x = 'logit_percycle', rasterized = True, ax = ax,
#                hue = 'label', palette = palette)

## need a context wrapper, else pandas throws a Future Warning
## see https://stackoverflow.com/questions/15777951/how-to-suppress-pandas-future-warning
with warnings.catch_warnings():
    warnings.simplefilter(action='ignore', category=FutureWarning)
    # Warning-causing lines of code here
    sns.kdeplot(data=data, x=x_var, hue="label",common_norm = True,
            palette = palette, multiple="layer", ax = ax, fill = True, legend = True)
    
### plot selection coefficient zero for orientation
ax.axvline(0, ls = '--', color = 'black')
ax.legend_.set_title('')

title = f"n = {sum(~data['is_wildtype'])} knockouts"
ax.set_title(title, loc = 'left')

ax.set_xlabel('relative fitness per-generation:' + r'  $s^{\mathrm{logit}}_{\mathrm{gen}}$')
ax.set_ylabel('sample density')


fig.savefig(FIG_DIR+ f'dfeplot_logit_pergen_x0={INITIAL_FREQ:.2f}.pdf', DPI = DPI, bbox_inches = 'tight', pad_inches = PAD_INCHES)
    

### plot global correlation in ranks

In [None]:
data.columns

In [None]:
### plot correlation

fig, ax = plt.subplots(figsize = (FIGHEIGHT_TRIPLET, FIGHEIGHT_TRIPLET))

x_var = 'logit_percycle_rank'
y_var = 'logit_pergen_rank'
data = df_output
sns.scatterplot(data = data, x = x_var, y = y_var, rasterized = True, ax = ax,
                hue = 'label', palette = palette)

### plot select points
is_labeled = True
for i in select:
    A, B = float(data.loc[i, x_var]), float(data.loc[i, y_var])
    #ax.scatter(A-0.15,B,s=150,color ='tab:red', zorder = -1, marker = 5 )
    if is_labeled == False: 
        label = 'max. disagreement' 
        is_labeled = True
    else: label = None
    #ax.scatter(A,B,s=200,color ='tab:blue', zorder = -1,label = label, alpha = 0.25)


### annotate
ax.legend(loc = 'upper left', bbox_to_anchor = (-0.05,1.0), frameon=False) #inside

ax.set_ylabel('relative fitness per-generation [rank]')
ax.set_xlabel('relative fitness per-cycle [rank]')

title = f"n = {sum(~data['is_wildtype'])} knockouts"
ax.set_title(title, loc = 'left')

fig.savefig(FIG_DIR + f"scatterplot_{x_var}_vs_{y_var}_x0={INITIAL_FREQ:.2f}.pdf", DPI = DPI, bbox_inches = 'tight', pad_inches = PAD_INCHES)
