# library_correlations
This script calculates stats about functional selections and variants present in libraries. 
- The median of all LibA selections vs the median of all LibB selections
- All selections for a specific condition
- Histogram of variants by # of mutations
- Distribution of functional scores

- written by Brendan Larsen

In [None]:
# this cell is tagged as parameters for `papermill` parameterization
altair_config = None
nipah_config = None

codon_variants_file = None

CHO_corr_plot_save = None
CHO_EFNB2_indiv_plot_save = None
CHO_EFNB3_indiv_plot_save = None

histogram_plot = None
func_scores_plot = None
uniq_barcodes_per_lib_df = None

In [None]:
import math
import os
import re
import altair as alt

import numpy as np

import pandas as pd

import scipy.stats

import Bio.SeqIO
import yaml
from Bio import AlignIO
from Bio import PDB
from Bio.Align import PairwiseAligner
from collections import Counter

In [None]:
# allow more rows for Altair
_ = alt.data_transformers.disable_max_rows()

if (
    os.getcwd()
    == "/fh/fast/bloom_j/computational_notebooks/blarsen/2023/Nipah_Malaysia_RBP_DMS/"
):
    pass
    print("Already in correct directory")
else:
    os.chdir(
        "/fh/fast/bloom_j/computational_notebooks/blarsen/2023/Nipah_Malaysia_RBP_DMS/"
    )
    print("Setup in correct directory")

For running interactively:

In [None]:
if histogram_plot is None:
    altair_config = "data/custom_analyses_data/theme.py"
    nipah_config = "nipah_config.yaml"
    codon_variants_file = "results/variants/codon_variants.csv"


Read in config files

In [None]:
if altair_config:
    with open(altair_config, "r") as file:
        exec(file.read())

with open(nipah_config) as f:
    config = yaml.safe_load(f)

with open("data/func_effects_config.yml") as y:
    config_func = yaml.safe_load(y)

In [None]:
cho_efnb2_low_selections = config_func["avg_func_effects"]["CHO_EFNB2_low"][
    "selections"
]
LibA_CHO_EFNB2 = [
    selection + "_func_effects.csv"
    for selection in cho_efnb2_low_selections
    if "LibA" in selection and "LibB" not in selection
]
LibB_CHO_EFNB2 = [
    selection + "_func_effects.csv"
    for selection in cho_efnb2_low_selections
    if "LibB" in selection and "LibA" not in selection
]

cho_efnb3_low_selections = config_func["avg_func_effects"]["CHO_EFNB3_low"][
    "selections"
]
LibA_CHO_EFNB3 = [
    selection + "_func_effects.csv"
    for selection in cho_efnb3_low_selections
    if "LibA" in selection and "LibB" not in selection
]
LibB_CHO_EFNB3 = [
    selection + "_func_effects.csv"
    for selection in cho_efnb3_low_selections
    if "LibB" in selection and "LibA" not in selection
]

### Calculate correlations for LibA and LibB for CHO-EFNB2 cell entry selections

In [None]:
# Define the base directory where CSV files are stored
path = "results/func_effects/by_selection/"


# Function to process functional selections from a specific library
def process_func_selections(library, library_name):
    df_list = []  # Initialize a list to store dataframes
    clock = 1  # Counter to uniquely name columns for each file

    # Loop through each file in the library
    for file_name in library:
        file_path = os.path.join(path, file_name)  # Construct the full file path

        # Read the CSV file into a dataframe, then rename and drop specific columns
        func_scores = pd.read_csv(file_path)
        func_scores_renamed = func_scores.rename(
            columns={
                "functional_effect": f"functional_effect_{clock}",
                "times_seen": f"times_seen_{clock}",
            }
        )
        func_scores_renamed.drop(["latent_phenotype_effect"], axis=1, inplace=True)

        df_list.append(func_scores_renamed)  # Append modified dataframe to the list
        clock += 1  # Increment counter

    # Merge all dataframes on 'site', 'mutant', and 'wildtype' columns
    merged_df = df_list[0]
    for df in df_list[1:]:
        merged_df = pd.merge(
            merged_df, df, on=["site", "mutant", "wildtype"], how="outer"
        )

    # Calculate median values of functional effects and times seen across all files
    lib_columns_func = [col for col in merged_df.columns if "functional_effect" in col]
    merged_df[f"median_effect_{library_name}"] = merged_df[lib_columns_func].median(
        axis=1
    )
    lib_columns_times_seen = [col for col in merged_df.columns if "times_seen" in col]
    merged_df[f"median_times_seen_{library_name}"] = merged_df[
        lib_columns_times_seen
    ].median(axis=1)

    # Drop intermediate columns used for calculations
    lib_columns = [col for col in merged_df.columns if re.search(r"_\d+", col)]
    merged_df = merged_df.drop(columns=lib_columns)
    return merged_df


# Process selections for two libraries and two cell types
A_selections_E2 = process_func_selections(LibA_CHO_EFNB2, "LibA")
B_selections_E2 = process_func_selections(LibB_CHO_EFNB2, "LibB")
A_selections_E3 = process_func_selections(LibA_CHO_EFNB3, "LibA")
B_selections_E3 = process_func_selections(LibB_CHO_EFNB3, "LibB")


# Function to merge selections from two libraries
def merge_selections(A_selections, B_selections):
    merged_selections = pd.merge(
        A_selections, B_selections, on=["wildtype", "site", "mutant"], how="inner"
    )
    lib_columns_times_seen = [
        col for col in merged_selections.columns if "times_seen" in col
    ]
    merged_selections["times_seen"] = merged_selections[lib_columns_times_seen].median(
        axis=1
    )
    return merged_selections


# Merge selections and add cell type information
CHO_EFNB2_merged = merge_selections(A_selections_E2, B_selections_E2)
CHO_EFNB2_merged["cell_type"] = "CHO-EFNB2"
CHO_EFNB3_merged = merge_selections(A_selections_E3, B_selections_E3)
CHO_EFNB3_merged["cell_type"] = "CHO-EFNB3"

# Concatenate merged selections for both cell types
both_selections = pd.concat([CHO_EFNB2_merged, CHO_EFNB3_merged])


# Function to generate and display a scatter plot comparing median effects from two libraries
def make_chart_median(df, title):
    slider = alt.binding_range(min=1, max=25, step=1, name="times_seen")
    selector = alt.param(name="SelectorName", value=1, bind=slider)

    empty = []
    variant_selector = alt.selection_point(
        on="mouseover", empty=False, nearest=True, fields=["site", "mutant"], value=1
    )

    df = df[(df["median_effect_LibA"].notna()) & (df["median_effect_LibB"].notna())]
    size = df.shape[0]

    for selection in ["CHO-EFNB2", "CHO-EFNB3"]:
        print(selection)
        tmp_df = df[df["cell_type"] == selection]
        slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(
            df[f"median_effect_LibA"], df[f"median_effect_LibB"]
        )
        r_value = float(r_value)
        print(f"{r_value:.2f}")

        chart = (
            alt.Chart(tmp_df, title=f"Entry in {selection} cells")
            .mark_point()
            .encode(
                x=alt.X("median_effect_LibA", title="LibA Cell Entry"),
                y=alt.Y("median_effect_LibB", title="LibB Cell Entry"),
                tooltip=["site", "wildtype", "mutant", "times_seen"],
                size=alt.condition(variant_selector, alt.value(100), alt.value(15)),
                color=alt.condition(
                    alt.datum.times_seen < selector,
                    alt.value("transparent"),
                    alt.value("black"),
                ),
                opacity=alt.condition(variant_selector, alt.value(1), alt.value(0.2)),
            )
        )
        empty.append(chart)
    combined_effect_chart = (
        alt.hconcat(*empty)
        .resolve_scale(y="shared", x="shared", color="independent")
        .add_params(variant_selector, selector)
    )
    return combined_effect_chart


CHO_EFNB2_corr_plot = make_chart_median(both_selections, "CHO-EFNB2")
CHO_EFNB2_corr_plot.display()
if histogram_plot is not None:
    CHO_EFNB2_corr_plot.save(CHO_corr_plot_save)

In [None]:
def plot_corr_heatmap(df):
    empty_chart = []

    for cell in list(df["cell_type"].unique()):
        tmp_df = df[df["cell_type"] == cell]
        chart = (
            alt.Chart(tmp_df, title=f"{cell}")
            .mark_rect()
            .encode(
                x=alt.X("median_effect_LibA", title="Library A").bin(
                    maxbins=20
                ),  # axis=alt.Axis(values=[-4,-1,0,1])
                y=alt.Y("median_effect_LibB", title="Library B").bin(
                    maxbins=20
                ),  # ,axis=alt.Axis(values=[-4,-1,0,1])
                color=alt.Color("count()", title="Count").scale(scheme="greenblue"),
                # tooltip=['effect','binding_median']
            )
        )
        empty_chart.append(chart)

    combined_chart = alt.hconcat(
        *empty_chart, title=alt.Title("Correlation between binding and entry")
    ).resolve_scale(y="shared", x="shared", color="shared")
    return combined_chart


entry_binding_corr_heatmap = plot_corr_heatmap(both_selections)
entry_binding_corr_heatmap.display()
# entry_binding_corr_heatmap.save(entry_binding_corr_heatmap)

### Make correlations between individual selections

In [None]:
def process_individ_selections(library):
    df_list = []
    clock = 1
    for file_name in library:
        file_path = os.path.join(path, file_name)
        print(f"Processing file: {file_path}")
        # Read the current CSV file
        func_scores = pd.read_csv(file_path)
        func_scores_renamed = func_scores.rename(
            columns={
                "functional_effect": f"functional_effect_{clock}",
                "times_seen": f"times_seen_{clock}",
            }
        )
        func_scores_renamed.drop(["latent_phenotype_effect"], axis=1, inplace=True)

        # Append the dataframe to the list
        df_list.append(func_scores_renamed)
        clock = clock + 1

    # Merge all dataframes on 'site' and 'mutant'
    merged_df = df_list[0]
    for df in df_list[1:]:
        merged_df = pd.merge(
            merged_df, df, on=["site", "mutant", "wildtype"], how="outer"
        )
    # Make list of how many selections are done for later correlation plots
    lib_size = len(library)
    number_list = [i for i in range(1, lib_size + 1)]
    return merged_df, number_list


CHO_EFNB2_indiv, lib_size_EFNB2 = process_individ_selections(
    LibA_CHO_EFNB2 + LibB_CHO_EFNB2
)
CHO_EFNB3_indiv, lib_size_EFNB3 = process_individ_selections(
    LibA_CHO_EFNB3 + LibB_CHO_EFNB3
)


def make_chart(df, number_list):
    empty_list = []
    for i in number_list:
        other_empty_list = []
        for j in number_list:
            df = df[
                (df[f"times_seen_{i}"] >= config["func_times_seen_cutoff"])
                & (df[f"times_seen_{j}"] >= config["func_times_seen_cutoff"])
                & (df[f"functional_effect_{i}"].notna())
                & (df[f"functional_effect_{j}"].notna())
            ]

            chart = (
                alt.Chart(df)
                .mark_circle(size=10, color="black", opacity=0.2)
                .encode(
                    x=alt.X(f"functional_effect_{i}"),
                    y=alt.Y(f"functional_effect_{j}"),
                    tooltip=["site", "wildtype", "mutant"],
                )
                .properties(height=alt.Step(10), width=alt.Step(10))
            )
            other_empty_list.append(chart)
        combined_effect_chart = alt.hconcat(*other_empty_list).resolve_scale(
            y="shared", x="shared", color="independent"
        )
        empty_list.append(combined_effect_chart)
    final_combined_chart = alt.vconcat(*empty_list).resolve_scale(
        y="shared", x="shared", color="independent"
    )
    return final_combined_chart


CHO_EFNB2_indiv_plot = make_chart(CHO_EFNB2_indiv, lib_size_EFNB2)
if histogram_plot is not None:
    CHO_EFNB2_indiv_plot.save(CHO_EFNB2_indiv_plot_save)
CHO_EFNB3_indiv_plot = make_chart(CHO_EFNB3_indiv, lib_size_EFNB3)
if histogram_plot is not None:
    CHO_EFNB3_indiv_plot.save(CHO_EFNB3_indiv_plot_save)

# Now make histogram of variants

In [None]:
codon_variants = pd.read_csv(codon_variants_file)
display(codon_variants.head(3))
unique_barcodes_per_library = codon_variants.groupby("library")["barcode"].nunique()

uniq_barcodes_per_lib = pd.DataFrame(unique_barcodes_per_library)
display(uniq_barcodes_per_lib)

### Find which sites are present, and which are missing

In [None]:
# Initialize an empty dictionary to keep track of observed mutations
aa_counts = {}
wildtypes = {}  # Dictionary to keep track of wildtype amino acids for each site


# Function to process each cell, update counts, and record wildtype amino acids
def process_cell(cell):
    if pd.notna(cell):  # Check if cell is not NaN
        substitutions = cell.split()
        for substitution in substitutions:
            if substitution[-1] not in ("*", "-") and substitution[0] not in (
                "*"
            ):  # Skip if substitution ends with '*' or '-'
                site = substitution[1:-1]
                mutation = substitution[-1]
                wildtype = substitution[0]
                site_mutation = site + mutation
                if site not in wildtypes:
                    wildtypes[site] = wildtype
                if site_mutation in aa_counts:
                    aa_counts[site_mutation] += 1
                else:
                    aa_counts[site_mutation] = 1


empty_mutants = []
empty_percent = []
for lib in ["LibA", "LibB"]:
    # Apply the function to each cell in the 'aa_substitutions' column
    tmp_df = codon_variants[codon_variants["library"] == lib]
    tmp_df["aa_substitutions"].apply(process_cell)

    # Generate all possible combinations excluding the wildtype for each site
    expected_sites = range(1, 533)
    possible_mutations = [
        "A",
        "C",
        "D",
        "E",
        "F",
        "G",
        "H",
        "I",
        "K",
        "L",
        "M",
        "N",
        "P",
        "Q",
        "R",
        "S",
        "T",
        "V",
        "W",
        "Y",
    ]

    # Adjust expected combinations to exclude the wildtype for each site
    expected_combinations = set()
    for site in expected_sites:
        site_str = str(site)
        if site_str in wildtypes:
            wildtype = wildtypes[site_str]
            for mutation in possible_mutations:
                if mutation != wildtype:  # Exclude the wildtype amino acid
                    expected_combinations.add(site_str + mutation)

    # Extract the actual combinations from the counts
    actual_combinations = set(aa_counts.keys())

    # Find missing combinations
    missing_combinations = expected_combinations - actual_combinations

    # Display results
    print(f"Number of unique site-mutation combinations observed: {len(aa_counts)}")
    print(
        f"Number of missing combinations (excluding wildtypes): {len(missing_combinations)}"
    )
    print(
        f"Total possible combinations excluding wildtypes: {len(expected_combinations)}"
    )
    empty_percent.append(len(actual_combinations) / len(expected_combinations))

uniq_barcodes_per_lib["percent"] = empty_percent
uniq_barcodes_per_lib["percent"] = uniq_barcodes_per_lib["percent"] * 100
uniq_barcodes_per_lib = uniq_barcodes_per_lib.round(2)
uniq_barcodes_per_lib = uniq_barcodes_per_lib.reset_index()
uniq_barcodes_per_lib.to_csv(uniq_barcodes_per_lib_df, index=False)

In [None]:
def calculate_fraction(library):
    total_A = codon_variants[codon_variants["library"] == library].shape[0]
    for number in range(5):
        fraction = codon_variants[
            (codon_variants["library"] == library)
            & (codon_variants["n_aa_substitutions"] == number)
        ].shape[0]
        print(
            f"For {library}, the fraction of sites with {number} mutations relative to WT are: {fraction/total_A:.2f}"
        )


calculate_fraction("LibA")
calculate_fraction("LibB")

In [None]:
def plot_histogram(df):
    df = df.drop(
        columns=[
            "barcode",
            "target",
            "variant_call_support",
            "codon_substitutions",
            "aa_substitutions",
            "n_codon_substitutions",
        ]
    )
    chart = (
        alt.Chart(df)
        .mark_bar(color="black")
        .encode(
            alt.X("n_aa_substitutions:N", title="# of AA Substitutions"),
            alt.Y(
                "count()", title="Count", axis=alt.Axis(grid=True)
            ),  # count() is a built-in aggregation to count rows in each bin
            column=alt.Column(
                "library", header=alt.Header(title=None, labelFontSize=18)
            ),
        )
    )
    return chart


histogram = plot_histogram(codon_variants)
histogram.display()
if histogram_plot is not None:
    histogram.save(histogram_plot)

### Find distribution of functional scores

In [None]:
def pull_in_func_scores(df):
    empty_list = []
    for i in df:
        j = i + "_func_scores.csv"
        tmp_df = pd.read_csv(f"results/func_scores/{j}")
        tmp_df["selection"] = i
        empty_list.append(tmp_df)
        tmp_df = pd.concat(empty_list)
        return tmp_df


e2_func_scores_df = pull_in_func_scores(cho_efnb2_low_selections)
e2_func_scores_df["cell_type"] = "CHO-EFNB2"
e3_func_scores_df = pull_in_func_scores(cho_efnb3_low_selections)
e3_func_scores_df["cell_type"] = "CHO-EFNB3"

# Make combined dataframe of cell entry data
merged_func_scores = pd.concat([e2_func_scores_df, e3_func_scores_df])


def classify_mutation(row):
    if isinstance(row["aa_substitutions"], str) and "*" in row["aa_substitutions"]:
        return "stop"
    elif row["n_aa_substitutions"] == 0:
        if row["n_codon_substitutions"] >= 1:
            return "synonymous"
        else:
            return "wildtype"
    elif row["n_aa_substitutions"] == 1:
        return "1 nonsynonymous"
    elif row["n_aa_substitutions"] >= 2:
        return ">2 nonsynonymous"


# Apply the function to each row in the dataframe to create the new column
merged_func_scores["mutation_class"] = merged_func_scores.apply(
    classify_mutation, axis=1
)

result_df = (
    merged_func_scores.groupby(["barcode", "cell_type"])
    .agg(
        func_score=("func_score", "median"), mutation_class=("mutation_class", "first")
    )
    .reset_index()
)

tmp = (
    result_df.groupby(["mutation_class", "cell_type"])["func_score"]
    .median()
    .reset_index()
)
tmp = tmp.rename(columns={"func_score": "median_func_score"})

result_df = result_df.merge(tmp, on=["mutation_class", "cell_type"], how="left")
display(result_df)

In [None]:
def plot_func_score_distribution(df):
    custom_sort = [
        "wildtype",
        "synonymous",
        "1 nonsynonymous",
        ">2 nonsynonymous",
        "stop",
    ]
    empty_charts = []
    for cell_idx, target_cell in enumerate(["CHO-EFNB2", "CHO-EFNB3"]):
        charts = []
        first_df = df[df["cell_type"] == target_cell]
        for idx, subset in enumerate(custom_sort):
            tmp_df = first_df[first_df["mutation_class"] == subset]

            is_last_plot = idx == len(custom_sort) - 1
            x_axis = alt.Axis(
                labelAngle=-90,
                titleFontSize=10,
                tickCount=3,
                values=[-10, -5, 0],
                title="Functional Score" if is_last_plot else None,
                labels=True if is_last_plot else False,
            )  # Only show labels for the last plot

            first_plot_column = cell_idx == 0
            y_axis = alt.Axis(
                labelAngle=0,
                titleAngle=0,
                title=subset if first_plot_column else None,
                domain=False,
                ticks=False,
                labels=False,
                titleX=-10,
                titleAlign="right",
            )

            chart = (
                alt.Chart(tmp_df, title=(target_cell if idx == 0 else ""))
                .mark_area(color="black")
                .encode(
                    x=alt.X("func_score", bin=alt.Bin(step=0.4), axis=x_axis),
                    y=alt.Y(
                        "count()", title=subset, axis=y_axis
                    ),  # alt.Axis(domain=False, ticks=False, labels=False)),
                    color=alt.Color(
                        "median_func_score",
                        title="Median Functional Score",
                        scale=alt.Scale(scheme="greenblue"),
                    ),
                    # row=alt.Row('mutation_class', title=None, sort=custom_sort, header=alt.Header(title=None)),
                    # column=alt.Column('cell_type'),
                )
                .properties(width=100, height=50)
            )

            charts.append(chart)
        combined_muts_chart = alt.vconcat(*charts, spacing=0).resolve_scale(
            y="independent", x="shared", color="shared"
        )
        empty_charts.append(combined_muts_chart)
    # Combine charts using vertical concatenation, adjusting scales and configuration as needed
    combined_chart = (
        alt.hconcat(*empty_charts, spacing=0)
        .resolve_scale(y="independent", x="shared", color="shared")
        .configure_view(stroke=None)
        .configure_axis(grid=False)
        .configure_title(
            anchor="middle",  # Anchors the title to the start of the chart
            offset=5,  # Adjusts the distance between the title and the chart
            fontSize=16,  # Adjusts the font size of the title
            # dx=5,  # Shifts the title horizontally (use negative value to shift left)
            # dy=-5  # Shifts the title vertically (use negative value to shift up)
        )
    )

    return combined_chart


tmp_img = plot_func_score_distribution(result_df)
tmp_img.display()
if histogram_plot is not None:
    tmp_img.save(func_scores_plot)