
### Plot results from Warringer fit

In [None]:
import pandas as pd

In [None]:


### execute script to load modules here
exec(open('setup_aesthetics.py').read())

In [None]:
## cell flagged with tag parameters
### parameters for merging plateaus

DATASET = 'warringer2003'  

In [None]:
### Update dependent parameters according to input

import os.path
from os import path

FIG_DIR = f'./figures/{DATASET}/'
os.makedirs(FIG_DIR, exist_ok=True)
print("All  plots will be stored in: \n" + FIG_DIR)


OUTPUT_DIR = f'./output/'
os.makedirs(OUTPUT_DIR, exist_ok=True)
print("All  newly created datafiles will be stored in: \n" + OUTPUT_DIR)

In [None]:
DATA_DIR = f'./data/{DATASET}/quantified/'


## read and write on results from piecewise fit 
PCWS_DATA_PLATEAUS = DATA_DIR + 'df_plateaus.csv'
PCWS_DATA_TRANSITIONS = DATA_DIR + 'df_transitions.csv'
PCWS_DATA_SHOULDERS = DATA_DIR + 'df_shoulders.csv'
PCWS_DATA_STATS = DATA_DIR + 'stats_by_curve.csv'

# create a new dataframe for the transition phases
PCWS_DATA_TRANSITION_PHASES = DATA_DIR + 'df_transitions_by_phase.csv'

DATA_DIR = f'./data/{DATASET}/piecewise_fit/'
# read some timecourses from the piecewise fit
PCWS_DATA_DLOGOD_TIMEPOINTS = DATA_DIR + 'dlogod_timepoints.csv'
PCWS_DATA_DLOGOD_VALUES = DATA_DIR + 'dlogod_values.csv'
PCWS_DATA_LOGOD_TIMEPOINTS = DATA_DIR + 'logod_timepoints.csv'
PCWS_DATA_LOGOD_VALUES = DATA_DIR + 'logod_values.csv'



SETUP_SCRIPT = f'setup_plateau_finder_{DATASET}.py'

assert path.isfile(SETUP_SCRIPT), f"Setup script: {SETUP_SCRIPT} does not exist."

In [None]:
dataset2colors = {'campos2018':'navy', 'chevereau2015':'firebrick', 'warringer2003':'darkorange' }
DATASET_COLOR = dataset2colors[DATASET]

In [None]:
exec(open(SETUP_SCRIPT).read())

In [None]:
## using the same interpolate function as for the plateau finding
def interpolate(x, xp, fp):
    return np.interp(x =x, xp = xp, fp = fp, left = np.nan, right = np.nan)

In [None]:
INDEX_COL = [0,1,2,3,4]
list_na_representations = ['not_present', 'failed_to_compute']

df_pcws_plateaus = pd.read_csv(PCWS_DATA_PLATEAUS, header = 0, index_col= INDEX_COL,\
                                      float_precision=None, na_values=list_na_representations)
df_pcws_transitions = pd.read_csv(PCWS_DATA_TRANSITIONS, header = 0, index_col= INDEX_COL,\
                                      float_precision=None, na_values=list_na_representations)

df_pcws_stats = pd.read_csv(PCWS_DATA_STATS, header = 0, index_col= INDEX_COL,\
                                      float_precision=None, na_values=list_na_representations)

df_pcws_dlogod_timepoints = pd.read_csv(PCWS_DATA_DLOGOD_TIMEPOINTS, header = 0, index_col= INDEX_COL,\
                                      float_precision=None, na_values=list_na_representations)
df_pcws_dlogod_values = pd.read_csv(PCWS_DATA_DLOGOD_VALUES, header = 0, index_col= INDEX_COL,\
                                      float_precision=None, na_values=list_na_representations)

df_pcws_logod_timepoints = pd.read_csv(PCWS_DATA_LOGOD_TIMEPOINTS, header = 0, index_col= INDEX_COL,\
                                      float_precision=None, na_values=list_na_representations)
df_pcws_logod_values = pd.read_csv(PCWS_DATA_LOGOD_VALUES, header = 0, index_col= INDEX_COL,\
                                      float_precision=None, na_values=list_na_representations)



def get_piecewise_logod_timeseries(name):
    t_array = df_pcws_logod_timepoints.loc[name].dropna()
    f_array = df_pcws_logod_values.loc[name].dropna()
    return t_array, f_array

def get_piecewise_deriv_timsseries(name):
    t_array = df_pcws_dlogod_timepoints.loc[name].dropna()
    df_array = df_pcws_dlogod_values.loc[name].dropna()
    return t_array, df_array



    

### Estimate number of replicates for each genotype

In [None]:
list_genes = df_pcws_stats.reset_index()['genotype']

gene2n = dict()

for v in list_genes:
    replicates = df_pcws_stats.loc[v]
    gene2n[v] = replicates.shape[0]

In [None]:
set(gene2n.values())

In [None]:
n2gene = dict()

for k, v in gene2n.items():
    n2gene[v] = n2gene.get(v,[]) + [k]

In [None]:
for k,v in n2gene.items():
    print(f"number of genotypes with {k} replicates: {len(v)}")

## Choose subset

In [None]:
### assign wild-type label
def is_wildtype(name):
    genotype = name[0]
    
    if genotype == 'BY4741':
        return True
    else:
        return False
    

df_pcws_stats['is_wildtype'] = [is_wildtype(v) for v in df_pcws_stats.index]

In [None]:
### pick subset of curves

# onlu include curves with 2 plateaus

is_two_plateaus = df_pcws_stats['no_plateaus'] == 2
df_subset = df_pcws_stats.loc[is_two_plateaus].copy(deep=True)
print(f"Only include curves with 2 plateaus. Number of curves left: {df_subset.shape[0]}")

In [None]:
# define minimum growth rate to be considered real growth plateau (from histogram)
MIN_RATE_FOR_GROWTH = 0.0011
fig, ax = plt.subplots(figsize = (FIGWIDTH_TRIPLET, FIGHEIGHT_TRIPLET))
ax = df_pcws_plateaus.loc[df_subset.index]['mean_value'].hist(bins = 41, color = DATASET_COLOR, log = True, alpha = 0.7,ax =ax)
ax.axvline(MIN_RATE_FOR_GROWTH, color= 'red', label = 'min rate for growth')
ax.legend(loc = 'upper right')
ax.set_xlabel('mean growth rate in plateau')
ax.set_ylabel('#plateaus')

ax.set_title(f"n = {df_subset.shape[0]} curves", loc = 'right') 
fig.savefig(FIG_DIR + f"histogram_for_mean_growth_rate_across_plateaus.pdf", DPI = DPI, bbox_inches = 'tight', pad_inches = PAD_INCHES)

In [None]:
# first plateau should be growth, second plateau should be no growth
def name2flag_as_M3(name):
    ### get plateaus
    plateaus_curve = df_pcws_plateaus.loc[name]
    plateau_one, plateau_two = plateaus_curve.iloc[0], plateaus_curve.iloc[1]

    ### check conditions
    plateau_one_is_growth = np.array(plateau_one['mean_value'] >= MIN_RATE_FOR_GROWTH)
    plateau_two_is_stationary = np.array(plateau_two['mean_value'] < MIN_RATE_FOR_GROWTH)

    return plateau_one_is_growth & plateau_two_is_stationary

## test
name = df_subset.index[0]
name2flag_as_M3(name)


In [None]:
is_M3_shape = [ name2flag_as_M3(v) for v in df_subset.index]

In [None]:
df_subset = df_subset.loc[is_M3_shape]

In [None]:

# first plateau should be a max type

In [None]:
CONCENTRATION_GLUCOSE = 20/180 * 1e3 # concentration in milliMolar

In [None]:
print(CONCENTRATION_GLUCOSE)

In [None]:
# define minimum growth rate to be considered real growth plateau (from histogram)
MIN_INITIAL_OD = 0.001
fig, ax = plt.subplots(figsize = (FIGWIDTH_TRIPLET, FIGHEIGHT_TRIPLET))
bins = np.arange(-0.01,0.5,step=0.01)
ax = df_raw.loc[df_subset.index]['0'].hist(bins = bins, color = DATASET_COLOR, log = True, alpha = 0.7,ax =ax)
ax.axvline(MIN_RATE_FOR_GROWTH, color= 'red', label = 'min rate for growth')
ax.legend(loc = 'upper right')
ax.set_xlabel('OD value after background substraction')
ax.set_ylabel('#plateaus')
ax.set_xlim(-0.01,0.01)
ax.set_title(f"n = {df_subset.shape[0]} curves", loc = 'right') 
fig.savefig(FIG_DIR + f"histogram_for_initial_OD.pdf", DPI = DPI, bbox_inches = 'tight', pad_inches = PAD_INCHES)

In [None]:
# identify point of maximum rate in the instantaneous growath rate




def name2yield(name):
    ## get timeseries with all timepoints
    t, od_excess = get_excess_od_timeseries_before_trim(name)
    # estimate OD from timepoint of saturation
    plateau_two = df_pcws_plateaus.loc[name].iloc[1]
    tsat = (plateau_two['t_end'] + plateau_two['t_start'])/2
    od_end = np.interp(x=tsat, xp = t, fp = od_excess)
    
    # estimate OD from average of initial points
    ## find timepoints that are positive
    #is_positive = od_excess >0
    #log_od_start= np.log(od_excess[is_positive])[:3].mean()
    
    log_od_start = np.log(od_excess[:3]).mean()
    od_start = np.exp(log_od_start)
    
    ## calculate yield
    Y = (od_end - od_start)/CONCENTRATION_GLUCOSE # OD per milliMOlar

    return od_start, od_end, Y

### test
name = df_subset.index[0]
od_start,od_end, Y = name2yield(name)

def name2max_growth_moment(name):
    ### get maximum growth rate
    plateau_one = df_pcws_plateaus.loc[name].iloc[0]
    gmax = plateau_one['mean_value']
    t_gmax = (plateau_one['t_end'] + plateau_one['t_start'])/2
    #t_gmax = plateau_one['t_start']
    #t_gmax = plateau_one['t_crit']
    return t_gmax, gmax
## test
t_gmax, gmax = name2max_growth_moment(name)


def name2lag_time(name):
    t, od = get_excess_od_timeseries_before_trim(name)
    ## get OD at initial point
    od_start, _, _ = name2yield(name)
    log_od_start = np.log(od_start)
    
    # get OD at point of maximum growth moment
    t_gmax, gmax = name2max_growth_moment(name)
    log_od_gmax = np.interp(x=t_gmax, xp = t, fp = np.log(od)) # we use log to interpolate, since then linear

    # use this to infer the lag time
    lag_time = t_gmax  + (log_od_start - log_od_gmax)/gmax
    
    return lag_time

## test
lag_time = name2lag_time(name)
lag_time





# plot the M3 like fit to the growth curves

In [None]:
%%time

for name in df_subset.index:
    _, gmax = name2max_growth_moment(name)
    df_subset.at[name, 'gmax'] = gmax
    lag_time = name2lag_time(name)
    df_subset.at[name, 'lag']  = lag_time
    od_start,_, Y = name2yield(name)
    df_subset.at[name, 'yield'] = Y
    df_subset.at[name, 'od_start'] = od_start
    
    

### Exclude curves where initial OD is negative

In [None]:
## check for nans

is_nan = df_subset['lag'].isna() | df_subset['gmax'].isna() | df_subset['yield'].isna()

print(is_nan.sum())

### exclude

df_subset = df_subset.loc[~is_nan]

### Exclude curves with negative lag time

In [None]:
is_negative_lag = np.array([v < 0 for v in df_subset['lag']])
sum(is_negative_lag)

In [None]:
list_negative_lag = df_subset.loc[is_negative_lag].sort_values('lag', ascending = True).index

In [None]:
### exclude

df_subset = df_subset.loc[~is_negative_lag]

In [None]:
## update wild-type index
is_wildtype = df_subset['is_wildtype']==True

In [None]:
## calculate coverage

df_subset.shape[0]/10200

In [None]:
fig, axes = plt.subplots(1,2, figsize = (2*FIGHEIGHT_TRIPLET, FIGHEIGHT_TRIPLET))

n_datapoints = df_subset.shape[0]
is_wildtype = df_subset['is_wildtype']==True

ax = axes[0]
### plot mutant dataset
x= df_subset.loc[~is_wildtype]['gmax'].values
y = df_subset.loc[~is_wildtype]['lag'].values
ax.scatter(x,y, color = 'silver', label = 'knockout', alpha = 0.6, rasterized = True)
## plot wild-type
x = df_subset.loc[is_wildtype]['gmax'].values
y = df_subset.loc[is_wildtype]['lag'].values
ax.scatter(x,y, color = 'tab:green', alpha = 1,  label = 'wild-type', rasterized = True)
ax.set_ylabel('lag time [min]')
ax.set_xlabel('realized growth rate [per min]')
ax.legend(loc =  'upper left')
title = f"n={n_datapoints} growth curves"
ax.set_title(title, loc = 'right')

ax = axes[1]
### plot mutant dataset
x= df_subset.loc[~is_wildtype]['gmax'].values
y = df_subset.loc[~is_wildtype]['yield'].values
ax.scatter(x,y, color = 'silver', label = 'knockout', alpha = 0.6, rasterized = True)
### plot wild-tpe
x = df_subset.loc[is_wildtype]['gmax'].values
y = df_subset.loc[is_wildtype]['yield'].values
ax.scatter(x,y, color = 'tab:green', alpha = 1, label = 'wild-type', rasterized = True)

ax.set_ylabel('biomass yield [OD/$\mu$M glucose]')
ax.set_xlabel('realized growth rate [per min]')
title = f"n={n_datapoints} growth curves"
ax.set_title(title, loc = 'right')
ax.legend(loc =  'upper left')
fig.tight_layout()

fig.savefig(FIG_DIR + f"correlations_using_M3_traits.pdf", DPI = DPI, bbox_inches = 'tight', pad_inches = PAD_INCHES)

In [None]:
from selection_coefficient import Problem_M3, sol_exact_M3

In [None]:


def name2Problem_M3(name):
    curve = df_subset.loc[name]

    strain_params = {'lam':[curve['lag'], 0.], 'g':[curve['gmax'], 1.], 'Y':[curve['yield'],1.]}
    initial_conditions = {'R_0':CONCENTRATION_GLUCOSE, 'N_0' : curve['od_start'], 'x': 0.0 }
    problem = Problem_M3(**strain_params, **initial_conditions)
    
    return problem

##
name = df_subset.index[0]
problem = name2Problem_M3(name)
problem.params()

In [None]:
### plot growth curve


def name2plot(name, ax = None):
    if ax == None:
        fig, ax = plt.subplots(figsize = (FIGWIDTH_TRIPLET, FIGHEIGHT_TRIPLET))
        

    t_full, od_excess = get_excess_od_timeseries_before_trim(name)
    ax.scatter(t_full, od_excess, marker = 'o', color = 'navy', label = 'before trim')
    
    t_trimmed, od_trimmed = get_excess_od_timeseries(name)
    ax.scatter(t_trimmed, od_trimmed, marker = 'o', color = 'tab:orange', label = 'after trim')

    #t, logod_pcws = get_piecewise_logod_timeseries(name)
    #ax.plot(t, np.exp(logod_pcws), color = 'tab:blue', label = 'pcws fit', ls= '--')

    problem = name2Problem_M3(name)
    fit = [sol_exact_M3(t=v, problem=problem) for v in t_full]
    ax.plot(t_full,fit, color = 'black', label = 'M3 fit')

    ax.set_yscale('log')
    ax.set_xlabel('time t')
    ax.set_ylabel('populationn size [OD]')
    ax.legend(loc = 'lower right')
    
    ax.set_ylim(ymin = 0.0005, ymax = 10)
    title = "curve id: " + str(name)
    ax.set_title(title, loc = 'right')
    ax.set_xlim(0,3000)
    return ax

## test

name = df_subset.index[is_wildtype][0]
ax = name2plot(name)



In [None]:
fig, ax = plt.subplots(figsize = (FIGWIDTH_TRIPLET, FIGHEIGHT_TRIPLET))

t, growth_rate = get_piecewise_deriv_timsseries(name)

ax.plot(t,growth_rate)
ax.set_ylabel('growth rate')
ax.set_xlim(xmin=0)

In [None]:
### calculate rsquared to growth curve



def name2rsquared(name):
    t, od_excess = get_excess_od_timeseries_before_trim(name)
    problem = name2Problem_M3(name)
    fit = [sol_exact_M3(t=v, problem=problem) for v in t]

    ## calculaate rsquared on a log-scale
    y_hat = np.log(od_excess)
    y = np.log(fit)
    sum_residuals = np.power(y_hat -y,2).sum()
    sum_total = np.power(y_hat - y_hat.mean(),2).sum()

    rsquared = 1 - (sum_residuals/sum_total)
    return rsquared

## test
name = df_subset.index[0]
name2rsquared(name)

In [None]:
for name in df_subset.index:
    df_subset.at[name, 'rsquared'] = name2rsquared(name) 

In [None]:
MIN_RSQUARED_TO_INCLUDE = 0.95
fig, ax = plt.subplots(figsize = (FIGWIDTH_TRIPLET, FIGHEIGHT_TRIPLET))
ax = df_subset['rsquared'].hist(bins = 41, color = DATASET_COLOR, log = True, alpha = 0.7, ax =ax)
ax.axvline(MIN_RSQUARED_TO_INCLUDE, color = 'tab:red', label = 'minimum quality of fit required')
ax.set_xlabel('quality of fit: $R^2$')
ax.set_ylabel('#curves')
ax.legend()
ax.set_title(f"n = {df_subset.shape[0]} curves", loc = 'right') 

fig.savefig(FIG_DIR + f"histogram_for_rsquared_after_fit.pdf", DPI = DPI, bbox_inches = 'tight', pad_inches = PAD_INCHES)

In [None]:
is_low_quality = np.array([v < MIN_RSQUARED_TO_INCLUDE for v in df_subset['rsquared']])
sum(is_low_quality)

In [None]:
list_low_quality = df_subset.loc[is_low_quality].sort_values('rsquared', ascending = True).index

In [None]:
fig, axes = plt.subplots(4,3, figsize = (3*FIGWIDTH_TRIPLET, 4*FIGHEIGHT_TRIPLET))

axes = axes.flatten()



for ax,i in zip(axes, range(12)):
    
    name = list_low_quality[i]
    name2plot(name, ax = ax)
    rsquared = df_subset.at[name,'rsquared']
    title = f"$R^2 = {rsquared:.2f}$, {name}"
    ax.set_title(title, loc = 'right')
    
fig.tight_layout()


fig.savefig(FIG_DIR + f"growthcurves_for_outliers_with_low_rsquared.pdf", DPI = DPI, bbox_inches = 'tight', pad_inches = PAD_INCHES)

In [None]:
### exclude low quality fits

df_subset = df_subset.loc[~is_low_quality]

In [None]:
## drop columns
list_cols_to_drop = ['plat_threshold', 'plat_duration', 'plat_distance', 'plat_atol', 'plat_rtol',\
                     'tran_threshold', 'tran_duration', 'tran_distance', 'tran_atol','tran_rtol',\
                    'no_plateaus', 'no_mono_violations',  'curve_rsquared_fd', 'final_gap_logod']

df_subset= df_subset.drop(list_cols_to_drop, axis = 1)

In [None]:
### store dataset
PCWS_OUTPUT_TRAITS = OUTPUT_DIR + "df_M3_traits.csv"
df_subset.to_csv(PCWS_OUTPUT_TRAITS, index = True, float_format= '%.6e', na_rep= 'removed')

In [None]:
## reread and test

df = df_subset
filename = PCWS_OUTPUT_TRAITS

print('#####################################')
print('\nTesting the data stored in ' + filename)
df_reread = pd.read_csv(filename, header = 0, index_col= INDEX_COL,\
                                  float_precision=None, na_values=list_na_representations)
print("Testing stored float values.")
float_columns = df.dtypes == 'float64'

x = df_reread.loc[:,float_columns].values
y = df.loc[:,float_columns].values

try:
    np.testing.assert_array_equal(x,y)
except AssertionError as e:
    print(e)

print("\nTesting stored values of other type, mostly strings.")
other_columns = ~float_columns
x = df_reread.loc[:,other_columns]
y = df.loc[:,other_columns]


try:
    assert x.equals(y)
    print("Success. All values of other type stored correctly.")
except Exception as e:
    print("Fail. Check true datatypes for columns marked as other in dataframe.")
    print(e)
    

### manual outlier inspection

In [None]:
list_gmax_high = df_subset.sort_values('gmax', ascending = False).index[:10]
list_gmax_low = df_subset.sort_values('gmax', ascending = True).index[:10]
list_lag_high = df_subset.sort_values('lag', ascending = False).index[:10]
list_lag_low = df_subset.sort_values('lag', ascending = True).index[:10]
list_yield_high = df_subset.sort_values('yield', ascending = False).index[:10]
list_yield_low = df_subset.sort_values('yield', ascending = True).index[:10]

In [None]:
fig, axes = plt.subplots(2,3, figsize = (3*FIGWIDTH_TRIPLET, 2*FIGHEIGHT_TRIPLET))

axes = axes.flatten()


for ax,i in zip(axes, range(6)):

    name2plot(list_yield_low[i], ax = ax)
    
fig.tight_layout()