# Visualizing Epistatic Shifts in VOCs

Here, I make an interactive plot to visualize Epistatic shifts in the RBD of different variants of concern and an ancestral sequence from Wuhan. 

In [1]:
import itertools 
import pandas as pd
import numpy as np
import altair as alt

# Remove the limit of ~5000 rows
alt.data_transformers.disable_max_rows()

DataTransformerRegistry.enable('default')

## Process the data 

I'll process the raw data from Tyler, rename columns where appropriate, calculate the shift between variants, and define the variants of concern. 

In [2]:
# ## == Filepaths == ##  

# # Input path to variant scores 
# variant_scores_path = "results/final_variant_scores/final_variant_scores.csv"
# # Input path to epistatic shifts (JS distance) 
# epistatic_shifts_path = "results/epistatic_shifts/JSD_versus_Wuhan1_by_target.csv"

# # Output path to wide-form variant scores 
# variant_scores_wide_path = "results/final_variant_scores/final_variant_scores_wide.csv"
# # Output path to HTML plot file
# output_html = "docs/_includes/epistasis.html"

In [None]:
## == Snakemake Filepaths == ##  

# Input path to variant scores 
variant_scores_path = snakemake.input.scores
# Input path to epistatic shifts (JS distance) 
epistatic_shifts_path = snakemake.input.jsd

# Output path to HTML plot file
output_html = snakemake.output.html

In [3]:
# Variant names as a dictionary - ordered as they will appear in the lengend
variants_names = {
    "Wuhan-Hu-1": "Wuhan-Hu-1",
    "Omicron_BA1": "Omicron BA.1", 
    "Omicron_BA2": "Omicron BA.2",
}

# Jensen-Shannon Distance of all pair-wise combinations of variant backgrounds -- for the line plot
jsd_df = (
    pd.read_csv(epistatic_shifts_path)
    .rename(columns = {"bg_1":"target", "bg_2":"background"})
    .replace({"target":variants_names,
             "background":variants_names})
    .drop(columns=["JSD", "JSD_min5bc"])
)


# Expression and Binding scores per variant background -- for the scatter plot
scores_df = (
    pd.read_csv(variant_scores_path)
    .rename(columns = {"position":"site"})
    .replace({"target":variants_names})
)

# Drop the un-used columns
scores_df = scores_df.drop(columns=[column for column in scores_df.columns
                                   if column not in ['target', 'wildtype', 'site', 'mutant',
                                                     'mutation', 'bind', 'delta_bind', 'n_bc_bind']
                                   ]
                          )

## == Re-arrange the scores dataframe for all combinations of variant backgrounds == ##

# Get a list of the background comparisons as a tuple
background_comparisons = list(itertools.permutations({bkg for bkg in scores_df.target}, 2))

# Set up the dataframe for the comparisons 
merged_comparisons = list()
for comparison in background_comparisons: 

    # Filter on each comparison in the tuple, and rename one to background for join
    target_df = scores_df.query(f'target == "{comparison[0]}"')
    background_df = scores_df.query(f'target == "{comparison[1]}"').rename(columns={"target":"background"})
    
    # Rename the remaining columns for the join
    target_df.columns = [column + "_target" 
                         if column not in ["target", "site", "mutant"] 
                         else column 
                         for column in target_df.columns]

    
    background_df.columns = [column + "_background" 
                             if column not in ["background", "site", "mutant"] 
                             else column 
                             for column in background_df.columns]
    
    # Merge the dataframe on "wildtype", "site", and "mutant"
    merged_df = pd.merge(target_df, background_df, how="left")
    
    # Save the combination to a list for concat
    merged_comparisons.append(merged_df)

background_comparison_df = pd.concat(merged_comparisons)

# # Save this dataframe and track with github 
# background_comparison_df.to_csv(variant_scores_wide_path, index=False)

## Define plot-wide parameters

Define the parameters that carry through to multiple plots for easy adjustment. 

In [4]:
# Width of the line plot and zoom bar
width = 1000

# Height of the line plot and scatter plot
height = 300

# How to space the sites on the scatter plot or zoom bar 
min_site = min({site for site in jsd_df.site})
max_site = max({site for site in jsd_df.site})
x_axis_spacing = 5

# Colors and names for the variant backgrounds
variant_names = [variant for variant in variants_names.values()]
variant_colors = ["#999999", "#E69F00", "#CC79A7"]

# Tooltips for the lineplot and names 
line_tooltip = [
    alt.Tooltip("site:O", title="RBD Site"),
    alt.Tooltip('JSD_min3bc:Q', title="Epistatic Shift", format=".2f")
]

# Tooltips for scatterplot and names
scatter_tooltip = [
    alt.Tooltip('target:N', title="Variant"),
    alt.Tooltip('mutation_target:N', title="Variant Mutation"),
    alt.Tooltip('delta_bind_target:Q', title="Variant Binding Difference", format=".2f"),
    alt.Tooltip('n_bc_bind_target:Q', title="Variant Barcodes"),
    alt.Tooltip('background:N', title="Comparator"),   
    alt.Tooltip('mutation_background:N', title="Comparator Mutation"),
    alt.Tooltip('delta_bind_background:Q', title="Comparator Binding Difference", format=".2f"),
    alt.Tooltip('n_bc_bind_background:Q', title="Comparator Barcodes")
]



## Define selections for plots 

Define the selecion objects that define interaction. Many of these are shared between plots and datasets, so it's helpful to define these at the top. 

In [5]:
# Dropdown menu to select the variant comparator
background_dropdown = alt.binding_select(options=list({bckg for bckg in jsd_df.background}), name="Select Comparator Background: ")
background_selection = alt.selection_single(fields=['background'], bind=background_dropdown, init={"background":"Wuhan-Hu-1"})

# Interactive legend to select the variant background
target_selection = alt.selection_single(fields=["target"], bind="legend", init={"target":"Omicron BA.1"}, empty='none')

# Select a site in the RBD from the line plot to investigate in the scatter plot. 
site_selection = alt.selection_single(fields=["site", "target", "background"], empty='none', on="click")

# Zoom bar brush to look closer at a region in the line plot. 
zoom_selection = alt.selection_interval(encodings=['x'], mark=alt.BrushConfig(stroke='black', strokeWidth=2))


## Define the plot objects

The final plot will be comprised of three main components - a **Zoom Bar**, a **Line Plot** of epistatic shift at each site, and a **Scatterplot** of the binding affinity (-log10 Kd) for each amino acid between comparisons. 

### Zoom Bar

In [6]:
## == Zoom bar for the line plot == ## 
zoom_bar = alt.Chart(jsd_df[['site']].drop_duplicates()
    ).mark_rect(
        color='lightgray'
    ).encode(
        x=alt.X('site:O',
                title="site zoom bar",
                axis=alt.Axis(values=list(range(min_site, max_site, x_axis_spacing)))
               )
    ).add_selection(
        zoom_selection
    ).properties(
        width=width,
        height=15
)


### Line Plot

In [7]:
## == Interactive line plot for Jensen Shannon Distance (a.k.a. Epistatic Shift) == ##

# Define the base plot shared between components so that we don't save multiple dataframes redundantly
lineplotbase = alt.Chart(jsd_df
    ).encode(
        x=alt.X('site:O',
                title="RBD site",
                axis=alt.Axis(values=list(range(min_site, max_site, x_axis_spacing)))
               ),
        y=alt.Y("JSD_min3bc:Q",
                title="Epistatic shift in mutational effects"
               ),
        color=alt.Color("target:N",
                        scale=alt.Scale(domain=variant_names, range=variant_colors),
                        legend=alt.Legend(orient="top", title="Variant Background: ")
                ),
        opacity=alt.condition(target_selection, alt.value(1.0), alt.value(0.1)),
)


# Add the line mark
line = (lineplotbase.mark_line(point=True))

# Add the point mark 
point = (lineplotbase.mark_point()
             .encode(
                 tooltip=line_tooltip,
                 size=alt.condition(~site_selection, alt.value(50), alt.value(400)))
             .transform_filter(target_selection)
             .add_selection(site_selection)
)

# Combine the marks and add filters and selections shared between them
line_point = (line + point).transform_filter(
        background_selection
    ).transform_filter(
        zoom_selection
    ).add_selection(
        target_selection,
        background_selection,
        zoom_selection
    ).properties(
        width=width,
        height=height
)


### Scatterplot

In [8]:
## == Scatter Plot for binding affinity between Background and Comparator == ##

# Define the base plot shared between components so that we don't save multiple dataframes redundantly
scatterbase = alt.Chart(
    background_comparison_df
).transform_filter(
    site_selection
).transform_filter(
    target_selection
).transform_filter(
    background_selection
)

# Add the text mark    
aminoacids = (scatterbase
                   .mark_text(size=16)
                   .encode(x=alt.X(
                               'bind_background:Q',
                                title="Binding Affinity (-log10 Kd) Comparator",
                                scale=alt.Scale(domain=[4.5, 12])
                           ),
                           y=alt.Y(
                                'bind_target:Q',
                                title ="Binding Affinity (-log10 Kd) Variant",
                                scale=alt.Scale(domain=[4.5, 12])
                           ),
                           text="mutant:N",
                           tooltip=scatter_tooltip
                          )
              )

# Add the 'title' - altair doesn't support dynamic axes or titles, so it's just a mark.
conditions = (scatterbase
                   .mark_text(align='left', baseline='bottom', fontSize=13, fontWeight='bold')
                   .encode(x=alt.value(0.0), y=alt.value(-2), text='_label:N')
                   .transform_aggregate(groupby=["target", "background", "site"])
                   .transform_calculate(_label = '"Comparing " + datum.target + " against " + datum.background + " at site " + datum.site')
              )

# Add the x-rule indicating the WT binding of comparator 
xrule = (scatterbase
                   .mark_rule(color="red", strokeDash=[12, 6], size=2)
                   .encode(x='bind_background:Q')
                   .transform_filter("datum.wildtype_background == datum.mutant")
          )
    
# Add the y-rule indicating the WT binding of background     
yrule = (scatterbase
                   .mark_rule(color="red", strokeDash=[12, 6], size=2)
                   .encode(y='bind_target:Q')
                   .transform_filter("datum.wildtype_target == datum.mutant")
          )
    

scatterplot_labeled = ((aminoacids + conditions + xrule + yrule)
                       .add_selection(
                            target_selection,
                            background_selection,
                            site_selection)
                       .properties(
                            width = height,
                            height = height)
                      )
    

## Combine the plots 

Merge the plots into a single interactive plot for publishing on GitHub pages. 

In [9]:
# Adjust the marings for rendering, otherwise the scatterplot gets cutoff
alt.renderers.set_embed_options(
    padding={"left": 25, "right": 100, "bottom": 25, "top": 25}
)

RendererRegistry.enable('default')

In [10]:
combined_plot = alt.hconcat(
    (line_point & zoom_bar) | scatterplot_labeled
).configure_point(
    size=50
).configure_axis(
    labelFontSize=13,
    titleFontSize=16
).configure_legend(
    labelFontSize=13,
    titleFontSize=16
)

combined_plot

In [11]:
# Save the plot as HTML with extra padding for the scatterplot
combined_plot.save(output_html, embed_options={"padding":{"left": 25, "right": 100, "bottom": 25, "top": 25}})


# END