In [None]:
# compare functional scores and functional effects obtained from global epistasis models
# also generate times_seen distribution chart and correlation between LibA and LibB effects

In [None]:
import pandas as pd

import altair as alt

import numpy as np

import scipy.stats

import httpimport

_ = alt.data_transformers.disable_max_rows()

In [None]:
# Import custom altair theme from remote github using httpimport module
def import_theme_new():
    with httpimport.github_repo("bblarsen-sci", "altair_themes", "main"):
        import main_theme

        @alt.theme.register("custom_theme", enable=True)
        def custom_theme():
            return main_theme.main_theme()


import_theme_new()

### Define lists of selections

In [None]:
LibA_selections = [
    "LibA-250311-CHO-bEFNB3-1",
    "LibA-250311-CHO-bEFNB3-2",
    "LibA-250409-CHO-bEFNB3-1",
    "LibA-250409-CHO-bEFNB3-2",
]

LibB_selections = [
    "LibB-250303-CHO-bEFNB3-1",
    "LibB-250303-CHO-bEFNB3-2",
    "LibB-250414-CHO-bEFNB3-1",
    "LibB-250414-CHO-bEFNB3-2",
]


### Import functional scores and effects for each selection

In [None]:
def get_func_scores_and_effects(selections, library_name):
    """
    Load and process functional scores and effects for given selections.

    Args:
        selections: List of selection identifiers
        library_name: Name of the library (e.g., 'LibA', 'LibB')

    Returns:
        tuple: (func_scores DataFrame, func_effect DataFrame)
    """
    func_scores_list = []
    func_effect_list = []

    for sel in selections:
        # Load functional scores
        func_score = pd.read_csv(
            f"../../../results/func_scores/{sel}_func_scores.csv", na_filter=False
        ).assign(sample=sel, library=library_name)
        func_scores_list.append(func_score)

        # Load functional effects
        func_effect = pd.read_csv(
            f"../../../results/func_effects/by_selection/{sel}_func_effects.csv",
            na_filter=False,
        ).assign(sample=sel, library=library_name)

        # Filter and process - use .copy() to avoid SettingWithCopyWarning
        func_effect = func_effect.query("wildtype != mutant").copy()
        func_effect["times_seen"] = func_effect["times_seen"].astype(int)
        func_effect_list.append(func_effect)

    # Concatenate all dataframes
    func_scores = pd.concat(func_scores_list, ignore_index=True)
    func_effect = pd.concat(func_effect_list, ignore_index=True)

    return func_scores, func_effect


# Process both libraries
LibA_func_scores, LibA_func_effect = get_func_scores_and_effects(
    LibA_selections, "LibA"
)
LibB_func_scores, LibB_func_effect = get_func_scores_and_effects(
    LibB_selections, "LibB"
)

# Combine all data
all_func_scores = pd.concat([LibA_func_scores, LibB_func_scores], ignore_index=True)
all_func_effect = pd.concat([LibA_func_effect, LibB_func_effect], ignore_index=True)

# Display results
display(all_func_scores)
display(all_func_effect)


### Get distribution of times_seen for each library after filtering


In [None]:
all_func_effects_filtered = all_func_effect[
    (all_func_effect['mutant'] != '*') &
    (all_func_effect['mutant'] != '-')
]

all_effects_grouped = (
    all_func_effects_filtered.groupby(["library", "site", "mutant"])["times_seen"]
    .mean()
    .reset_index()
)

display(all_effects_grouped)

In [None]:
times_seen_distribution_chart = (
    alt.Chart(all_effects_grouped)
    .mark_bar(opacity=0.6)
    .encode(
        x=alt.X(
            "times_seen_binned:Q",
            title="Barcode Coverage of Mutations",
        ),
        y=alt.Y("count():Q", title="Count").stack(None),
        color=alt.Color(
            "library:N",
            title="Library",
        ),
    )
    .transform_bin("times_seen_binned", field="times_seen", bin=alt.Bin(step=1))
).properties(
    width=alt.Step(20),
    height=200,
)
display(times_seen_distribution_chart)


#times_seen_distribution_chart.save(
#    "../../results/figures/mutation_distribution/times_seen_distribution_chart.png",
#    ppi=300,
#)
#times_seen_distribution_chart.save(
#    "../../results/figures/mutation_distribution/times_seen_distribution_chart.svg"
#)
#

### Filter functional scores by single mutations, then add wildtype, mutant, and site information

In [None]:
def process_func_scores(func_scores):
    tmp_df = func_scores.query("n_aa_substitutions == 1").copy()

    tmp_df["aa_substitutions"] = tmp_df["aa_substitutions"].astype(str)
    tmp_df["wildtype"] = tmp_df["aa_substitutions"].str[0]
    tmp_df["mutant"] = tmp_df["aa_substitutions"].str[-1]
    tmp_df["site"] = tmp_df["aa_substitutions"].str.extract(r"([0-9]+)").astype(int)
    tmp_df = tmp_df.query('mutant != "*" & mutant != "-"').copy()
    return tmp_df


all_func_scores_single = process_func_scores(all_func_scores)

# group by library, sample, wildtype, mutant, site and calculate mean functional score and times seen
func_scores_grouped = (
    all_func_scores_single.groupby(["library", "sample", "wildtype", "mutant", "site"])
    .agg(
        mean_func_score=("func_score", "mean"),
        times_seen=("n_aa_substitutions", "count"),
    )
    .reset_index()
)

# apply clipping to mean_func_score to match the clipping that is applied to the functional effects in multiDMS
func_scores_grouped["mean_func_score"] = func_scores_grouped["mean_func_score"].clip(
    lower=-4
)


# Now merge the grouped functional scores with the functional effects
merged_func_scores_effects = pd.merge(
    func_scores_grouped,
    all_func_effects_filtered,
    how="left",
    on=["library", "sample", "wildtype", "mutant", "site"],
    suffixes=("_func_score", "_func_effect"),
)
display(merged_func_scores_effects.head(3))

In [None]:
def plot_corr_w_pearson_r(
    df, x, y, x_axis_title, y_axis_title, x_rvalue, y_rvalue, tooltip_list
):
    """
    Function to plot correlation with Pearson's r value
    Args:
        df (pd.DataFrame): DataFrame containing the data to plot
        x (str): Column name for x-axis
        y (str): Column name for y-axis
        x_axis_title (str): Title for x-axis
        y_axis_title (str): Title for y-axis
        x_rvalue (float): X coordinate for r value text
        y_rvalue (float): Y coordinate for r value text
        tooltip_list (list): List of columns to display in tooltip
    Returns:
        alt.Chart: Altair chart object
    """
    # Get the unique samples
    samples = df["sample"].unique().tolist()
    print(f"Samples: {samples}")
    
    empty_chart = []
    for sample in samples:
        # Process the data for each sample
        tmp_df = df.query(f"sample == '{sample}'").copy()
        tmp_df = tmp_df.round(2)
        tmp_df['times_seen_func_effect'] = tmp_df['times_seen_func_effect'].astype(int)
        #tmp_df = tmp_df.query("times_seen_func_effect >= 2").copy()
        
        # calculate R value:
        slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(
            tmp_df[x], tmp_df[y]
        )
        r_value = float(r_value)
        print(f"r_value: {r_value:.2f}")

        # Create a plot with Altair
        predicate = alt.datum.times_seen_func_score < 2
        chart = (
            alt.Chart(tmp_df, title=sample)
            .mark_point(opacity=0.5)
            .encode(
                x=alt.X(x, title=x_axis_title),
                y=alt.Y(y, title=y_axis_title),
                color=alt.when(predicate)
                .then(alt.value("#d1615d"))
                .otherwise(alt.value("#b8b0ac")),
                tooltip=tooltip_list,
            )
        )
        text = (
            alt.Chart(
                {
                    "values": [
                        {
                            "x": x_rvalue,
                            "y": y_rvalue,
                            "text": f"r = {r_value:.2f}",
                        }
                    ]
                }
            )
            .mark_text(
                dx=10,
                dy=0,
                align="left",
            )
            .encode(x=alt.X("x:Q"), y=alt.Y("y:Q"), text="text:N")
        )
        combined_chart = chart + text
        empty_chart.append(combined_chart)
    # Combine the charts
    chart = alt.vconcat(*empty_chart).resolve_scale(y='shared')
    return chart




In [None]:
x = "functional_effect"
y = "mean_func_score"
x_axis_title = "Functional effect"
y_axis_title = ["Mean single mutation", "functional score"]
tooltip_list = [
    "site",
    "wildtype",
    "mutant",
    x,
    y,
    "times_seen_func_score",
    "times_seen_func_effect",
]
x_rvalue = -4
y_rvalue = 1.5


corr_chart = plot_corr_w_pearson_r(
    merged_func_scores_effects,
    x,
    y,
    x_axis_title,
    y_axis_title,
    x_rvalue,
    y_rvalue,
    tooltip_list,
)
corr_chart.display()

In [None]:
corr_chart.save(
    "../../results/figures/library_correlations/LibA_and_B_func_score_vs_func_effect.svg"
)
corr_chart.save(
    "../../results/figures/library_correlations/LibA_and_B_func_score_vs_func_effect.png",
    ppi=300,
)

### Now plot correlation between libraries A and B

In [None]:
#display(all_func_effects_filtered)

func_effects_aggregated = (
    all_func_effects_filtered.groupby(["library", "site", "mutant", "wildtype"])
    .agg(
        mean_func_effect=("functional_effect", "mean"),
        times_seen=("times_seen", "mean")
    )
    .reset_index()
    .query("times_seen >= 2")
)

# now split by library
LibA_func_effects_aggregated = func_effects_aggregated.query("library == 'LibA'").copy()
LibB_func_effects_aggregated = func_effects_aggregated.query("library == 'LibB'").copy()

merged_effects = pd.merge(
    LibA_func_effects_aggregated,
    LibB_func_effects_aggregated,
    how="inner",
    on=["site", "mutant", "wildtype"],
    suffixes=("_LibA", "_LibB"),
).drop(columns=["library_LibA", "library_LibB"])
display(merged_effects.head())

In [None]:
# calculate R value:
slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(
    merged_effects["mean_func_effect_LibA"], merged_effects["mean_func_effect_LibB"]
)
r_value = float(r_value)
print(f"r_value: {r_value:.2f}")

In [None]:
chart = (
    alt.Chart(merged_effects)
    .mark_circle(size=20, opacity=0.1, stroke="black", strokeWidth=0.5, color="gray")
    .encode(
        x=alt.X("mean_func_effect_LibA", title="Mean functional effect LibA"),
        y=alt.Y("mean_func_effect_LibB", title="Mean functional effect LibB"),
        tooltip=[
            "wildtype",
            "mutant",
            "site",
            "mean_func_effect_LibA",
            "mean_func_effect_LibB",
            "times_seen_LibA",
            "times_seen_LibB",
        ],
    )
    .properties(
        width=200,
        height=200,
    )
)

text = (
        alt.Chart(
            {
                "values": [
                    {
                        "x": -4,
                        "y": 0.75,
                        "text": f"r = {r_value:.2f}",
                    }
                ]
            }
        )
        .mark_text(
            dx=10,
            dy=0,
            align="left",
        )
        .encode(x=alt.X("x:Q"), y=alt.Y("y:Q"), text="text:N")
    )
combined_chart = chart + text

display(combined_chart)
combined_chart.save(
    "../../results/figures/library_correlations/LibA_and_B_func_effects_correlation.png", ppi=300
)
combined_chart.save("../../results/figures/library_correlations/LibA_and_B_func_effects_correlation.svg")