### Cell entry validation with luciferase assay
This script reads in luciferase values from a csv file and calculates correlation with DMS cell entry effects and plots using altair.

In [None]:
import re
import altair as alt
import numpy as np
import pandas as pd
import scipy.stats
import httpimport

# allow more rows for Altair
_ = alt.data_transformers.disable_max_rows()

In [None]:
# Import custom altair theme from remote github using httpimport module
def import_theme_new():
    with httpimport.github_repo("bblarsen-sci", "altair_themes", "main"):
        import main_theme

        @alt.theme.register("custom_theme", enable=True)
        def custom_theme():
            return main_theme.main_theme()


import_theme_new()

### Load experimental luciferase data and DMS functional effects

In [None]:
# read in experimental validation data
validation_df = pd.read_csv(snakemake.input.validation_df)

# now read in entry DMS data
entry_dms_df = pd.read_csv(snakemake.input.entry_df)

# add a column for the mutation to make it easy to merge with validation data
entry_dms_df["mutation"] = (
    entry_dms_df["wildtype"] + entry_dms_df["site"].astype(str) + entry_dms_df["mutant"]
)

In [None]:
# calculate mean and standard deviation of luciferase values for each mutation
avg_RLU_df = (
    validation_df.groupby("mutation")
    .agg(
        mean_RLU=("mean_luciferase", "mean"),
        std_RLU=("mean_luciferase", "std"),
    )
    .round(1)
    .reset_index()
)

# add columns for upper and lower bounds for error bars
avg_RLU_df["upper"] = avg_RLU_df["mean_RLU"] + avg_RLU_df["std_RLU"]
avg_RLU_df["lower"] = avg_RLU_df["mean_RLU"] - avg_RLU_df["std_RLU"]
display(avg_RLU_df)

In [None]:
# Calculate relative luciferase values for each mutation relative to mean unmutated value
unmutated_value = avg_RLU_df.loc[
    avg_RLU_df["mutation"] == "Unmutated", "mean_RLU"
].values[0]
print(f'unmutated_value: {unmutated_value}')

validation_df["relative_luciferase"] = (
    validation_df["mean_luciferase"] / unmutated_value
).round(3)

# merge with entry DMS data
avg_merged_df = validation_df.merge(entry_dms_df, on=["mutation"], how="left")

# Assign log2 values to dataframe and apply clipping to lower end
avg_merged_df["log2_RLU"] = np.log2(avg_merged_df["relative_luciferase"]).clip(lower=-4)
display(avg_merged_df)

In [None]:
# summarize the merged dataframe to get mean and standard deviation of log2 relative luciferase values for each mutation
avg_merged_df_summary = (
    avg_merged_df.groupby("mutation")
    .agg(
        effect=("effect", "first"),
        mean_log2_RLU=("log2_RLU", "mean"),
        std_log2_RLU=("log2_RLU", "std"),
    )
    .round(3)
    .reset_index()
)

# add columns for upper and lower bounds for error bars
avg_merged_df_summary["upper"] = (
    avg_merged_df_summary["mean_log2_RLU"] + avg_merged_df_summary["std_log2_RLU"]
)
avg_merged_df_summary["lower"] = (
    avg_merged_df_summary["mean_log2_RLU"] - avg_merged_df_summary["std_log2_RLU"]
)
# assign 0 to unmutated values
avg_merged_df_summary.loc[
    avg_merged_df_summary["mutation"] == "Unmutated",
    ["effect", "mean_log2_RLU", "std_log2_RLU", "upper", "lower"],
] = 0, 0, 0, 0, 0
display(avg_merged_df_summary)


In [None]:
##### calculate R value:
slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(
    avg_merged_df_summary["effect"], avg_merged_df_summary["mean_log2_RLU"]
)
r_value = float(r_value)
print(f'r_value: {r_value:.2f}')

In [None]:
# Sorting function to put 'Unmutated' on top of the legend, followed by numerical order
def custom_sort_order(array):
    # Sort based on the numerical part in mutation strings, e.g., '530' in 'Q530F'
    def extract_number(mutation):
        num = re.search(r"\d+", mutation)
        return int(num.group()) if num else 0

    array = sorted(array, key=extract_number)

    # Move 'Unmutated' to the beginning of the list
    if "Unmutated" in array:
        array.remove("Unmutated")
        array.insert(0, "Unmutated")
    return array


# Define the category10 colors manually
category10_colors = [
    "#5778a4",
    "#e49444",
    "#d1615d",
    "#85b6b2",
    "#6a9f58",
    "#e7ca60",
    "#a87c9f",
    "#f1a2a9",
    "#967662",
    "#b8b0ac",
]

# Adjust colors based on the unique mutations
colors = ["black"] + category10_colors[
    : len(avg_merged_df_summary["mutation"].unique()) - 1
]

In [None]:
# Create the Altair chart
corr_chart_relative = (
    alt.Chart(avg_merged_df_summary)
    .mark_circle(size=100, stroke="black", strokeWidth=1, opacity=1)
    .encode(
        x=alt.X(
            "effect:Q",
            title="Cell Entry in DMS",
            scale=alt.Scale(domain=[-4, 1]),
            axis=alt.Axis(values=[-4, -3, -2, -1, 0, 1], tickCount=6),
        ),
        y=alt.Y(
            "mean_log2_RLU",
            title=["Relative RLUs", "Compared to Unmutated F"],
        ),
        color=alt.Color(
            "mutation",
            title="Virus",
            scale=alt.Scale(
                domain=custom_sort_order(avg_merged_df_summary["mutation"].unique()),
                range=colors,
            ),
        ),
        tooltip=["mutation", "effect"],
    )
)

min_effect = int(avg_merged_df_summary["effect"].min())
max_mean_luciferase = int(avg_merged_df_summary["mean_log2_RLU"].max())

text = (
    alt.Chart(
        {
            "values": [
                {
                    "x": min_effect,
                    "y": max_mean_luciferase,
                    "text": f"r = {r_value:.2f}",
                }
            ]
        }
    )
    .mark_text(
        align="left",
        baseline="top",
        dx=-20,
        dy=-25,
    )
    .encode(x=alt.X("x:Q"), y=alt.Y("y:Q"), text="text:N")
)

error = (
    alt.Chart(avg_merged_df_summary)
    .mark_errorbar()
    .encode(
        x=alt.X("effect:Q"),
        y=alt.Y(
            "lower",
            title=["Relative RLUs", "Compared to Unmutated F"],
        ),
        y2="upper",
        color=alt.Color(
            "mutation",
            scale=alt.Scale(
                domain=custom_sort_order(avg_merged_df_summary["mutation"].unique()),
                range=colors,
            ),
        ),
    )
)

# combine the charts
final_chart_relative = error + corr_chart_relative + text 
final_chart_relative = final_chart_relative.properties(
    height=200, width=200
)
final_chart_relative.display()


In [None]:
# save the chart
final_chart_relative.save(
    snakemake.output.cell_entry_validation_correlation_plot_png, ppi=300
)
final_chart_relative.save(snakemake.output.cell_entry_validation_correlation_plot)
