# Disentangling contact and ensemble epistasis in a riboswitch
## Daria R. Wonderlick, Julia R. Widom, Michael J. Harms
## Fig 3, 4, 5, 6

In [None]:
FIG_TO_PLOT = "3" ## set to "3" or "4" or "5" for panels in figure 3 or 4 or 5

In [None]:
%matplotlib inline
from matplotlib import pyplot as plt
from matplotlib import patches

import pandas as pd
import numpy as np
import corner
from scipy import stats

import glob
import copy
import inspect
import pickle
import re


SMALL_SIZE = 16
MEDIUM_SIZE = 18
BIGGER_SIZE = 20

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

def plot_curve(df,
               x_column,
               y_column,
               rep_column="replicate",
               area_cutoff=0.95,
               contour_levels=[0.682],
               color="black",
               median_fmt={"linestyle":"-","linewidth":3},
               area_fmt={"linestyle":"none","alpha":0.4},
               contour_fmt={"linestyle":"--","linewidth":1},
               fig=None,
               ax=None,
               median_only=False):
    """
    Plot a curve with uncertainty as areas/dashes.
    
    Parameters
    ----------
    df : pandas.DataFrame
        dataframe with x_column, y_column, and rep_column
    x_column : str
        column with data for x on plot
    y_column : str
        column with data for y on plot
    rep_column : str
        column holding replicate number. function will calculate confidence
        given distribution of y_column values over all unique replicates
    area_cutoff : float
        where to draw lines/areas for uncertainty. Is lower edge of two-
        tailed. Default is 95% confidence (0.025 below and 0.975 above). 
    contour_leveals : list
        draw lines at the specified contour levels (two-tailed). Deafult is
        0.682 (line at 1 standard deviation). 
    color : str
        color to use for plotting series
    median_fmt : dict
        pass these matplotlib keyword arguments to ax.plot(x,y) call for
        the median of the distribution.
    area_fmt : dict
        pass these matplotlib keyword arguments to matplotlib.patches.Polygon
        for each of the uncertainty areas
    contour_fmt : dict
        pass these matplotlib keyword arguments to ax.plot(x,y) call for the
        contour lines. 
    fig : matplotlib.Figure, optional
        fig on which to do plot. if None, create new. 
    ax : matplotlib.Axis, optional
        ax on which to do plot. if None, create new
    
    Returns
    -------
    fig, ax : matplotlib.Figure, matplotlib.ax
        matplotlib objects used for potting. 
    """
    
    # Work on copies of input dictionaies
    median_fmt = copy.deepcopy(median_fmt)
    area_fmt = copy.deepcopy(area_fmt)
    contour_fmt = copy.deepcopy(contour_fmt)
    
    # This allows user to specify a partial dictionary of formats without 
    # having to pass entire formatter with all defaults. If a dictionary 
    # is passed without one of the default arguments, this is added in. 
    to_add_default = {"median_fmt":median_fmt,
                      "area_fmt":area_fmt,
                      "contour_fmt":contour_fmt}
    for d in to_add_default:
        
        # Get default for this parameter. 
        default = inspect.signature(plot_curve).parameters[d].default
        
        # Go through keys in the default dictionary and update the argument
        # passed in by the user with those defaults. 
        for k in default:
            try:
                to_add_default[d][k]
            except KeyError:
                to_add_default[d][k] = default[k]
        
    # If user hasn't specified unique colors for each element, use the main
    # color
    try:
        median_fmt["color"]
    except KeyError:
        median_fmt["color"] = color
        
    try:
        area_fmt["facecolor"]
    except KeyError:
        area_fmt["facecolor"] = color
        
    try:
        contour_fmt["color"]
    except KeyError:
        contour_fmt["color"] = color
    
    # Create figure if needed
    if fig is None:
        fig, ax = plt.subplots(1,figsize=(6,6))
    
    # Make a stack of the y values vs. x
    reps = np.unique(df.loc[:,rep_column])
    y_stack = []
    for r in reps:
        x = np.array(df.loc[df.loc[:,rep_column] == r,x_column])
        y_stack.append(df.loc[df.loc[:,rep_column] == r,y_column])

    y_stack = np.array(y_stack)

    cutoffs = [(1-area_cutoff)/2]
    if contour_levels is not None:
        for c in contour_levels:
            cutoffs.append((1-c)/2)
    
    # Calculate edges for cutoffs, as well as grabbing mean
    median_y = []
    lows = [[] for _ in range(len(cutoffs))]
    highs = [[] for _ in range(len(cutoffs))]
    for i in range(y_stack.shape[1]):

        y = np.sort(y_stack[:,i])
        for j, c in enumerate(cutoffs):
            lows[j].append( y[int(round(    c*y.shape[0],0))])
            highs[j].append(y[int(round((1-c)*y.shape[0],0))])

        median_y.append(y[len(y)//2])

    # Plot median line
    median_y = np.array(median_y)    
    ax.plot(x,median_y,**median_fmt)
    
    if not median_only:

        # Create a polygon for the main area
        patch_x = list(x)
        patch_x.extend(patch_x[::-1])
        patch_x = np.array(patch_x)

        patch_y = lows[0][:]
        patch_y.extend(highs[0][::-1])
        patch_y = np.array(patch_y)

        patch_xy = np.array([patch_x,patch_y]).T

        area = patches.Polygon(patch_xy,**area_fmt)
        ax.add_patch(area)

        # Draw contour lines
        for i in range(1,len(cutoffs)):
            ax.plot(x,lows[i],**contour_fmt)
            ax.plot(x,highs[i],**contour_fmt)


    return fig, ax


def get_dGobs(param,m):
    
    X = np.exp(param[2]*(param[1] + np.log(m)))
    pct = np.exp(param[0])*X/(1 + X)
    return -0.001987*298*np.log(pct)

def load_dGobs_from_file(some_file,m):

    if type(some_file) is str:
        df = pd.read_csv(some_file)
    else:
        df = some_file
    
    out = {"replicate":[],
           "Mt":[],
           "dGobs":[]}

    for i in range(5000):
        
        out["replicate"].extend([i for _ in range(len(m))])
        out["Mt"].extend(m*1e-9)
        out["dGobs"].extend(get_dGobs(df.iloc[i,:],m))

    return pd.DataFrame(out)

def get_additive(csv_00,csv_10,csv_01):
    
    df_00 = pd.read_csv(csv_00)
    df_10 = pd.read_csv(csv_10)
    df_01 = pd.read_csv(csv_01)
    
    df_11 = df_00.copy()
    df_11.loc[:,"logK_mg"] =  df_10.logK_mg  + df_01.logK_mg  - df_00.logK_mg
    df_11.loc[:,"logK_2AP"] = df_10.logK_2AP + df_01.logK_2AP - df_00.logK_2AP
    df_11.loc[:,"n_mg"] =     df_10.n_mg     + df_01.n_mg     - df_00.n_mg

    return df_11



def get_minimum_epistasis_values(cycle,num_attempts=20):

    possible_outputs = []
    for j in range(num_attempts):
    
        all_dfs = []
        for c in cycle:
            all_dfs.append(pd.read_csv(label_to_file[c]))

        take_only = min([len(df) for df in all_dfs])
        for i in range(len(all_dfs)):
            all_dfs[i] = all_dfs[i].sample(n=take_only)

        ep = np.array((all_dfs[3] - all_dfs[2]) - (all_dfs[1] - all_dfs[0]))
        total_ep = np.sum(np.abs(ep),axis=1)
        idx = np.argmin(total_ep)

        me_00 = np.array(all_dfs[0].iloc[idx])
        me_10 = np.array(all_dfs[1].iloc[idx])
        me_01 = np.array(all_dfs[2].iloc[idx])
        me_11 = np.array(all_dfs[3].iloc[idx])

        total = np.sum(np.abs((me_11 - me_10) - (me_01 - me_00)))
        possible_outputs.append((total,me_00,me_10,me_01,me_11))
        
    possible_outputs.sort()

    return possible_outputs[0][1:]

def get_specific_epistasis(cycle,target_epistasis):

    all_dfs = []
    for c in cycle:
        all_dfs.append(pd.read_csv(label_to_file[c]))

    take_only = min([len(df) for df in all_dfs])
    for i in range(len(all_dfs)):
        all_dfs[i] = all_dfs[i].sample(n=take_only)

    ep = np.array((all_dfs[3] - all_dfs[2]) - (all_dfs[1] - all_dfs[0]))
    diffs = np.sum(np.abs(ep - target_epistasis),axis=1)
    idx = np.argmin(diffs)

    me_00 = np.array(all_dfs[0].iloc[idx])
    me_10 = np.array(all_dfs[1].iloc[idx])
    me_01 = np.array(all_dfs[2].iloc[idx])
    me_11 = np.array(all_dfs[3].iloc[idx]) 
    
    return ep[idx],me_00,me_10,me_01,me_11


def get_median_values(samples_file,cutoff_scalar=0.1):
    
    df = pd.read_csv(samples_file)
    for c in df.columns[::-1]:
        values = np.array(df.loc[:,c])
        values.sort()
        lower = values[int(round(len(values)//2 - cutoff_scalar/2*len(values),0))]
        upper = values[int(round(len(values)//2 + cutoff_scalar/2*len(values),0))]

        df = df.loc[np.logical_and(df[c] > lower,df[c] <= upper),:]

    idx = np.random.choice(df.index)
    
    return np.array(df.loc[idx,:])

label_dict = {"xxxxx":"wt",
              "Uxxxx":"a35u",
              "xCxxx":"g38c",
              "xxAxx":"c50a",
              "xxxGx":"c60g",
              "UxxGx":"a35u-c60g",
              "xCxGx":"g38c-c60g",
              "xxAGx":"c50a-c60g"}

mut_to_label_dict = dict([(label_dict[k],k) for k in label_dict])

label_to_file = {}
for k in label_dict:
    label_to_file[label_dict[k]] = f"all-samples/3.3/bayesian/{k}.csv"

files = glob.glob("all-samples/3.3/bayesian/*.csv")
files = [f for f in files if "_" not in f]

# Mg2+ concentration span for calculations
m = 10**np.linspace(2,5,100)

# See if epistasis is bigger than this magnitude when looking for significance
epistasis_cutoff = 0.75

# Plot only median of dG obs curves
median_only = False

## Select figure to plot. 

In [None]:
if FIG_TO_PLOT == "3":

    # FIG 3, 35/60
    cycle = ["wt","c60g","a35u","a35u-c60g"]
    ylims = [[13,15],[15.5,17.5],[0,4]]
    xmin = 12
    xmax = 15
    ymin = 15
    ymax = 18
    kcal_to_pixel = 100
    bad_cycle_file = label_to_file[cycle[3]]

elif FIG_TO_PLOT == "4":
    
    ## FIG 4, 38/60
    cycle = ["wt","c60g","g38c","g38c-c60g"]
    ylims = [[13,15],[15,17],[0,4]]
    xmin = -10
    xmax = 30
    ymin = -55
    ymax = 15
    kcal_to_pixel = 100/80
    bad_cycle_file = label_to_file[cycle[2]]

elif FIG_TO_PLOT == "5":
    
    # FIG 5, 50/60
    cycle = ["wt","c60g","c50a","c50a-c60g"]
    ylims = [[13,15],[15,17],[0,4]]
    xmin = -10
    xmax = 20
    ymin = -50
    ymax = 25
    kcal_to_pixel = 100/100
    bad_cycle_file = label_to_file[cycle[3]]
    
else:
    
    print("FIG_TO_PLOT should be '3', '4' or '5'")


## Fig 3A-C, 4A-C, 5A-C

In [None]:
#cycle = ["wt","a35u","g38c","c50a","c60g","a35u-c60g","g38c-c60g","c50a-c60g"]

out_dfs = {"logK_2AP":{},
           "logK_mg":{},
           "n_mg":{}}

for i in range(len(cycle)):
    df = pd.read_csv(label_to_file[cycle[i]]).sample(n=10000)
    for k in out_dfs:
        if k.startswith("logK"):
            out_dfs[k][cycle[i]] = -0.001987*298*np.array(df.loc[:,k] - np.log(1e9))
        else:
            out_dfs[k][cycle[i]] = np.array(df.loc[:,k])

            
for i, k in enumerate(out_dfs):
    
    out_dfs[k] = pd.DataFrame(out_dfs[k])
    
    df = out_dfs[k]
    medians = []
    lowers = []
    uppers = []
    all_values = []
    for c in df.columns:
        
        # if c == 'g38c' and k != "n_mg":
        #     df[c] = np.nan
        
        values = np.array(df.loc[:,c])/(-0.001987*298)
        values.sort()
        medians.append(values[len(values)//2])
    
        lower = np.round(len(values)*0.025,0)
        upper = np.round(len(values)*0.975,0)
        
        lowers.append(lower)
        uppers.append(upper)
        np.random.shuffle(values)
        all_values.append(values)
        

    fig, ax = plt.subplots(1,figsize=(6,6))

    vp = ax.violinplot(df,
                       showextrema=False,
                       showmedians=True,
                       quantiles=[[0.025,0.975] for _ in range(len(df.columns))])

    vp["cmedians"].set_linewidth(2)
    vp["cmedians"].set_edgecolor("black")
    vp["cquantiles"].set_linewidth(1)
    vp["cquantiles"].set_edgecolor("black")
    
    xlabels = cycle
    
    all_effects = []
    for j in range(len(cycle)//2):
        
        x = 2*j + 1
        y = medians[2*j]
        dx = 1
        dy = medians[2*j + 1] - y
        
        ax.arrow(x,y,dx,dy,length_includes_head=True)
      
        # Get distribution overlaps
        v2 = all_values[2*j].copy()
        v1 = all_values[2*j + 1].copy()

        np.random.shuffle(v1)
        np.random.shuffle(v2)

        eff = v2 - v1
        eff.sort()

        higher = np.sum(eff > 0)
        lower = np.sum(eff < 0)
        
        if higher > lower:
            p_value = 1 - higher/len(eff)
        else:
            p_value = 1 - lower/len(eff)

        if p_value < 0.0001:
            code = "***"
        elif p_value < 0.001:
            code = "**"
        elif p_value < 0.01:
            code = "*"
        else:
            code = ""

        if p_value == 0.0:
            out = f"<{1/len(eff):.3e},{code}"
        else:
            out = f"{p_value:.3e},{code}"

        cred_bottom = eff[int(np.round(0.025*len(eff),0))]
        cred_top = eff[int(np.round(0.975*len(eff),0))]
        cred_median = eff[len(eff)//2]

        print(f"Individual effect in background {cycle[j*2]}:",out)
        print(f"{cred_median:.2f} [{cred_bottom:.2f},{cred_top:.2f}], p: {p_value:.3e}")

        all_effects.append(eff)
        
    
    m1 = all_effects[0].copy()
    m2 = all_effects[1].copy()
    
    np.random.shuffle(m1)
    np.random.shuffle(m2)
    
    
    all_ep = m2 - m1
    all_ep.sort()
    cred_bottom = all_ep[int(np.round(0.025*len(all_ep),0))]
    cred_top = all_ep[int(np.round(0.975*len(all_ep),0))]
    cred_median = all_ep[len(all_ep)//2]   
    
    higher = np.sum(all_ep > 0)
    lower = np.sum(all_ep < 0)

    if higher > lower:
        p_value = 1 - higher/len(all_ep)
    else:
        p_value = 1 - lower/len(all_ep)
    
    
    print("Epistasis")
    print(f"{cred_median:.2f} [{cred_bottom:.2f},{cred_top:.2f}], p: {p_value:.3e}")



    
    ax.set_title(f"{k}")
    
    ax.set_xticks(np.arange(len(xlabels))+1)
    ax.set_xticklabels(xlabels,rotation=90)
    ax.set_ylabel("kcal/mol")
    ax.set_ylim(ylims[i])    
    plt.show()



## Fig 4D, 5D

In [None]:
df = pd.read_csv(bad_cycle_file) 
df = -0.001987*298*(df - np.log(1e9))

x_span = int(np.round((xmax - xmin)*kcal_to_pixel,0))
y_span = int(np.round((ymax - ymin)*kcal_to_pixel,0))

X, Y = np.mgrid[xmin:xmax:(x_span*1j),ymin:ymax:(y_span*1j)]

values = np.vstack([df.logK_2AP, df.logK_mg])
kernel = stats.gaussian_kde(values)

positions = np.vstack([X.ravel(), Y.ravel()])
Z = np.reshape(kernel(positions).T, X.shape)

fig, ax = plt.subplots(1,figsize=(6,6))
ax.imshow(np.rot90(Z),extent=[xmin, xmax, ymin, ymax],cmap="cividis")

x_pos = np.linspace(xmin,xmax,Z.shape[0])
y_pos = np.linspace(ymin,ymax,Z.shape[1])
ax.contour(x_pos,y_pos,np.rollaxis(Z,axis=1),origin="lower",colors=["white"],linewidths=[0.5])

ax.set_xlabel("dG_2AP")
ax.set_ylabel("dG_MG")
ax.set_xticks((xmin,xmax))
ax.set_xticklabels((xmin,xmax))
ax.set_yticks((ymin,ymax))
ax.set_yticklabels((ymin,ymax))




## Fig 4E, 5E

In [None]:
out_dfs = {"logK_2AP":{},
           "logK_mg":{},
           "n_mg":{}}

for i in range(len(cycle)):
    df = pd.read_csv(label_to_file[cycle[i]]).sample(n=10000)
    for k in out_dfs:
        if k.startswith("logK"):
            out_dfs[k][cycle[i]] = -0.001987*298*np.array(df.loc[:,k] - np.log(1e9))
        else:
            out_dfs[k][cycle[i]] = -0.001987*298*(np.array(df.loc[:,k]) - 1)

n_mg_ep = (out_dfs["n_mg"][cycle[3]] - out_dfs["n_mg"][cycle[2]]) - (out_dfs["n_mg"][cycle[1]] - out_dfs["n_mg"][cycle[0]])
logK_2AP_ep = (out_dfs["logK_2AP"][cycle[3]] - out_dfs["logK_2AP"][cycle[2]]) - (out_dfs["logK_2AP"][cycle[1]] - out_dfs["logK_2AP"][cycle[0]])
logK_mg_ep = (out_dfs["logK_mg"][cycle[3]] - out_dfs["logK_mg"][cycle[2]]) - (out_dfs["logK_mg"][cycle[1]] - out_dfs["logK_mg"][cycle[0]])

ep_df = pd.DataFrame({"logK_2AP_ep":logK_2AP_ep,
                      "logK_mg_ep":logK_mg_ep,
                      "n_mg_ep":n_mg_ep})


values = np.vstack([ep_df.logK_2AP_ep, ep_df.logK_mg_ep])
kernel = stats.gaussian_kde(values)

x_span = int(np.round((xmax - xmin)*kcal_to_pixel,0))
y_span = int(np.round((ymax - ymin)*kcal_to_pixel,0))

X, Y = np.mgrid[xmin:xmax:(x_span*1j),ymin:ymax:(y_span*1j)]
positions = np.vstack([X.ravel(), Y.ravel()])
Z = np.reshape(kernel(positions).T, X.shape)
Z = Z/np.sum(Z)

fig, ax = plt.subplots(1,figsize=(6,6))
ax.imshow(np.rot90(Z),extent=[xmin, xmax, ymin, ymax],cmap="cividis")

box = np.array(((-1,-1,1,1),(-1,1,1,-1)))
box = np.rot90(box)
p = patches.Polygon(box,facecolor="none",edgecolor="white")

ax.plot((0,0),(ymin,ymax),'--',color='white',lw=0.5)
ax.plot((xmin,xmax),(0,0),'--',color='white',lw=0.5)

x_pos = np.linspace(xmin,xmax,Z.shape[0])
y_pos = np.linspace(ymin,ymax,Z.shape[1])
ax.contour(x_pos,y_pos,np.rollaxis(Z,axis=1),origin="lower",colors=["white"],linewidths=[0.5])
ax.add_patch(p)

ax.set_xlabel("ddG_2AP")
ax.set_ylabel("ddG_MG")
ax.set_xticks((xmin,xmax))
ax.set_xticklabels((xmin,xmax))
ax.set_yticks((ymin,ymax))
ax.set_yticklabels((ymin,ymax))



## Fig 3D, 4F, 5F

In [None]:

df_00 = load_dGobs_from_file(f"all-samples/3.3/bayesian/{mut_to_label_dict[cycle[0]]}.csv",m)
df_00 = df_00.loc[np.logical_and(df_00.Mt >= 1e-8,df_00.Mt <= 1e-4),:]
fig, ax = plot_curve(df_00,x_column="Mt",y_column="dGobs",median_only=median_only)

df_10 = load_dGobs_from_file(f"all-samples/3.3/bayesian/{mut_to_label_dict[cycle[1]]}.csv",m)
df_10 = df_10.loc[np.logical_and(df_10.Mt >= 1e-8,df_10.Mt <= 1e-4),:]
plot_curve(df_10,x_column="Mt",y_column="dGobs",color="red",fig=fig,ax=ax,median_only=median_only)

df_01 = load_dGobs_from_file(f"all-samples/3.3/bayesian/{mut_to_label_dict[cycle[2]]}.csv",m)   
df_01 = df_01.loc[np.logical_and(df_01.Mt >= 1e-8,df_01.Mt <= 1e-4),:]
plot_curve(df_01,x_column="Mt",y_column="dGobs",color="blue",fig=fig,ax=ax,median_only=median_only)

df_11 = load_dGobs_from_file(f"all-samples/3.3/bayesian/{mut_to_label_dict[cycle[3]]}.csv",m)
df_11 = df_11.loc[np.logical_and(df_11.Mt >= 1e-8,df_11.Mt <= 1e-4),:]
plot_curve(df_11,x_column="Mt",y_column="dGobs",color="purple",fig=fig,ax=ax,median_only=median_only)

ax.set_xscale("log")

ax.set_xlabel("[Mg] (M)")
ax.set_ylabel("dGobs")
ax.set_ylim(1,8)



## Fig 3E, 4G, 5G

In [None]:
df = df_00.copy()
df = df.loc[np.logical_and(df.Mt >= 1e-8,df.Mt <= 1e-4),:]
df["ddGobs"] = (df_11["dGobs"] - df_10["dGobs"]) - (df_01["dGobs"] - df_00["dGobs"])

fig, ax = plot_curve(df,x_column="Mt",y_column="ddGobs")
ax.set_xscale("log")
ax.plot((1*1e-7,1e-4),(0,0),'--',color="gray")

ml_00 = get_median_values(label_to_file[cycle[0]])
ml_10 = get_median_values(label_to_file[cycle[1]])
ml_01 = get_median_values(label_to_file[cycle[2]])
ml_11 = get_median_values(label_to_file[cycle[3]])

dGobs_ml_00 = get_dGobs(ml_00,m)
dGobs_ml_10 = get_dGobs(ml_10,m)
dGobs_ml_01 = get_dGobs(ml_01,m)
dGobs_ml_11 = get_dGobs(ml_11,m)
ddGobs_ml = (dGobs_ml_11 - dGobs_ml_10) - (dGobs_ml_01 - dGobs_ml_00)

# What to use as ref state for additive calculation? g38c/c60g uses 11; others use 00
if cycle[3] == "g38c-c60g":
    ml_additive_01 = ml_11 - (ml_10 - ml_00) 
    dGobs_ml_additive = get_dGobs(ml_additive_01,m)
    ddGobs_ml_additive = (dGobs_ml_11 - dGobs_ml_10) - (dGobs_ml_additive - dGobs_ml_00)
else:
    ml_additive_11 = ml_10 + ml_01 - ml_00 
    dGobs_ml_additive = get_dGobs(ml_additive_11,m) 
    ddGobs_ml_additive = (dGobs_ml_additive - dGobs_ml_10) - (dGobs_ml_01 - dGobs_ml_00)

ax.plot(m*1e-9,ddGobs_ml_additive,color="green",lw=3)

me_00, me_10, me_01, me_11 = get_minimum_epistasis_values(cycle,num_attempts=1)

dGobs_me_00 = get_dGobs(me_00,m)
dGobs_me_10 = get_dGobs(me_10,m)
dGobs_me_01 = get_dGobs(me_01,m)
dGobs_me_11 = get_dGobs(me_11,m)

ddGobs_me = (dGobs_me_11 - dGobs_me_10) - (dGobs_me_01 - dGobs_me_00)

me_additive_11 = me_01 + me_10 - me_00 
dGobs_additive = get_dGobs(me_additive_11,m) 
ddGobs_additive = (dGobs_additive - dGobs_me_10) - (dGobs_me_01- dGobs_me_00)


ax.set_xlabel("[Mg] (nM)")
ax.set_ylabel("epistasis in dGobs")


## Fig 5F

In [None]:
df = pd.read_csv("2AP_corrected.csv")
fig, ax = plt.subplots(1,figsize=(6,6))

colors = {"xxxxx":"black",
          "Uxxxx":"gray",
          "xCxxx":"orange",
          "xxAxx":"blue",
          "xxxGx":"red",
          "UxxGx":"pink",
          "xCxGx":"green",
          "xxAGx":"purple"}

#geno_list = np.unique(df.Geno)
geno_list = ["xxxxx","xxAxx","xxxGx","xxAGx"]

for geno in geno_list:
    
    df_geno = df.loc[df.Geno == geno,:]
    dG_list = []
    mg_list = np.array(np.unique(df_geno.Mg))
    mg_list.sort()
    for mg in mg_list:
        this_df = df_geno.loc[df_geno.Mg == mg,:]

        da = this_df.FS_mean[this_df.FS_mean != 0]*5e-8
        rna = this_df.Rna[this_df.FS_mean != 0]*1e-6
        pct = da/(rna - da)

        dG = -0.001987*298*np.log(pct)
        dG_list.append(np.mean(dG))
                
        ax.scatter( mg*1e-3,np.mean(dG),s=30,color=colors[geno])
        ax.errorbar(mg*1e-3,np.mean(dG),yerr=np.std(dG),color=colors[geno],capsize=5)
    
    ax.plot(mg_list*1e-3,dG_list,"-",lw=2,color=colors[geno],label=geno)

fig.legend()
    
    
ax.set_xscale("log")





## Fig 6

In [None]:
RT = 0.001987*298
mg = 10**np.linspace(-16,-7,100)
epistasis_scalar = 1/RT
n_epistasis_scalar = 0.5

fig, ax = plt.subplots(3,1,figsize=(6,18),sharex=True)
ax[0].plot((1e-7,100),(0,0),'--',color='gray')
ax[1].plot((1e-7,100),(0,0),'--',color='gray')
ax[2].plot((1e-7,100),(0,0),'--',color='gray')

style = {"lw":3}

g00 = np.array([23,26,1.5]) 
dGobs_00 = get_dGobs(g00,mg)
m1 = 1/RT
m2 = 1/RT

for i in range(-3,4):

    if i < 0:
        color = [1,1+i/3,1+i/3,1]
    elif i == 0:
        color = [0.0,0.0,0.0,1]
    else:
        color = [1-i/3,1-i/3,1,1]
    
    g10 = g00 + np.array([m1,0,0])
    g01 = g00 + np.array([m2,0,0])
    g11 = g10 + g01 - g00 + np.array((i*epistasis_scalar,0,0))
    
    dGobs_10 = get_dGobs(g10,mg)
    dGobs_01 = get_dGobs(g01,mg)
    dGobs_11 = get_dGobs(g11,mg)

    ax[0].plot(mg*1e9,(dGobs_11 - dGobs_10) - (dGobs_01 - dGobs_00),color=color,**style)
    
    g10 = g00 + np.array([0,m1,0])
    g01 = g00 + np.array([0,m2,0])
    g11 = g10 + g01 - g00 + np.array((0,-i*epistasis_scalar,0))
    
    dGobs_10 = get_dGobs(g10,mg)
    dGobs_01 = get_dGobs(g01,mg)
    dGobs_11 = get_dGobs(g11,mg)

    ax[1].plot(mg*1e9,(dGobs_11 - dGobs_10) - (dGobs_01 - dGobs_00),color=color,**style)

    g10 = g00 + np.array([0,0,0.5])
    g01 = g00 + np.array([0,0,0.5])
    g11 = g10 + g01 - g00 + np.array((0,0,-i*n_epistasis_scalar))
    
    dGobs_10 = get_dGobs(g10,mg)
    dGobs_01 = get_dGobs(g01,mg)
    dGobs_11 = get_dGobs(g11,mg)

    ax[2].plot(mg*1e9,(dGobs_11 - dGobs_10) - (dGobs_01 - dGobs_00),color=color,**style)

for i in range(3):
    ax[i].set_ylim(-5,5)
    ax[i].set_ylabel("epistasis in dGobs")
    ax[i].set_yticks(np.arange(-5,6))

ax[0].set_xscale('log')
ax[2].set_xscale('log')
ax[2].set_xlabel("Mg")

