# Visualizing binding afinity and expression from RBD DMS of VOCs

This notebook contains code to make interactive heatmaps for binding and expression measured in the RBD DMS libraries for four SARS-CoV-2 variants of concern and the ancestral reference sequence `[Wuhan-Hu-1, Alpha, Beta, Delta, and Eta]`. 

Most of this code was inspired, and in some cases, repurposed from Sarah Hilton's work found [here](https://github.com/jbloomlab/SARS-CoV-2-RBD_DMS/blob/master/interactive_heatmap.ipynb). 

In [1]:
import itertools 
import pandas as pd
import numpy as np
import altair as alt

# Remove the limit of ~5000 rows
alt.data_transformers.disable_max_rows()

DataTransformerRegistry.enable('default')

## Process the data

Import and format the data for the interactive `Altair` heat maps. 

In [2]:
## == Filepaths == ##  

# Input path to variant scores 
variant_scores_path = "results/final_variant_scores/final_variant_scores.csv"
# Input path to RBD annotations 
rbd_annotations_path = "data/RBD_sites.csv"

# Output path to HTML plot file
output_html = "docs/_includes/heatmap.html"

In [3]:
# Variant names as a dictionary - ordered as they will appear in the lengend
variants_names = {
    "Wuhan-Hu-1": "Wuhan-Hu-1",
    "N501Y": "Alpha", 
    "Beta": "Beta",
    "Delta": "Delta",
    "E484K": "Eta",
}


# Expression and Binding scores per variant background -- for the heatmap plots
scores_df = (
    pd.read_csv(variant_scores_path)
        .rename(columns = {"position":"site",
                           "delta_expr":"RBD expression", # rename for the selection menus 
                           "delta_bind":"ACE2 binding"    # rename for the selection menus 
                          })
        .replace({"target":variants_names})
)

# Drop the un-used columns - shrinks the size of the altair plot
scores_df = scores_df.drop(columns=[column for column in scores_df.columns
                                   if column not in ['target', 'wildtype', 'site', 'mutant',
                                                     'RBD expression', 'mutation', 'ACE2 binding', 'n_bc_bind']
                                   ]
                          )


# Annotations for RBD sites - we really only care about ACE2 contact sites
# from Sarah Hilton - https://github.com/jbloomlab/SARS-CoV-2-RBD_DMS/blob/4f506dcbfade2a5efb9beb324a7f0b2f675ab9fb/data/RBD_sites.csv
site_level_annotations_df = (pd.read_csv(rbd_annotations_path)
                                 .rename(columns={"site_SARS2":"site",
                                                  "SARS2_ACE2_contact":"ACE2_contact"})
                            )

# Add the ACE2 contacts 
scores_df = pd.merge(scores_df,  site_level_annotations_df[["site", "ACE2_contact"]], how='left', on='site').replace({'ACE2_contact': {True: 'Yes', False: 'No'}})

# Set a character, `x`, to appear in the wildtype sites
scores_df['wildtype_code'] = (scores_df[['wildtype', 'mutant']].apply(lambda x: 'x' if x[0] == x[1] else '', axis=1))



## Define plot-wide parameters

Define the parameters that carry through to multiple plots for easy adjustment. 

In [4]:
# Width of the zoom bar
width = 1000

# Height of the heatmap
height = 300

# How to space the sites on the x axis on the zoom bar 
min_site = min({site for site in scores_df.site})
max_site = max({site for site in scores_df.site})
x_axis_spacing = 5

# Order of the amino acids on the y-axis
aa_order = ['R', 'K', 'H', 'D', 'E', 'Q', 'N', 'S', 'T', 'Y',
            'W', 'F', 'A', 'I', 'L', 'M', 'V', 'G', 'P', 'C', '*']
    
# Tooltips and corresponding names with formatting
heatmap_tooltips = [
    alt.Tooltip('target:N', title="Variant"),
    alt.Tooltip('mutation:N', title="Mutation"),
    alt.Tooltip('ACE2 binding:Q', title="Change in ACE2 Binding", format=".2f"), 
    alt.Tooltip('RBD expression:Q', title="Change in RBD Expression", format=".2f"),
    alt.Tooltip('n_bc_bind:Q', title="Barcode Count"),  
    alt.Tooltip('ACE2_contact:N', title="ACE2 Contact")
]

# Scale limits for the heatmap coloring - these have to apply to *both* expression and binding
minimum_domain = -4.0
maximum_domain = 1.5


## Define selections for plots 

Define the selecion objects that define interaction. Many of these are shared between plots and datasets, so it's helpful to define these at the top. 

In [5]:
# Zoom bar brush to look closer at a region in the line plot. 
zoom_selection = alt.selection_interval(encodings=['x'], mark=alt.BrushConfig(stroke='black', strokeWidth=2))

 
# Cell selector for highlighting the cell you're currently mousing over 
amino_acid_selection = alt.selection_single(encodings=['x', 'y'], on='mouseover', empty='none')


## === Selections for the TOP heatmap === ##
# Drop down to select the variant to display in the top heatamp (wuhan, alpha, beta, delta, eta)
top_dropdown = alt.binding_select(options=list({variant for variant in scores_df.target}), name="Select Top Variant: ")
top_selection = alt.selection_single(fields=['target'], bind=top_dropdown, init={"target":"Wuhan-Hu-1"})

# Drop down to select the metric displayed in the heatmap ( (delta) Expression or Binding )
top_metric_dropdown = alt.binding_select(options=['ACE2 binding', 'RBD expression'], name="Select Top Metric: ")
top_metric_selection = alt.selection_single(fields=['metric'], bind=top_metric_dropdown, init={'metric': 'ACE2 binding'})

## === Selections for the BOTTOM heatmap === ##
# Drop down to select the variant to display in the bottom heatamp (wuhan, alpha, beta, delta, eta)
bottom_dropdown = alt.binding_select(options=list({variant for variant in scores_df.target}), name="Select Bottom Variant: ")
bottom_selection = alt.selection_single(fields=['target'], bind=bottom_dropdown, init={"target":"Beta"})

# Drop down to select the metric displayed in the heatmap ( (delta) Expression or Binding )
bottom_metric_dropdown = alt.binding_select(options=['ACE2 binding', 'RBD expression'], name="Select Bottom Metric: ")
bottom_metric_selection = alt.selection_single(fields=['metric'], bind=bottom_metric_dropdown, init={'metric': 'ACE2 binding'})




## Define the plot objects

The final plot will be comprised of multiple **Heatmaps** that display the binding and expression for the RBD DMS from the four variants of concern and the ancestral sequence. 

### Zoom Bar

In [6]:
## == Zoom bar for the heatmap plot == ## 
zoom_bar = alt.Chart(scores_df[['site']].drop_duplicates()
    ).mark_rect(
        color='lightgray'
    ).encode(
        x=alt.X('site:O',
                title=None,
                axis=alt.Axis(values=list(range(min_site, max_site, x_axis_spacing)))
               )
    ).add_selection(
        zoom_selection
    ).properties(
        width=width,
        height=15,
        title="site zoom bar"
)

zoom_bar


### Heatmaps

In [7]:
## == Heatmaps plots with annotations == ## 
def heatmap(data, variant_selection, metric_selection):
    """
    Function to repoduce code for a heatmaps based on a different seleciton. 
    This shortens the amount of code needed to make two heatmaps that will 
    end up concatented together in the final plot. 
    
    
    """
    
    # Define the input dataset once in the base plot
    heatmapbase = alt.Chart(data
    ).transform_fold(
        ['ACE2 binding', 'RBD expression'],
        as_=['metric', 'measurement']
    ).transform_filter(
        metric_selection & variant_selection
    ).encode(
        x=alt.X('site:O',
                axis=alt.Axis(titleFontSize=15)),
        y=alt.Y('mutant:O',
                sort=aa_order,
                axis=alt.Axis(labelFontSize=12,
                              titleFontSize=15))
    )


    # Define the metric by which the plot is colored - i.e. (delta) expression or binding
    coloring = heatmapbase.mark_rect(
    ).encode(
        color= alt.Color('measurement:Q',
                         type='quantitative',
                         scale=alt.Scale(scheme='redblue',
                                         domain=[minimum_domain, maximum_domain],
                                         domainMid=0, 
                                         clamp=True
                                        ),
                           legend=alt.Legend(orient='left',
                                             title='grey is n.d.',
                                             gradientLength=100)),
        stroke=alt.value('black'),
        strokeWidth=alt.condition(amino_acid_selection,
                                  alt.value(2),
                                  alt.value(0)),
        tooltip=heatmap_tooltips
    )


    # And a black 'x' to the wildtype amino acids 
    wildtype = heatmapbase.mark_text(
        color='black'
    ).encode(
        text=alt.Text('wildtype_code:N')
    )

    # Color the empty measurements gray
    nulls = heatmapbase.mark_rect(
    ).transform_filter(
        "!isValid(datum.measurement)"
    ).mark_rect(
        opacity=0.5
    ).encode(
        alt.Color(f'measurement:N',
                  scale=alt.Scale(scheme='greys'),
                  legend=None)
    )
    
    # Add the 'title' - altair doesn't support dynamic axes or titles, so it's just a mark.
    conditions = heatmapbase.mark_text(
        align='left', baseline='bottom', fontSize=16, fontWeight='bold'
    ).encode(
        x=alt.value(0.0), y=alt.value(-1), text='_label:N'
    ).transform_aggregate(
        groupby=["target", "metric"]
    ).transform_calculate(
        _label = 'datum.target + " " + datum.metric' 
    )
                  

    # Return the final heatmap along with annotations
    return (coloring + nulls + wildtype + conditions
    ).interactive(
    ).add_selection(
        amino_acid_selection,
        zoom_selection
    ).transform_filter(
        zoom_selection
    ).properties(height=height)



    
    
    

In [8]:
# Make a heatmap with the top selection and bottom selection
top_heatmap = heatmap(scores_df, top_selection, top_metric_selection)
bottom_heatmap = heatmap(scores_df, bottom_selection, bottom_metric_selection)


# Combine the two heatmaps with the zoom bar - you have to add the selections here backwards to get the right order - why!?
final_heatmap = (zoom_bar & top_heatmap & bottom_heatmap).add_selection(
    bottom_metric_selection,
    bottom_selection,
    top_metric_selection,
    top_selection,
)

final_heatmap


In [9]:
# Save the plot to HTML for the website
final_heatmap.save(output_html)


# END