### This notebook makes heatmaps of entry and binding data from filtered data

In [42]:
# this cell is tagged as parameters for `papermill` parameterization
#input configs
altair_config = None
nipah_config = None

#input files
entropy_file = None

func_scores_E2_file = None
binding_E2_file = None
e2_low_func_file = None

func_scores_E3_file = None
binding_E3_file = None
e3_low_func_file = None

#output images
E2_binding_heatmap = None
E3_binding_heatmap = None
E2_entry_heatmap = None
E3_entry_heatmap = None
combined_entry_contact_heatmaps = None
combined_contact_binding_plot = None

In [43]:
import math
import os
import re
import altair as alt
import numpy as np
import pandas as pd
import scipy.stats
import yaml

In [44]:
# allow more rows for Altair
_ = alt.data_transformers.disable_max_rows()

if os.getcwd() == '/fh/fast/bloom_j/computational_notebooks/blarsen/2023/Nipah_Malaysia_RBP_DMS/':
    pass
    print("Already in correct directory")
else:
    os.chdir("/fh/fast/bloom_j/computational_notebooks/blarsen/2023/Nipah_Malaysia_RBP_DMS/")
    print("Setup in correct directory")

Setup in correct directory


In [45]:
if nipah_config is None:
##hard paths in case don't want to run with snakemake
    print('loading hard paths')
    altair_config = "data/custom_analyses_data/theme.py"
    nipah_config = "nipah_config.yaml"
    entropy_file = 'results/entropy/entropy.csv'
    
    #input files
    func_scores_E2_file = "results/filtered_data/E2_entry_filtered.csv"
    binding_E2_file = "results/filtered_data/E2_binding_filtered.csv"
    e2_low_func_file = "results/filtered_data/E2_binding_low_effect_filter.csv"
    
    func_scores_E3_file = "results/filtered_data/E3_entry_filtered.csv"
    binding_E3_file = "results/filtered_data/E3_binding_filtered.csv"
    e3_low_func_file = "results/filtered_data/E3_binding_low_effect_filter.csv"

loading hard paths


#### Read configs

In [46]:
if altair_config:
    with open(altair_config, 'r') as file:
        exec(file.read())

with open(nipah_config) as f:
    config = yaml.safe_load(f)

In [47]:
#import filtered data
e2_binding = pd.read_csv(binding_E2_file)
e2_func = pd.read_csv(func_scores_E2_file)
e2_low_func = pd.read_csv(e2_low_func_file)

e3_binding = pd.read_csv(binding_E3_file)
e3_func = pd.read_csv(func_scores_E3_file)
e3_low_func = pd.read_csv(e3_low_func_file)

In [None]:
def prepare_entropy():  # need to prepare entropy data for plotting on heatmap
    # read in entropy data, calculated in different notebook
    entropy = pd.read_csv(entropy_file)
    df = entropy[["site", "henipavirus_entropy"]]
    df = df.dropna(subset=["site"])
    df["site"] = df["site"].astype("Int64")
    df = df.rename(columns={"henipavirus_entropy": "entropy"})
    df['entropy'] = df['entropy'].round(2)
    df = df[["site", "entropy"]].drop_duplicates()
    df["mutant"] = "entropy"
    df["wildtype"] = ""
    df["type"] = "entropy"
    df.rename(columns={"entropy": 'value'}, inplace=True)
    return df

def make_contact():
    df = pd.DataFrame(
        {
            "site": config["contact_sites"],
            "contact": [0.0] * len(config["contact_sites"]),
        }
    )
    # Renaming and restructuring the dataframe as per your original function
    df["mutant"] = "contact"
    df["wildtype"] = ""
    df["type"] = "contact"
    df.rename(columns={"contact": 'value'}, inplace=True)
    return df

# This gets called during heatmap generation
def make_empty_df(df, contact_df=None, entropy_df=None, contact_flag=None, entropy_flag=None,low_entry_df=None,binding_flag=None):
    sites = range(71, 603)
    amino_acids = ["R", "K", "H", "D", "E", "Q", "N", "S", "T", "Y", "W", "F", "A", "I", "L", "M", "V", "G", "P", "C"]
    # Create the combination of each site with each amino acid
    data = [{"site": site, "mutant": aa} for site in sites for aa in amino_acids]

    # Create the DataFrame
    empty_df = pd.DataFrame(data)
    #all_sites_df = pd.merge(empty_df, df, on=["site", "mutant"], how="left")

    if binding_flag:
        if low_entry_df is None:
            print('You indicated binding but did not provide a low_entry_df')
        all_sites_df = pd.merge(empty_df, df, on=["site", "mutant"], how="left")
        df_test = all_sites_df.melt(
            id_vars=["site", "mutant", "wildtype"],
            value_vars=["binding_median"],
            var_name="type",
            value_name='value',
        )
        low_entry_df = low_entry_df.rename(columns={'effect':'low_effect'})
        df_filter = low_entry_df.melt(
            id_vars=["site","mutant","wildtype"],
            value_vars=["low_effect"],
            var_name="type",
            value_name='value',
        )
        df_test = pd.concat([df_test,df_filter])
        
    else:
        all_sites_df = pd.merge(empty_df, df, on=["site", "mutant"], how="left")
        df_test = all_sites_df.melt(
            id_vars=["site", "mutant", "wildtype"],
            value_vars=["effect"],
            var_name="type",
            value_name='value',
        )
    
    
    if contact_flag and entropy_flag is None:
        df_test = pd.concat([df_test], ignore_index=True)
    if contact_flag is True:
        df_test = pd.concat([df_test, contact_df], ignore_index=True)
    if entropy_flag is True:
        df_test = pd.concat([df_test, entropy_df], ignore_index=True)
    if entropy_flag and contact_flag is True:
        df_test = pd.concat([df_test, entropy_df, contact_df], ignore_index=True)
    


    return df_test

In [None]:
# Make the base heatmap. This contains information about the x_axis and heatmap_sites which are important for sorting them correctly. 
def make_base_heatmap(df, heatmap_sites, x_axis):
    base = (
        alt.Chart(df)
        .encode(
            x=alt.X("site:O", title="Site", sort=heatmap_sites, axis=x_axis),
            y=alt.Y(
                "mutant",
                title="Amino Acid",
                sort=alt.EncodingSortField(field='mutant_rank', order='ascending'),
                axis=alt.Axis(grid=False),
            ),
        )
        .properties(
            width=alt.Step(10),
            height=alt.Step(11),
        )
    )
    return base

# This makes an 'empty' heatmap that shows sites that were not observed as some color (default:gray)
def make_empty_heatmap(base, background_color):
    chart_empty = (
        base.mark_rect(color=background_color)
        .encode(
            tooltip=['site', 'mutant']
        )
        .transform_filter(
            ((alt.datum.type == "effect") | (alt.datum.type == 'binding_median')) & (alt.datum.value == None) 
        )
    )
    return chart_empty
# This makes the white squares and X for the wildtype amino acids
def make_wildtype_heatmap(unique_wildtypes_df, strokewidth_size, heatmap_sites):
    wildtype_layer_box = (
        alt.Chart(unique_wildtypes_df)
        .mark_rect(color="white", stroke="black", strokeWidth=strokewidth_size)
        .encode(
            x=alt.X("site:O", sort=heatmap_sites),
            y=alt.Y("wildtype", sort=alt.EncodingSortField(field="mutant_rank", order="ascending")),
            tooltip=["site", "wildtype"],
        )
        .transform_filter(
            ((alt.datum.type == "effect") | (alt.datum.type == 'binding_median')) & (alt.datum.wildtype != None) & (alt.datum.value != None)
        )
    )
    wildtype_layer = (
        alt.Chart(unique_wildtypes_df)
        .mark_text(color="black", text="X", size=8, align="center", baseline="middle")
        .encode(
            x=alt.X("site:O", sort=heatmap_sites),
            y=alt.Y("wildtype", sort=alt.EncodingSortField(field="mutant_rank", order="ascending")),
            tooltip=["site", "wildtype"],
        )
        .transform_filter(
            ((alt.datum.type == "effect") | (alt.datum.type == 'binding_median')) & (alt.datum.wildtype != None) & (alt.datum.value != None)
        )
    )
    return wildtype_layer_box, wildtype_layer

# This makes the actual effect heatmap, and adds a bar for the legend if its the first time through the loop
def create_effect_chart(base, color_scale_effect, strokewidth_size, legend_title=None, effect_legend_added=None):
    legend = alt.Legend(title=legend_title) if effect_legend_added is True else None
    chart = (
        base.mark_rect(stroke="black", strokeWidth=strokewidth_size)
        .encode(
            color=alt.condition(
                '(datum.type == "effect" | datum.type == "binding_median")',
                alt.Color('value:Q', scale=color_scale_effect, legend=legend),
                alt.value("transparent"),
            ),
            tooltip=['site', 'mutant', 'wildtype', 'value']
        )
        .transform_filter(
            (alt.datum.wildtype != '') & (alt.datum.wildtype != None)
        )
    )
    return chart
# This makes a chart for the entropy values 
def create_entropy_chart(base, color_scale_entropy, strokewidth_size, legend_title=None, entropy_legend_added=None):
    legend = alt.Legend(title='Henipavirus Entropy') if entropy_legend_added is True else None
    chart = (
        base.mark_rect(stroke="black", strokeWidth=strokewidth_size)
        .encode(
            color=alt.condition(
                'datum.mutant == "entropy"',
                alt.Color('value:Q', scale=color_scale_entropy, legend=legend),
                alt.value("transparent"),
            ),
            tooltip=['site', 'mutant', 'wildtype', 'value']
        )
    )
    return chart
# This makes a chart for the contact sites
def create_contact_chart(base):
    chart_contact = (
        base.mark_rect(color="black")
        .encode(tooltip=['site'])
        .transform_filter(
            (alt.datum.mutant == "contact")
        )
    )
    return chart_contact

def make_low_effect_heatmap(base,strokewidth_size,heatmap_sites):
    chart_filtered = (
            base.mark_rect(color='#939598',stroke='black',strokeWidth = strokewidth_size).encode(
            ).transform_filter(
                alt.datum.type == 'low_effect'
            )
    )
    return chart_filtered

# This compiles all the different charts and returns a single chart
def compile_chart(df, heatmap_sites, unique_wildtypes_df, x_axis, background_color, color_scale_effect, color_scale_entropy, strokewidth_size=None, legend_title=None, effect_legend_added=None, entropy_legend_added=None,binding_flag=None):
    base = make_base_heatmap(df, heatmap_sites, x_axis)
    chart_empty = make_empty_heatmap(base, background_color)
    chart_contact = create_contact_chart(base)
    chart_effect = create_effect_chart(base, color_scale_effect, strokewidth_size, legend_title, effect_legend_added)
    chart_entropy = create_entropy_chart(base, color_scale_entropy, strokewidth_size, legend_title, entropy_legend_added)
    wildtype_layer_box, wildtype_layer = make_wildtype_heatmap(unique_wildtypes_df, strokewidth_size, heatmap_sites)
    if binding_flag:
        low_entry_heatmap = make_low_effect_heatmap(base,strokewidth_size,heatmap_sites)
        chart = alt.layer(
            chart_empty,
            chart_effect,
            low_entry_heatmap,
            chart_entropy,
            chart_contact,
            wildtype_layer_box,
            wildtype_layer,
        ).resolve_scale(y="shared", x="shared", color="independent")
    else:
        chart = alt.layer(
            chart_empty,
            chart_effect,
            chart_entropy,
            chart_contact,
            wildtype_layer_box,
            wildtype_layer,
        ).resolve_scale(y="shared", x="shared", color="independent")
    
    return chart

In [None]:
def plot_entry_heatmap(
    df, 
    legend_title, 
    null_color=None, 
    ranges=None, 
    effect_color=None, 
    entropy_color=None,
    strokewidth_size=None,
    custom_y_axis_order=None,
    entropy_flag=None,
    contact_flag=None,
    specific_sites=None,
    specific_sites_name=None,
    low_entry_df = None,
    binding_flag=None,
    custom_domain=None):
    
    """
    Generates a customizable heatmap for deep mutational scanning (DMS) data visualization.

    Parameters:
    - df (DataFrame): The data frame containing the data to be visualized. It must include the columns 'site', 'mutant', 'value', and 'wildtype'.
    - legend_title (str): The title of the heatmap legend.
    - null_color (str, optional): Color for mutants with no observations. Default is 'gray'.
    - ranges (list of tuples, optional): Defines the ranges for site wrapping on the heatmap. If not provided, a default range is used.
    - effect_color (str, optional): Color scheme for effect values. Default is 'red-blue'.
    - entropy_color (str, optional): Color scheme for entropy values. Default is 'purples'.
    - strokewidth_size (float, optional): The width of the stroke used in the heatmap. Default size is not specified.
    - custom_y_axis_order (list, optional): Specifies a custom order for the y-axis, overriding the default amino acid order.
    - entropy_flag (bool, optional): If True, sequence entropy is included in the heatmap. Default is False.
    - contact_flag (bool, optional): If True, contact sites are included in the heatmap. Default is False.
    - specific_sites (list, optional): Specifies a subset of sites to be plotted. If None, all sites are plotted using wrapping. Default is None.
    - specific_sites_name (str, optional): A title to display at the top of the heatmap for specific sites. Default is None.
    - low_entry_df (DataFrame,optional): If given, will use different color to show sites with low entry scores (Used for Binding Score Heatmaps)
    - binding_flag (bool, optional): If True, will plot binding instead of entry. Must be used with low_entry_df to mask low cell entry mutants
    - custom_domain (list, optional): Give custom domain used for coloring. If None, will use default [-4,2.5]

    Returns:
    An Altair chart object representing the generated heatmap. This chart can be further customized or directly displayed in Jupyter notebooks or other compatible environments.
    """
    
    if contact_flag:
        contact_df = make_contact()
    else: 
        contact_df = None
    if entropy_flag is True:
        entropy_df = prepare_entropy()
    else:
        entropy_df = None

    # Make the dataframes for plotting.
    empty_df = make_empty_df(df,contact_df=contact_df,entropy_df=entropy_df,contact_flag=contact_flag,entropy_flag=entropy_flag,low_entry_df=low_entry_df,binding_flag=binding_flag)

    # Define the base order list
    base_order = ["R", "K", "H", "D", "E", "Q", "N", "S", "T", "Y", "W", "F", "A", "I", "L", "M", "V", "G", "P", "C"]
    
    # Initialize custom_order with custom_y_axis_order or base_order based on custom_y_axis_order's value
    custom_order = custom_y_axis_order if custom_y_axis_order is not None else base_order
    # Prepend conditions based on flags
    if entropy_flag and contact_flag:
        # Both flags are true, prepend both "contact" and "entropy"
        custom_order = ["contact", "entropy"] + custom_order
    elif entropy_flag:
        # Only entropy_flag is true, prepend "entropy"
        custom_order = ["entropy"] + custom_order
    elif contact_flag:
        # Only contact_flag is true, prepend "contact"
        custom_order = ["contact"] + custom_order
    
    # Optional parameters
    if null_color is None:
        background_color = "#d1d3d4"
    else:
        background_color = null_color
        
    # Sites for wrapping heatmap correctly
    if ranges is None:
        full_ranges = [
            list(range(start, end))
            for start, end in [(71, 204), (204, 337), (337, 470), (470, 603)]
        ]
    else:
        full_ranges=ranges
    
    # effect_color
    if custom_domain:
        if effect_color is None:
            color_scale_effect = alt.Scale(scheme="redblue", domainMid=0, domain=custom_domain)
        else:
            color_scale_effect = alt.Scale(scheme=effect_color, domainMid=0, domain=custom_domain)
    else:
        if effect_color is None:
            color_scale_effect = alt.Scale(scheme="redblue", domainMid=0, domain=[-4,2.5])
        else:
            color_scale_effect = alt.Scale(scheme=effect_color, domainMid=0, domain=[-4,2.5])
        
    
    # entropy_color
    if entropy_color is None:
        color_scale_entropy = alt.Scale(scheme="purples", domain=[0, 2], reverse=True)
    else:
        color_scale_entropy = alt.Scale(scheme=entropy_color, domain=[0, 2], reverse=True)
    
    # strokewidth size
    if strokewidth_size is None:
        strokewidth_size = 0.25
    else:
        strokewidth_size = strokewidth_size

    if entropy_flag is None:
        entropy_legend_added = None
    else:
        entropy_legend_added = True
      
    effect_legend_added = True

    def determine_sorting_order(df):
        # Sort the dataframe by 'site' to ensure that duplicates are detected correctly.
        final_df = df.sort_values("site")
        sort_order = {mutant: i for i, mutant in enumerate(custom_order)}
        final_df["mutant_rank"] = final_df["mutant"].map(sort_order)
        # Map the 'mutant' column to these ranks
        # Now sort the dataframe by this rank
        final_df = final_df.sort_values("mutant_rank")
        sites = sorted(final_df["site"].unique(), key=lambda x: float(x))
        return final_df, sites, sort_order
    
    heatmap_df, heatmap_sites,sort_order = determine_sorting_order(empty_df)
    
    # container to hold the charts
    charts = []

    if specific_sites:
        #Filter the heatmap to only show certain sites
        subset_df = heatmap_df[heatmap_df["site"].isin(specific_sites)]
        
        ### Need to do independently for wildtype here for individual sites
        unique_wildtypes_df = subset_df.drop_duplicates(subset=["site","wildtype"])  
        unique_wildtypes_df = unique_wildtypes_df.sort_values("site")
        sort_order = {mutant: i for i, mutant in enumerate(custom_order)}
        unique_wildtypes_df["mutant_rank"] = unique_wildtypes_df["wildtype"].map(sort_order)
        unique_wildtypes_df = unique_wildtypes_df.sort_values("mutant_rank")

        #Setup x-axis labeling
        x_axis = alt.Axis(
                labelAngle=-90,
                title="Site",
                labels=True,
        )  
        # Run the main heatmap compiler function
        chart = compile_chart(subset_df, heatmap_sites, unique_wildtypes_df, x_axis, background_color, color_scale_effect=color_scale_effect, color_scale_entropy=color_scale_entropy,strokewidth_size=strokewidth_size, legend_title=legend_title, effect_legend_added=effect_legend_added, entropy_legend_added=entropy_legend_added,binding_flag=binding_flag)
        #Since this is a single chart, I don't know why I need to do this, but I seem to get errors if I don't append and then do alt.vconcat below. I get why I need to do this for multiple heatmaps in a for loop, but not here. Leaving in.
        charts.append(chart)
        if specific_sites_name:
            specific_sites_name=specific_sites_name
        else:
            specific_sites_name=''
        combined_charts = alt.vconcat(*charts,title=specific_sites_name).resolve_scale(y="shared", x="shared", color="shared")
        return combined_charts
    else:
        for idx, subset in enumerate(full_ranges):
            # Flags for showing the legend only the first time
            subset_df = heatmap_df[
                heatmap_df["site"].isin(subset)
            ]  # for the wrapping of sites
            unique_wildtypes_df = subset_df.drop_duplicates(
                subset=["site", "wildtype"]
            )  # for the wildtype mapping
            
            # Keep track of where in the loop we are for plotting
            is_last_plot = idx == len(full_ranges) - 1
            x_axis = alt.Axis(
                labelAngle=-90,
                labelExpr="datum.value % 10 === 0 ? datum.value : ''",
                title="Site" if is_last_plot else None,
                labels=True,
            )
            chart = compile_chart(subset_df, heatmap_sites, unique_wildtypes_df, x_axis, background_color, color_scale_effect=color_scale_effect, color_scale_entropy=color_scale_entropy,strokewidth_size=strokewidth_size, legend_title=legend_title, effect_legend_added=effect_legend_added, entropy_legend_added=entropy_legend_added,binding_flag=binding_flag)
            charts.append(chart)
            effect_legend_added = None
            entropy_legend_added = None
        combined_chart = alt.vconcat(
            *charts, spacing=3, title=f"{legend_title}"
        ).resolve_scale(y="shared", x="independent", color="shared")
        return combined_chart

### Now that we have heatmap function setup, make heatmaps:

#### First, do binding heatmaps

In [None]:
E2_binding_heatmap_full = plot_entry_heatmap(
    df=e2_binding,
    legend_title='EFNB2 Binding',
    null_color=config['background_color'],
    effect_color=config['effect_color'],
    entropy_color=config['entropy_color'],
    strokewidth_size=config['strokewidth_size'],
    contact_flag = True,
    #entropy_flag = True,
    low_entry_df = e2_low_func,
    binding_flag = True,
    custom_domain=[-6,2.5]
    
)
E2_binding_heatmap_full.display()
E2_binding_heatmap_full.save(E2_binding_heatmap)

In [None]:
E3_binding_heatmap_full = plot_entry_heatmap(
    df=e3_binding,
    legend_title='EFNB3 Binding',
    null_color=config['background_color'],
    effect_color=config['effect_color'],
    entropy_color=config['entropy_color'],
    strokewidth_size=config['strokewidth_size'],
    contact_flag = True,
    #entropy_flag = True,
    low_entry_df = e3_low_func,
    binding_flag = True,
    custom_domain=[-2,2],
    
)
E3_binding_heatmap_full.display()
E3_binding_heatmap_full.save(E3_binding_heatmap)

#### Then, make entry heatmaps:

In [None]:
E2_entry_heatmap_full = plot_entry_heatmap(
    df = e2_func, 
    legend_title = "CHO-EFNB2 Entry", 
    null_color=config['background_color'],
    effect_color=config['effect_color'],
    entropy_color=config['entropy_color'],
    strokewidth_size=config['strokewidth_size'],
    contact_flag = True,
    entropy_flag = True
)
E2_entry_heatmap_full.display()
E2_entry_heatmap_full.save(E2_entry_heatmap)

In [None]:
E3_entry_heatmap_full = plot_entry_heatmap(
    df = e3_func, 
    legend_title = "CHO-EFNB3 Entry", 
    null_color=config['background_color'],
    effect_color=config['effect_color'],
    entropy_color=config['entropy_color'],
    strokewidth_size=config['strokewidth_size'],
    contact_flag = True,
    entropy_flag = True
)
E3_entry_heatmap_full.display()
E3_entry_heatmap_full.save(E3_entry_heatmap)

#### Then, make entry and binding heatmaps for contact sites:

In [None]:
E2_entry_heatmap_contact = plot_entry_heatmap(
    df = e2_func, 
    legend_title = "CHO-EFNB2 Entry", 
    null_color=config['background_color'],
    effect_color=config['effect_color'],
    entropy_color=config['entropy_color'],
    strokewidth_size=config['strokewidth_size'],
    specific_sites=config['contact_sites'],
    #specific_sites_name='Contact Sites',
    #contact_flag = False,
    #entropy_flag = True,
)
E2_entry_heatmap_contact.display()

In [None]:
E3_entry_heatmap_contact = plot_entry_heatmap(
    df = e3_func, 
    legend_title = "CHO-EFNB3 Entry", 
    null_color=config['background_color'],
    effect_color=config['effect_color'],
    entropy_color=config['entropy_color'],
    strokewidth_size=config['strokewidth_size'],
    specific_sites=config['contact_sites'],
    #specific_sites_name='Contact Sites',
    #contact_flag = True,
    #entropy_flag = True,
)
E3_entry_heatmap_contact.display()

In [None]:
combined_contact = alt.hconcat(E2_entry_heatmap_contact, E3_entry_heatmap_contact,title='Contact Sites')
combined_contact.display()
combined_contact.save(combined_entry_contact_heatmaps)

In [None]:
E2_binding_heatmap_contact = plot_entry_heatmap(
    df=e2_binding,
    legend_title='EFNB2 Binding',
    null_color=config['background_color'],
    effect_color=config['effect_color'],
    entropy_color=config['entropy_color'],
    strokewidth_size=config['strokewidth_size'],
    specific_sites=config['contact_sites'],
    #contact_flag = True,
    #entropy_flag = True,
    low_entry_df = e2_low_func,
    binding_flag = True,
    custom_domain=[-6,2.5]
)
#E2_binding_heatmap_contact.display()

E3_binding_heatmap_contact = plot_entry_heatmap(
    df=e3_binding,
    legend_title='EFNB3 Binding',
    null_color=config['background_color'],
    effect_color=config['effect_color'],
    entropy_color=config['entropy_color'],
    strokewidth_size=config['strokewidth_size'],
    specific_sites=config['contact_sites'],
    #contact_flag = True,
    #entropy_flag = True,
    low_entry_df = e3_low_func,
    binding_flag = True,
    custom_domain=[-2,2]
)
#E3_binding_heatmap_contact.display()
combined_contact_binding = alt.hconcat(E2_binding_heatmap_contact, E3_binding_heatmap_contact,title='Contact Sites')
combined_contact_binding.display()
combined_contact_binding.save(combined_contact_binding_plot)

In [None]:
entry_binding_combined_heatmap = alt.vconcat(combined_contact,combined_contact_binding)
entry_binding_combined_heatmap.display()

In [None]:
E3_binding_heatmap_contact = plot_entry_heatmap(
    df=e3_binding,
    legend_title='EFNB3 Binding',
    null_color=config['background_color'],
    effect_color=config['effect_color'],
    entropy_color=config['entropy_color'],
    strokewidth_size=config['strokewidth_size'],
    specific_sites=config['contact_sites'],
    #contact_flag = True,
    #entropy_flag = True,
    low_entry_df = e3_low_func,
    binding_flag = True,
    custom_domain=[-2,2]
)

In [None]:
hydrophobic_AA = ['A','V','L','I','M']
aromatic_AA = ['Y','W','F']
positive_AA = ['K','R','H']
negative_AA = ['E','D']
hydrophilic_AA = ['S','T','N','Q']

def find_aa_wildtype_sites(df,aa_type):
    aa_list = list(df[df['wildtype'].isin(aa_type)]['site'].unique())
    return aa_list

def make_AA_lists(df):
    hydrophobic_AA_list = find_aa_wildtype_sites(df,hydrophobic_AA)
    aromatic_AA_list = find_aa_wildtype_sites(df,aromatic_AA)
    positive_AA_list = find_aa_wildtype_sites(df,positive_AA)
    negative_AA_list = find_aa_wildtype_sites(df,negative_AA)
    hydrophilic_AA_list = find_aa_wildtype_sites(df,hydrophilic_AA)
    all_AA = [hydrophobic_AA_list,aromatic_AA_list,positive_AA_list,negative_AA_list,hydrophilic_AA_list]
    return all_AA

all_AA = make_AA_lists(e2_binding)
AA_names = ['Hydrophobic','Aromatic','Positive','Negative','Hydrophilic']

empty_charts = []

def make_aa_property_charts(df,df_low,sites_list,sites_name,custom_domain,legend_name):
    for aa_type, name in zip(sites_list, sites_name):
        aa_property = plot_entry_heatmap(
            df=df,
            legend_title=legend_name,
            null_color=config['background_color'],
            effect_color=config['effect_color'],
            entropy_color=config['entropy_color'],
            strokewidth_size=config['strokewidth_size'],
            specific_sites=aa_type,
            specific_sites_name=name,
            #contact_flag = True,
            #entropy_flag = True,
            low_entry_df = df_low,
            binding_flag = True,
            custom_domain=custom_domain
        )
        empty_charts.append(aa_property)
    combined_chart = alt.vconcat(*empty_charts, spacing=3)
    return combined_chart

E2_binding_AA_prop = make_aa_property_charts(e2_binding,e2_low_func, all_AA, AA_names, [-6,2.5],'EFNB2 Binding')
E2_binding_AA_prop.display()

In [None]:
E3_binding_AA_prop = make_aa_property_charts(e3_binding,e3_low_func, all_AA, AA_names, [-2,2],'EFNB3 Binding')
E3_binding_AA_prop.display()