This Python Jupyter notebook plots escape profiles for antibodies and sera.
Import Python modules:
import multiprocessing
import os
import Bio.SeqIO
import dms_variants
import dms_variants.constants
import dms_variants.utils
import dmslogo
from dmslogo.colorschemes import CBPALETTE
import dmslogo.utils
from IPython.display import display, HTML, Image
import matplotlib.cm
import matplotlib.colors
import matplotlib.pyplot as plt
import numpy
import pandas as pd
from plotnine import *
import pdb_prot_align.colorschemes
import yaml
Versions of key software:
print(f"Using `dmslogo` version {dmslogo.__version__}")
print(f"Using `dms_variants` version {dms_variants.__version__}")
Using `dmslogo` version 0.6.2
Using `dms_variants` version 0.8.10
Read the configuration file:
with open('config.yaml') as f:
config = yaml.safe_load(f)
Create output directory:
os.makedirs(config['escape_profiles_dir'], exist_ok=True)
Extract from configuration what we will use as the site- and mutation-level metrics:
site_metric = config['site_metric']
mut_metric = config['mut_metric']
print(f"At site level, quantifying selection by {site_metric}")
print(f"At mutation level, quantify selection by {mut_metric}")
At site level, quantifying selection by site_total_escape_frac_single_mut
At mutation level, quantify selection by mut_escape_frac_single_mut
Read the sites of "strong escape" for each antibody / sera. These auto-identified sites are what we plot by default for each escape profile:
print(f"Reading sites of strong escape from {config['strong_escape_sites']}")
strong_escape_sites = pd.read_csv(config['strong_escape_sites'])
Reading sites of strong escape from results/escape_profiles/strong_escape_sites.csv
Read the escape fractions.
We only retain the average of the libraries for plotting here, not the individual libraries.
Also, we work in the full-Spike rather than RBD numbering, which means we use label_site
as site
(and so rename as such below):
print(f"Reading escape fractions from {config['escape_fracs']}")
escape_fracs = (pd.read_csv(config['escape_fracs'])
.query('library == "average"')
.drop(columns=['site', 'selection', 'library'])
.rename(columns={'label_site': 'site'})
)
print('First few lines of escape-fraction data frame with sample-information added:')
display(HTML(escape_fracs.head().to_html(index=False)))
Reading escape fractions from results/escape_scores/escape_fracs.csv
First few lines of escape-fraction data frame with sample-information added:
condition | site | wildtype | mutation | protein_chain | protein_site | mut_escape_frac_single_mut | site_total_escape_frac_single_mut | site_avg_escape_frac_single_mut | nlibs |
---|---|---|---|---|---|---|---|---|---|
267C_200 | 331 | N | A | E | 331 | 0.003015 | 0.05336 | 0.002808 | 2 |
267C_200 | 331 | N | C | E | 331 | 0.003249 | 0.05336 | 0.002808 | 2 |
267C_200 | 331 | N | D | E | 331 | 0.001829 | 0.05336 | 0.002808 | 2 |
267C_200 | 331 | N | E | E | 331 | 0.004172 | 0.05336 | 0.002808 | 2 |
267C_200 | 331 | N | F | E | 331 | 0.001581 | 0.05336 | 0.002808 | 2 |
Some sites / mutations are totally missing in the escape_fracs
data frame.
For plotting with dmslogo
, we need to pad these missing sites to be zero:
# make "padding" data frame covering all conditions, sites, and mutations
first_site = escape_fracs['site'].min()
last_site = escape_fracs['site'].max()
mutations = escape_fracs['mutation'].unique()
pad_df = pd.concat([pd.DataFrame({'condition': condition,
'site': site,
'mutation': mutations})
for condition in escape_fracs['condition'].unique()
for site in range(first_site, last_site + 1)])
# need to read in wildtype and map site to wildtype
wt_prot = str(Bio.SeqIO.read(config['wildtype_sequence'], 'fasta').seq.translate())
assert len(wt_prot) == last_site - first_site + 1
site_to_wt = {site: wt_prot[site - first_site] for site in range(first_site, last_site + 1)}
for site, wt in escape_fracs.set_index('site')['wildtype'].to_dict().items():
if wt != site_to_wt[site]:
raise ValueError(site, wt, site_to_wt[site])
# pad escape fracs data frame
escape_fracs_padded = (
escape_fracs
[['condition', 'site', 'mutation', site_metric, mut_metric]]
.merge(pad_df, how='outer')
.fillna(0)
.assign(wildtype=lambda x: x['site'].map(site_to_wt),
wt_site=lambda x: x['wildtype'] + x['site'].astype(str))
.assign(**{site_metric: lambda x: x.groupby(['condition', 'site'])[site_metric].transform('max')})
)
We add columns to the data frame that enable coloring of the logo plots by the deep mutational scanning measurements of binding or expression. We choose a color scheme that spans the min and maximum values for all letters with non-zero height (the mutation-level metric > 0). We also write out the scale bar for this coloring.
Importantly, note that as long as clip_vals_gt_0
is set to True
, then all DMS values greater than 0 (beneficial mutations) are set to zero. For publications, the scale bar should be manually edited to change "0" to ">0" to reflect this fact.
mut_bind_expr_file = config['final_variant_scores_mut_file']
clip_vals_gt_0 = True # plot DMS values > 0 as 0 (grouping beneficial and neutral)
print(f"Reading DMS data from {mut_bind_expr_file}")
# read DMS data flagging mutations with escape > 0
mut_bind_expr = (
pd.read_csv(mut_bind_expr_file)
.query('target==@config["primary_target"]')
[['position', 'mutant', 'delta_bind', 'delta_expr']]
.rename(columns={'position': 'site',
'mutant': 'mutation',
'delta_bind': 'bind',
'delta_expr': 'expr'})
# flag mutations with mutation escape > 0
.merge(escape_fracs_padded, how='right', validate='one_to_many', on=['site', 'mutation'])
.assign(escape_gt_0=lambda x: x[mut_metric] > 0)
.groupby(['site', 'mutation', 'bind', 'expr'])
.aggregate(escape_gt_0=pd.NamedAgg('escape_gt_0', 'any'))
.reset_index()
.drop_duplicates()
)
# add color for each mutation, coloring those without escape > 0 as white
for prop in ['bind', 'expr']:
# set up color scale and draw scale bard
min_prop = mut_bind_expr.query('escape_gt_0')[prop].min()
if clip_vals_gt_0:
mut_bind_expr[prop] = numpy.clip(mut_bind_expr[prop], None, 0)
max_prop = mut_bind_expr.query('escape_gt_0')[prop].max()
# get YlOrBr color map (https://matplotlib.org/3.1.0/tutorials/colors/colormaps.html),
# but start 20% in so not too faint: https://stackoverflow.com/a/18926541
nsegments = 256
cmap = matplotlib.colors.LinearSegmentedColormap.from_list(
name='trunc_YlOrBr',
colors=matplotlib.cm.get_cmap('YlOrBr', lut=256)(numpy.linspace(0.2, 1, 256))
)
colormap = pdb_prot_align.colorschemes.ValueToColorMap(
minvalue=min_prop,
maxvalue=max_prop,
cmap=cmap,
)
for orientation in ['horizontal', 'vertical']:
scalebar_file = os.path.join(config['escape_profiles_dir'], f"{prop}_scalebar_{orientation}.pdf")
print(f"\n{prop} ranges from {min_prop} to {max_prop}, here is the scale bar, which is being saved to {scalebar_file}")
fig, _ = colormap.scale_bar(orientation=orientation,
label={'bind': 'ACE2 binding',
'expr': 'RBD expression'}[prop])
fig.savefig(scalebar_file, bbox_inches='tight')
fig.savefig(os.path.splitext(scalebar_file)[0]+ '.svg', bbox_inches='tight')
display(fig)
plt.close(fig)
# add color to data frame of DMS data
mut_bind_expr[f"{prop}_color"] = mut_bind_expr.apply(lambda r: colormap.val_to_color(r[prop]) if r['escape_gt_0'] else 'white',
axis=1)
# save to file the color scheme
print(f"Saving DMS color scheme to {config['escape_profiles_dms_colors']}")
(mut_bind_expr
.query('escape_gt_0')
.drop(columns='escape_gt_0')
.to_csv(config['escape_profiles_dms_colors'], index=False)
)
# add DMS coloring to escape fractions data frame
escape_fracs_padded = (
escape_fracs_padded
.drop(columns=['bind_color', 'expr_color'], errors='ignore')
.merge(mut_bind_expr[['site', 'mutation', 'bind_color', 'expr_color']],
how='left',
validate='many_to_one',
on=['site', 'mutation'])
# assign colors that are NaN and have 0 height to be white
.assign(bind_color=lambda x: x['bind_color'].where((x[mut_metric] != 0) | x['bind_color'].notnull(),
'white'),
expr_color=lambda x: x['expr_color'].where((x[mut_metric] != 0) | x['expr_color'].notnull(),
'white'),
)
)
# check no letters have NaN colors
nan_color = (
escape_fracs_padded
.query('bind_color.isnull() | expr_color.isnull()')
)
if len(nan_color):
raise ValueError(f"The following entries lack colors:\n{nan_color}")
Reading DMS data from results/final_variant_scores/final_variant_scores.csv
bind ranges from -1.83171 to 0.0, here is the scale bar, which is being saved to results/escape_profiles/bind_scalebar_horizontal.pdf
bind ranges from -1.83171 to 0.0, here is the scale bar, which is being saved to results/escape_profiles/bind_scalebar_vertical.pdf
Saving DMS color scheme to results/escape_profiles/escape_profiles_dms_colors.csv
expr ranges from -0.74994 to 0.0, here is the scale bar, which is being saved to results/escape_profiles/expr_scalebar_horizontal.pdf
expr ranges from -0.74994 to 0.0, here is the scale bar, which is being saved to results/escape_profiles/expr_scalebar_vertical.pdf
Saving DMS color scheme to results/escape_profiles/escape_profiles_dms_colors.csv
We have manually specified configurations for escape profiles in a YAML file:
print(f"Reading escape-profile configuration from {config['escape_profiles_config']}")
with open(config['escape_profiles_config']) as f:
escape_profiles_config = yaml.safe_load(f)
print(f"Reading the site color schemes from {config['site_color_schemes']}")
site_color_schemes = pd.read_csv(config['site_color_schemes'])
Reading escape-profile configuration from data/escape_profiles_config.yaml
Reading the site color schemes from data/site_color_schemes.csv
Now start making these plots. We will do this in a multiprocessing queue, so first define the function that makes each actual plot:
def draw_profile(tup):
"""Takes single argument `tup` which is a tuple with the following elements:
- df: pandas.DataFrame with data of interest.
- site_metric: site-level metric
- show_color: highlighting color in site-level plots
- draw_line_plot: do we draw line plot?
- mut_metric: mutation-level metric
- color: column with logo plot color
- dmslogo_facet_plot_kwargs: other kwargs
- dmslogo_draw_logo_kwargs: any kwargs for draw_logo
- dmslogo_draw_line_kwargs: any kwargs for draw_line
- pdffile: name of created PDF
- pngfile: name of created PNG
- svgfile: name of created SVG
Returns 2-tuple `(pdffile, pngfile)`
"""
(df, site_metric, show_color, draw_line_plot, mut_metric, color, dmslogo_facet_plot_kwargs,
dmslogo_draw_logo_kwargs, dmslogo_draw_line_kwargs, pdffile, pngfile, svgfile) = tup
fig, ax = dmslogo.facet_plot(
data=df,
x_col='site',
show_col='to_show',
gridrow_col='condition',
share_ylim_across_rows=False,
draw_line_kwargs=(dict({'height_col': site_metric,
'ylabel': 'escape fraction',
'widthscale': 1.3,
'show_color': show_color
},
**dmslogo_draw_line_kwargs)
if draw_line_plot else None
),
draw_logo_kwargs=dict({'letter_height_col': mut_metric,
'letter_col': 'mutation',
'ylabel': 'escape fraction',
'color_col': color,
'xtick_col': 'wt_site',
'xlabel': 'site',
'shade_color_col': 'shade_color' if 'shade_color' in df.columns else None,
'shade_alpha_col': 'shade_alpha' if 'shade_alpha' in df.columns else None
},
**dmslogo_draw_logo_kwargs,
),
share_xlabel=True,
share_ylabel=True,
**dmslogo_facet_plot_kwargs,
)
fig.savefig(pdffile, dpi=300, transparent=True, bbox_inches='tight')
fig.savefig(pngfile, dpi=300, transparent=True, bbox_inches='tight')
fig.savefig(svgfile, transparent=True, bbox_inches='tight')
plt.close(fig)
return pdffile, pngfile
Now make a list of the input tuples we pass the draw_profiles
function above for everything we want to draw:
draw_profile_tups = []
for name, specs in escape_profiles_config.items():
# get data frame with just the conditions we want to plot, also re-naming them
conditions_to_plot = list(specs['conditions'].keys())
assert len(conditions_to_plot) == len(set(specs['conditions'].values()))
assert set(conditions_to_plot).issubset(set(escape_fracs_padded['condition']))
df = (escape_fracs_padded
.query('condition in @conditions_to_plot')
.assign(condition=lambda x: x['condition'].map(specs['conditions']))
)
# see if we are only plotting single-nt available mutations
if 'single_nt_only' in specs:
if not numpy.allclose(df[site_metric],
df.groupby(['condition', 'site'])[mut_metric].transform('sum'),
atol=1e-3,
rtol=1e-3):
raise ValueError('single_nt_only plotting requires re-computation of the site metric, so it must be the '
'simple sum of the mutation metric')
nt_seq = str(Bio.SeqIO.read(specs['single_nt_only'], 'genbank').seq)
df = (
df
.assign(codon=lambda x: x['site'].map(lambda r: nt_seq[3 * (r - 1): 3 * r]),
codon_aa=lambda x: x['codon'].map(dms_variants.constants.CODON_TO_AA),
single_nt_accessible=lambda x: x.apply(
lambda row: dms_variants.utils.single_nt_accessible(row['codon'],
row['mutation'],
'false'),
axis=1)
)
)
# check nucleotide sequence encodes correct amino-acid sequence
if any(df['wildtype'] != df['codon_aa']):
raise ValueError('nucleotide sequence differs DMS RBD amino-acid sequence')
# subset on mutations of interest, re-compute site metric
df = df.query('single_nt_accessible')
df[site_metric] = df.groupby(['condition', 'site'])[mut_metric].transform('sum')
# specify order to plot
df = df.assign(condition=lambda x: pd.Categorical(x['condition'], specs['conditions'].values(), ordered=True))
# get the sites we want to show in logo plots
sites_to_show = []
if specs['plot_auto_identified_sites']:
threshold = specs['plot_auto_identified_sites']
if threshold not in strong_escape_sites['threshold'].unique():
raise ValueError(f"invalid `plot_auto_identified_sites of {threshold}\n"
f"valid values are: {strong_escape_sites['threshold'].unique()}")
sites_to_show += (strong_escape_sites
.query('condition in @conditions_to_plot')
.query('threshold == @threshold')
['site']
.unique()
.tolist()
)
sites_to_show += specs['add_sites']
sites_to_show = set(sites_to_show) - set(specs['exclude_sites'])
df = df.assign(to_show=lambda x: x['site'].isin(sites_to_show))
# is there a site color scheme?
color_col = None
if 'site_color_scheme' in specs:
color_col = 'color'
if specs['site_color_scheme'] in site_color_schemes.columns:
# color scheme specified by site
site_colors = site_color_schemes.set_index('site')[specs['site_color_scheme']].to_dict()
df = df.assign(color=lambda x: x['site'].map(site_colors))
else:
# color specified for all sites
df = df.assign(color=specs['site_color_scheme'])
# is there a mutation color specification?
if 'mutation_colors' in specs:
assert 'site_color_scheme' in specs, 'must specify site-color scheme for mutation colors'
def mut_color(row):
key = f"{row['mutation']}{row['site']}"
if key in specs['mutation_colors']:
return specs['mutation_colors'][key]
else:
return row[color_col]
df = df.assign(color=lambda x: x.apply(mut_color, axis=1))
# shade any sites
if 'shade_sites' in specs:
shade_records = []
for condition, shade_sites in specs['shade_sites'].items():
for site, (shade_color, shade_alpha) in shade_sites.items():
shade_records.append((condition, site, shade_color, shade_alpha))
shade_df = pd.DataFrame.from_records(shade_records,
columns=['condition', 'site', 'shade_color', 'shade_alpha'])
df = df.merge(shade_df,
on=['condition', 'site'],
how='left',
validate='many_to_one',
)
# get any additional logo plot arguments
if 'dmslogo_facet_plot_kwargs' in specs:
dmslogo_facet_plot_kwargs = specs['dmslogo_facet_plot_kwargs']
else:
dmslogo_facet_plot_kwargs = {}
# get y-axis limits, see here: https://jbloomlab.github.io/dmslogo/set_ylims.html
if 'escape_profile_ymax' in specs: # specific y-max set for this plot
escape_profile_ymax_quantile = specs['escape_profile_ymax']['quantile']
escape_profile_ymax_frac = specs['escape_profile_ymax']['frac']
if 'min_ymax' in specs['escape_profile_ymax']:
escape_profile_min_ymax = specs['escape_profile_ymax']['min_ymax']
else:
escape_profile_min_ymax = None
else: # use default in config
escape_profile_ymax_quantile = config['escape_profile_ymax']['quantile']
escape_profile_ymax_frac = config['escape_profile_ymax']['frac']
if 'min_ymax' in config['escape_profile_ymax']:
escape_profile_min_ymax = config['escape_profile_ymax']['min_ymax']
else:
escape_profile_min_ymax = None
ylim_setter = dmslogo.utils.AxLimSetter(max_from_quantile=(escape_profile_ymax_quantile,
escape_profile_ymax_frac),
datalim_pad=0.06,
min_upperlim=escape_profile_min_ymax)
ylim_setter_nopad = dmslogo.utils.AxLimSetter(max_from_quantile=(escape_profile_ymax_quantile,
escape_profile_ymax_frac),
datalim_pad=0,
min_upperlim=escape_profile_min_ymax)
ylims = {}
ylims_nopad = {} # unpadded are written to file giving min / max of logos
for condition, condition_df in df.groupby('condition'):
ylims[condition] = ylim_setter.get_lims(condition_df
[['site', site_metric]]
.drop_duplicates()
[site_metric]
)
ylims_nopad[condition] = ylim_setter_nopad.get_lims(condition_df
[['site', site_metric]]
.drop_duplicates()
[site_metric]
)
if 'set_ylims' not in dmslogo_facet_plot_kwargs: # do not overwrite manual y-limits
dmslogo_facet_plot_kwargs['set_ylims'] = ylims
else:
ylims_nopad = dmslogo_facet_plot_kwags['set_ylims']
# write the ylimits
ylims_csv = os.path.join(config['escape_profiles_dir'], f"{name}_stackedlogo_ylims.csv")
(
pd.DataFrame.from_dict(ylims_nopad, orient='index', columns=['minimum', 'maximum'])
.rename_axis('condition')
.to_csv(ylims_csv, float_format='%.3f')
)
# draw plot for each color scheme
colors_plotfiles = [(color_col, os.path.join(config['escape_profiles_dir'], f"{name}_stackedlogo.pdf"))]
if 'color_by_dms' in specs and specs['color_by_dms']:
colors_plotfiles += [('bind_color', os.path.join(config['escape_profiles_dir'], f"{name}_color_by_bind_stackedlogo.pdf")),
('expr_color', os.path.join(config['escape_profiles_dir'], f"{name}_color_by_expr_stackedlogo.pdf"))]
for color, pdffile in colors_plotfiles:
pngfile = os.path.splitext(pdffile)[0] + '.png'
svgfile = os.path.splitext(pdffile)[0] + '.svg'
draw_profile_tups.append((
df, site_metric, CBPALETTE[-1],
'draw_line_plot' not in specs or specs['draw_line_plot'],
mut_metric, color,
dmslogo_facet_plot_kwargs,
specs['dmslogo_draw_logo_kwargs'] if 'dmslogo_draw_logo_kwargs' in specs else {},
specs['dmslogo_draw_line_kwargs'] if 'dmslogo_draw_line_kwargs' in specs else {},
pdffile, pngfile, svgfile
))
Now use a multiprocessing queue to map the plotting function to the tuple arguments in the multiprocessing queue:
ncpus = min(config['max_cpus'], multiprocessing.cpu_count())
print(f"Drawing {len(draw_profile_tups)} profiles using {ncpus} CPUs...")
with multiprocessing.Pool(ncpus) as pool:
for i, (tup, (pdffile, pngfile)) in enumerate(zip(draw_profile_tups,
pool.imap(draw_profile, draw_profile_tups)),
start=1):
print(f"\nPlotted profile {i} to:\n {pdffile}\n {pngfile}.")
draw_line_plot = tup[3]
assert isinstance(draw_line_plot, bool)
display(Image(pngfile, width=500 * (1 + int(draw_line_plot))))
Drawing 12 profiles using 8 CPUs...
Plotted profile 1 to:
results/escape_profiles/Delta_breakthrough_stackedlogo.pdf
results/escape_profiles/Delta_breakthrough_stackedlogo.png.
Plotted profile 2 to:
results/escape_profiles/Delta_breakthrough_color_by_bind_stackedlogo.pdf
results/escape_profiles/Delta_breakthrough_color_by_bind_stackedlogo.png.
Plotted profile 3 to:
results/escape_profiles/Delta_breakthrough_color_by_expr_stackedlogo.pdf
results/escape_profiles/Delta_breakthrough_color_by_expr_stackedlogo.png.
Plotted profile 4 to:
results/escape_profiles/Pfizer_stackedlogo.pdf
results/escape_profiles/Pfizer_stackedlogo.png.
Plotted profile 5 to:
results/escape_profiles/Pfizer_color_by_bind_stackedlogo.pdf
results/escape_profiles/Pfizer_color_by_bind_stackedlogo.png.
Plotted profile 6 to:
results/escape_profiles/Pfizer_color_by_expr_stackedlogo.pdf
results/escape_profiles/Pfizer_color_by_expr_stackedlogo.png.
Plotted profile 7 to:
results/escape_profiles/primary_Delta_stackedlogo.pdf
results/escape_profiles/primary_Delta_stackedlogo.png.
Plotted profile 8 to:
results/escape_profiles/primary_Delta_color_by_bind_stackedlogo.pdf
results/escape_profiles/primary_Delta_color_by_bind_stackedlogo.png.
Plotted profile 9 to:
results/escape_profiles/primary_Delta_color_by_expr_stackedlogo.pdf
results/escape_profiles/primary_Delta_color_by_expr_stackedlogo.png.
Plotted profile 10 to:
results/escape_profiles/all_Delta_lib_stackedlogo.pdf
results/escape_profiles/all_Delta_lib_stackedlogo.png.
Plotted profile 11 to:
results/escape_profiles/all_Delta_lib_color_by_bind_stackedlogo.pdf
results/escape_profiles/all_Delta_lib_color_by_bind_stackedlogo.png.
Plotted profile 12 to:
results/escape_profiles/all_Delta_lib_color_by_expr_stackedlogo.pdf
results/escape_profiles/all_Delta_lib_color_by_expr_stackedlogo.png.