In [None]:

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from scipy import stats

import seaborn as sns

In [None]:


### execute script to load modules here
exec(open('setup_aesthetics.py').read())

In [None]:
## create export directory if necessary
## foldernames for output plots/lists produced in this notebook
import os
FIG_DIR = f'./figures/encodings_examples/'
os.makedirs(FIG_DIR, exist_ok=True)
print("All  plots will be stored in: \n" + FIG_DIR)

### Define encodings

In [None]:
### define helper function
def logit(x):
    return np.log(np.divide(x,1-x))
## test
logit(0.5)

In [None]:
def eval_statistic(xf,x0, phi = logit):
    return phi(xf) - phi(x0)

In [None]:
eval_statistic(xf = 0.55, x0 = 0.45)

In [None]:
def eval_statistic_s(xf,x0):
    return eval_statistic(xf=xf,x0=x0, phi = logit) 

def eval_statistic_deltalog(xf,x0):
    return eval_statistic(xf=xf,x0=x0, phi =lambda x: np.log(x)) 

In [None]:
palette = sns.color_palette("Set2")

### simulate trajectory to fixation

Formulating the ODE problem in python is not obvious, the underlying SCIPY module requires a specific format for the derivative. 

See an example for ODE solving in Python here: https://pundit.pratt.duke.edu/wiki/Python:Ordinary_Differential_Equations/Examples

In [None]:
from scipy.integrate import solve_ivp

### Logistic Equation of growth

In [None]:

### define logistic equation
def fun(t, y, r = 5):
    return r*np.multiply(y,(1-y))*np.heaviside(y,0)

## define timewindow of solution
tspan = [0,5]

### define initial condition
y0 = [1e-6]


sol = solve_ivp(fun, tspan, y0 = y0, atol = 1e-10, rtol = 1e-8, vectorized = True)
yraw = sol.y[0]

In [None]:
fig, axes = plt.subplots(3,1, figsize = (FIGHEIGHT_TRIPLET, 1*FIGHEIGHT_TRIPLET), sharex = True)

ax = axes[0]
#ax.set_ylabel('frequency $x$')
#twin = ax.twinx()
ax.plot(sol.t,yraw, color = palette[0], lw = 3)



ax = axes[1]
#ax.set_ylabel('$\log\; x$')
#twin = ax.twinx()
ax.plot(sol.t,np.log(yraw),color = palette[1], lw = 3)



ax = axes[2]
#ax.set_ylabel('$\mathrm{logit}\; x$')
#twin = ax.twinx()
ax.plot(sol.t,logit(yraw), color = palette[2], lw = 3)

for ax in axes:
    ax.tick_params(left = False, labelleft= False)
    ax.set_xlim(sol.t[0],sol.t[-1])
    ax.tick_params(left = True, labelleft = True)

ax.set_xlabel('time',labelpad =10)
ax.tick_params(labelbottom = False)

fig.savefig(FIG_DIR + f'example_trajectories_logistic.pdf', DPI = DPI, bbox_inches = 'tight', pad_inches = PAD_INCHES)
  

### Gompertz Equation of growth

In [None]:

### define gompertz equation
def fun(t, y, r = 5):
    return r*np.multiply(y,np.log(1/y))*np.heaviside(y,0)

## define timewindow of solution
tspan = [0,1.5]

sol = solve_ivp(fun, tspan, y0 = y0, atol = 1e-10, rtol = 1e-8, vectorized = True)
yraw = sol.y[0]

In [None]:
fig, axes = plt.subplots(3,1, figsize = (FIGHEIGHT_TRIPLET, 1*FIGHEIGHT_TRIPLET), sharex = True)

ax = axes[0]
#ax.set_ylabel('frequency $x$')
#twin = ax.twinx()
ax.plot(sol.t,yraw, color = palette[0], lw = 3)



ax = axes[1]
#ax.set_ylabel('$\log\; x$')
#twin = ax.twinx()
ax.plot(sol.t,np.log(yraw),color = palette[1], lw = 3)



ax = axes[2]
#ax.set_ylabel('$\mathrm{logit}\; x$')
#twin = ax.twinx()
ax.plot(sol.t,logit(yraw), color = palette[2], lw = 3)

for ax in axes:
    ax.tick_params(left = False, labelleft= False)
    ax.set_xlim(sol.t[0],sol.t[-1])
    ax.tick_params(left = True, labelleft = True)

ax.set_xlabel('time',labelpad =10)
ax.tick_params(labelbottom = False)

fig.savefig(FIG_DIR + f'example_trajectories_gompertz.pdf', DPI = DPI, bbox_inches = 'tight', pad_inches = PAD_INCHES)
  

### Plot heteroscedasticity for regression on raw variable

In [None]:
from scipy.stats import linregress

In [None]:

### define logistic equation
def fun(t, y, r = 5):
    return r*np.multiply(y,(1-y))*np.heaviside(y,0)

## define timewindow of solution
tspan = [0,5]

### define initial condition
y0 = [1e-6]


sol = solve_ivp(fun, tspan, y0 = y0, atol = 1e-10, rtol = 1e-8, vectorized = True)


In [None]:
## define the true trajectory
tvec = sol.t
yraw = sol.y[0]

def time2freq(t):
    return np.interp(t, tvec,yraw)

In [None]:
rng = np.random.default_rng(seed = 304576)

In [None]:
### demonstration: using example from scipy documentation

n_trials, p_success = 10, .5  # number of trials, probability of each trial

# result of flipping a coin 10 times, tested 1000 times.
rng.binomial(n_trials, p_success, 10)/n_trials


In [None]:
## demonstration: how this sampling works

n_trials = int(1e5*0.1) # number of cells that are extracted
size = 5 # number of extraction attempts
rng.binomial(n_trials, p_success, size = size)/n_trials

In [None]:
## sample binomial data along a regression
n_timepoints = 5
### chosse timewindow for sampliing
t_start, t_end = 2.2,3.2 # middle window
#t_start, t_end = 1.,2. # early window
#t_start, t_end = 2.5, 3.5  #late window
timepoints = np.linspace(t_start,t_end,num = n_timepoints)

### choose parameters for the cell count and frequency estimation
n_cells_counted = int(1e2) # number of cells that are extracted
n_extractions = 50 # number of extraction attempts

### simulate extraction process
freq_samples = np.ones((n_timepoints,n_extractions))

for i in range(len(timepoints)): 
    freq_true = time2freq(timepoints[i])
    freq_samples[i] =  rng.binomial(n_cells_counted, freq_true, 
                                    size = n_extractions)/n_cells_counted
    


## create similar matrix of timepoints
timepoint_samples = np.outer(timepoints, np.ones(n_extractions))

## flatten both objects
freq_samples = freq_samples.flatten()
timepoint_samples = timepoint_samples.flatten()

In [None]:
### build regression models for the different encodings

## no encoding
slope_raw, intercept_raw, _, _, _ = linregress(x=timepoint_samples, y=freq_samples)
def time2freq_fit(t):
    return intercept_raw + slope_raw*t

In [None]:
yregression = time2freq_fit(timepoints)
freq_fitted = np.outer(yregression,np.ones(n_extractions)).flatten()
freq_true = time2freq(timepoint_samples)

freq_residuals = freq_samples - freq_fitted

freq_residual_mean = np.zeros_like(timepoints)
for i in range(len(timepoints)):
    t = timepoints[i]
    freq_residual_mean[i] = freq_residuals[timepoint_samples ==t].mean()

In [None]:
yregression

In [None]:
## log encoding
## remove inf values
is_finite = freq_samples != 0
log_timepoints = timepoint_samples[is_finite]
log_samples = np.log(freq_samples[is_finite])

slope_log, intercept_log, _, _, _ = linregress(x= log_timepoints, y=log_samples)
def time2log_fit(t):
    return intercept_log + slope_log*t

yregression = time2log_fit(timepoints)
log_fitted = np.outer(yregression,np.ones(n_extractions)).flatten()[is_finite]
log_true = np.log(freq_true)[is_finite]

log_residuals = log_samples - log_fitted


In [None]:

log_residual_mean = np.zeros_like(timepoints)
for i in range(len(timepoints)):
    t = timepoints[i]
    log_residual_mean[i] = log_residuals[log_timepoints ==t].mean()

In [None]:
## logitit encoding
## remove inf values
is_finite = (freq_samples != 0) & (freq_samples != 1)
logit_timepoints = timepoint_samples[is_finite]
logit_samples = logit(freq_samples[is_finite])

slope_logit, intercept_logit, _, _, _ = linregress(x=logit_timepoints, y=logit_samples)
def time2logit_fit(t):
    return intercept_logit + slope_logit*t

yregression = time2logit_fit(timepoints)
logit_fitted = np.outer(yregression,np.ones(n_extractions)).flatten()[is_finite]
logit_true = logit(freq_true)[is_finite]

logit_residuals = logit_samples - logit_fitted

In [None]:
logit_residual_mean = np.zeros_like(timepoints)
for i in range(len(timepoints)):
    t = timepoints[i]
    logit_residual_mean[i] = logit_residuals[logit_timepoints ==t].mean()

In [None]:
fig, axes = plt.subplots(3,2, figsize = (2.5*FIGHEIGHT_TRIPLET, 1*FIGHEIGHT_TRIPLET), sharex=True)

ax = axes[0,0]
ax.set_ylabel('frequency $x$')
ax.plot(tvec,yraw, color = palette[0], lw = 3)
## plot samples with regression
ax.scatter(timepoint_samples, freq_samples, color = 'tab:grey')
ax.plot(timepoints,time2freq_fit(timepoints), color = 'tab:red')

### plot residuals
ax = axes[0,1]
ax.scatter(timepoint_samples,freq_residuals, color= 'tab:grey')
ax.scatter(timepoints, freq_residual_mean, marker = 'x',color = 'tab:red')
ax.plot(timepoints, freq_residual_mean, color = 'tab:red', lw = 3,zorder = 3)
ax.axhline(0, ls = '--', color = 'black')


ax = axes[1,0]
ax.set_ylabel('$\log\; x$')
ax.plot(tvec,np.log(yraw),color = palette[1], lw = 3)
## plot samples with regression

ax.scatter(timepoint_samples, np.log(freq_samples), color = 'tab:grey')
ax.plot(timepoints,time2log_fit(timepoints), color = 'tab:red')


### plot residuals
ax = axes[1,1]

ax.scatter(log_timepoints,log_residuals, color= 'tab:grey')
ax.scatter(timepoints, log_residual_mean, marker = 'x',color = 'tab:red')
ax.plot(timepoints, log_residual_mean, color = 'tab:red', lw = 3,zorder = 3)

ax.axhline(0, ls = '--', color = 'black')




ax = axes[2,0]
ax.set_ylabel('$\mathrm{logit}\; x$')
#twin = ax.twinx()
ax.plot(tvec,logit(yraw), color = palette[2], lw = 3)

## plot samples with regression
ax.scatter(timepoint_samples, logit(freq_samples), color = 'tab:grey')
ax.plot(timepoints,time2logit_fit(timepoints), color = 'tab:red')

### plot residuals
ax = axes[2,1]
ax.scatter(logit_timepoints,logit_residuals, color= 'tab:grey')
ax.scatter(timepoints, logit_residual_mean, marker = 'x',color = 'tab:red')
ax.plot(timepoints, logit_residual_mean, color = 'tab:red', lw = 3, zorder = 3)
ax.axhline(0, ls = '--', color = 'black')


for ax in axes[:,0]:
    ax.tick_params(left = False, labelleft= False)
    ax.set_xlim(tvec[0],tvec[-1])
    
for ax in axes[:,1]:
    ymin,ymax = ax.get_ylim()
    yabs = np.max(np.abs([ymin,ymax]))*1.1
    ax.set_ylim(-yabs,yabs)
    
    ax.tick_params(left = False, labelleft= False)
    ax.set_ylabel('residuals')
    


axes[2,0].set_xlabel('time')
axes[2,1].set_xlabel('time')
#ax.tick_params(labelbottom = False)

fig.savefig(FIG_DIR + f'heteroscedasticity_comparison.pdf', DPI = DPI, bbox_inches = 'tight', pad_inches = PAD_INCHES)
  

### calculate standard deviation from binomial sampling

To calculate the error of sampling, we need to specify
- the true frequency `p`
- the number of samples drawn from the urn `n`
- the size of the population `N` in the urn. 

The variation on a set of samples `x_i`, where `x_i` are realizations of the bionmial random variable so they are either `x_i = 0` (head) or `x_i=1` tail is given by 

    mean = sum x_i/n
    var = sum (x_i - mean)^2
    
For the binomial distribution, we know that

    var = n*p*(1-p)
    

What makes this potentially confusing is that in many experiments, the number of samples collected is varied with the true frequency `p`. For example, if the frequency of a focal strain is low, then the experimentalists keeps counting colonies until he has at least `100` colonies counted of the rare type. This effectively means `n*p` is fixed.

In a simplification, here we only plot the error from single draw `n=1`.

In [None]:
xvec = np.linspace(0.01,0.99,num = 500)

In [None]:
var = np.multiply(xvec,1-xvec)

In [None]:
fig, axes = plt.subplots(3,1, figsize = (1.5*FIGHEIGHT_TRIPLET, 1.33*FIGHEIGHT_TRIPLET), sharex = True)

ax = axes[0]
ax.set_ylabel('frequency $x$')
#twin = ax.twinx()
ax.plot(xvec, var, color = palette[0], lw = 3)



ax = axes[1]
ax.set_ylabel('$\log\; x$')
#twin = ax.twinx()
ax.plot(xvec,np.log(var),color = palette[1], lw = 3)



ax = axes[2]
ax.set_ylabel('$\mathrm{logit}\; x$')
#twin = ax.twinx()
ax.plot(xvec,logit(var), color = palette[2], lw = 3)

for ax in axes:
    #ax.tick_params(left = False, labelleft= False)
    pass

ax.set_xlabel('time')
ax.tick_params(labelbottom = False)


fig.savefig(FIG_DIR + f'example_error_trajectories.pdf', DPI = DPI, bbox_inches = 'tight', pad_inches = PAD_INCHES)
  

Under the encoding we need to compute the variation as follows. 

First we compute the expected value, then the va

    E[f(x)] = f(0)*(1-p) * f(0)* p
    
The problem: For our transforms `f=log` and `f=logit` the values are infinite in this simple calculation.
    

### plot a binomial distribution

In [None]:
### create a sample of initial frequencies
rs = np.random.RandomState(15031998)



### size of the random vector
### basically, number of replicate experiments
size = 100000

df_data = pd.DataFrame()

n_sampled = 1000 #/xtrue #number of balls drawn from the urn at each replicate experiment
n_success  = 100

for xtrue, n_sampled in zip([0.99, 0.5, 0.01],[n_success/0.01, n_success/0.5, n_success/0.01]):

    

    dist = rs.binomial(n=n_sampled,p=xtrue,size = size)/n_sampled

    ## sample raw frequencies
    df_raw = pd.DataFrame(data=np.vstack([xtrue*np.ones_like(dist),dist]).T, columns = ['true frequency', 'value'])
    
    
    mean,std = df_raw['value'].mean(), df_raw['value'].std()
    df_raw['rescaled'] = (df_raw['value'] - mean)/std
    df_raw['residual'] = df_raw['value'] - mean
    df_raw['type'] = 'no encoding'
    df_data = df_data.append(df_raw)
    
    ### evaluate under logit transform 
    df_logit = df_raw.copy(deep=True) 
    df_logit['value'] = np.array([logit(v) for v in df_raw['value']])
    df_logit['type'] = 'encoded with logit'
    # Replacing infinite with nan
    df_logit = df_logit.replace([np.inf, -np.inf], np.nan)
    mean,std = df_logit['value'].mean(), df_logit['value'].std()
    df_logit['rescaled'] = (df_logit['value'] - mean)/std
    df_logit['residual'] = df_logit['value'] - mean
    df_data = df_data.append(df_logit)
    
    ## evaluate under log transform 
    df_log = df_raw.copy(deep=True) 
    df_log['value'] = np.array([np.log(v) for v in df_raw['value']])
    df_log['type'] = 'encoded with log'
    df_log = df_log.replace([np.inf, -np.inf], np.nan)
    mean,std = df_log['value'].mean(), df_log['value'].std()
    df_log['rescaled'] = (df_log['value'] - mean)/std
    df_log['residual'] = df_log['value'] - mean
    df_data = df_data.append(df_log)


    
    

In [None]:


# Replacing infinite with nan
df_data = df_data.replace([np.inf, -np.inf], np.nan)

In [None]:
sum(df_data['value'].isna())

In [None]:
palette = sns.color_palette("Set2")

In [None]:
## sort
df_data = df_data.sort_values(['type', 'true frequency'], ascending = False)

In [None]:
fig, axes = plt.subplots(1, 3, figsize = (FIGHEIGHT_TRIPLET,FIGHEIGHT_TRIPLET), sharey = True)

ax = axes[0]
label = 'no encoding'
data_to_plot =df_data[df_data['type']== label]
sns.violinplot(x='residual',y = 'true frequency', data=data_to_plot, ax =ax,orient = 'h',
              label=label,  color = palette[0], scale = 'count', rasterized = True, 
               inner = None, cut = 0)
ax.set_xlabel(label)
sns.despine(ax=ax)


ax = axes[1]
label = 'encoded with log'
data_to_plot =df_data[df_data['type']== label]
sns.violinplot(x='residual',y = 'true frequency', data=data_to_plot, ax =ax,orient = 'h',
              label=label,  color = palette[1], scale = 'count', rasterized = True,
              inner = None, cut = 0)
sns.despine(ax=ax, left = True)
ax.tick_params(left=False)
ax.set_ylabel("")
ax.set_xlabel('log')

ax = axes[2]
label = 'encoded with logit'
data_to_plot =df_data[df_data['type']== label]
sns.violinplot(x='residual',y = 'true frequency', data=data_to_plot, ax =ax,orient = 'h',
              label=label,  color = palette[2], scale = 'count', rasterized = True, 
               inner = None, cut = 0)
ax.set_xlabel('logit')
ax.set_ylabel("")
sns.despine(ax=ax, left = True)
ax.tick_params(left=False)


for ax in axes: 
    ### symmetrize
    xmin,xmax = ax.get_xlim()
    max_abs = np.abs([xmin,xmax]).max()
    ax.set_xlim(-max_abs,max_abs)





fig.savefig(FIG_DIR + f'example_distributions_after_encoding.pdf', DPI = DPI, bbox_inches = 'tight', pad_inches = PAD_INCHES)
              

In [None]:
print("Standard deviations")
for label in ['no encoding', 'encoded with logit', 'encoded with log']:
    print(label)
    df_bytype =df_data[df_data['type']== label]
    for xtrue in [0.99, 0.5, 0.01]:
        data = df_bytype[df_bytype['true frequency'] == xtrue]
        print(data['value'].std())

### plot overview of encodings

In [None]:
palette = sns.color_palette("Set2")

In [None]:
fig, ax = plt.subplots(figsize = (FIGHEIGHT_TRIPLET,FIGHEIGHT_TRIPLET))

x0_vec = np.linspace(0.0001,0.9999, num = 100)

ax.plot(x0_vec,x0_vec, color = palette[0], label = '$m=x$', lw = 3)
ax.plot(x0_vec, np.log(x0_vec), color = palette[1], ls = '-', label = '$m=\log(x)$', lw = 3)
ax.plot(x0_vec, logit(x0_vec), color = palette[2], ls = '-', label = '$m=\log(x/1-x)$', lw = 3)

ax.axhline(0, color = 'black', ls = 'dotted')
ax.set_ylim(-3,3)
ax.set_xlim(0,1)

ax.set_xlabel('input strain frequency $x$')
ax.set_ylabel('output $m$ from encoding function')

ax.legend(loc = 'upper left')

fig.savefig(FIG_DIR + f'example_encodings.pdf', DPI = DPI, bbox_inches = 'tight', pad_inches = PAD_INCHES)
              