## Analyze stability effects

In [34]:
import pandas as pd
import numpy as np
import altair as alt
import polyclonal
import theme

alt.themes.register('main_theme', theme.main_theme)
alt.themes.enable('main_theme')

alt.data_transformers.disable_max_rows()

DataTransformerRegistry.enable('default')

In [25]:
# read in structure mapping
site_map = pd.read_csv('../data/site_numbering_map.csv')
site_map.head()

Unnamed: 0,sequential_site,reference_site,sequential_wt,region,rbs_region
0,1,1,Q,HA1,outside RBS
1,2,2,K,HA1,outside RBS
2,3,3,I,HA1,outside RBS
3,4,4,P,HA1,outside RBS
4,5,5,G,HA1,outside RBS


In [26]:
# read in data
stability_data = pd.read_csv('../results/stability/averages/stability_mut_effect.csv')
print(f'There are {len(stability_data)} stability measurements.')

stability_data.head()

There are 7373 stability measurements.


Unnamed: 0,epitope,site,wildtype,mutant,mutation,stability_mean,stability_median,stability_std,n_models,times_seen,frac_models,LibA-240928-pH,LibB-240928-pH
0,1,1,Q,A,Q1A,0.004237,0.004237,0.04109,2,5.5,1.0,0.03329,-0.02481
1,1,1,Q,C,Q1C,-0.0143,-0.0143,0.01123,2,4.5,1.0,-0.006359,-0.02224
2,1,1,Q,D,Q1D,-0.0219,-0.0219,0.007839,2,5.0,1.0,-0.02744,-0.01636
3,1,1,Q,E,Q1E,0.00689,0.00689,0.01096,2,7.0,1.0,0.01464,-0.000862
4,1,1,Q,F,Q1F,-0.001402,-0.001402,0.006532,2,6.5,1.0,0.003217,-0.006021


In [27]:
func_data = pd.read_csv('../results/func_effects/averages/MDCKSIAT1_entry_func_effects.csv')
print(f'There are {len(func_data)} cell entry measurements.')

func_data.head()

There are 10401 cell entry measurements.


Unnamed: 0,site,wildtype,mutant,effect,effect_std,times_seen,n_selections
0,1,Q,*,-4.945,0.0,16.25,4
1,1,Q,A,-0.1226,0.2296,7.5,4
2,1,Q,C,-0.5732,0.5667,5.75,4
3,1,Q,D,0.255,0.3448,6.5,4
4,1,Q,E,0.2941,0.0502,9.0,4


In [28]:
combined_data = pd.merge( 
    stability_data,
    func_data,
    on=['site', 'wildtype', 'mutant'], 
    how='outer',
    suffixes=('_stability', '_func')
)

combined_data.head()

Unnamed: 0,epitope,site,wildtype,mutant,mutation,stability_mean,stability_median,stability_std,n_models,times_seen_stability,frac_models,LibA-240928-pH,LibB-240928-pH,effect,effect_std,times_seen_func,n_selections
0,1.0,1,Q,A,Q1A,0.004237,0.004237,0.04109,2.0,5.5,1.0,0.03329,-0.02481,-0.1226,0.2296,7.5,4
1,1.0,1,Q,C,Q1C,-0.0143,-0.0143,0.01123,2.0,4.5,1.0,-0.006359,-0.02224,-0.5732,0.5667,5.75,4
2,1.0,1,Q,D,Q1D,-0.0219,-0.0219,0.007839,2.0,5.0,1.0,-0.02744,-0.01636,0.255,0.3448,6.5,4
3,1.0,1,Q,E,Q1E,0.00689,0.00689,0.01096,2.0,7.0,1.0,0.01464,-0.000862,0.2941,0.0502,9.0,4
4,1.0,1,Q,F,Q1F,-0.001402,-0.001402,0.006532,2.0,6.5,1.0,0.003217,-0.006021,-0.7141,0.6042,7.0,4


In [30]:
boxplot = alt.Chart(
    combined_data.assign(
        stability_measured=combined_data['stability_mean'].notna(),
    ).query('mutant != wildtype')
).mark_boxplot(extent='min-max', color='#b3b3b3', size=40).encode(
    y=alt.Y(
        "effect",
        title=(["Effect on cell entry in", "MA22 background"]),
        axis=alt.Axis(
            tickCount=3,
        ),
    ),
    x=alt.X(
        "stability_measured",
        title=(["Stability measured"]),
        axis=alt.Axis(
            grid=False,
        ),
    ),
).properties(
    width=150,
    height=150
)

boxplot

In [31]:
combined_data_ann = pd.merge(
    combined_data.query(
        'times_seen_stability >= 2 and n_models >= 2'
    ),
    site_map,
    left_on=['site', 'wildtype'], 
    right_on=['reference_site', 'sequential_wt'], 
).drop(
    columns=['sequential_site', 'reference_site', 'sequential_wt']
)
combined_data_ann.head()

Unnamed: 0,epitope,site,wildtype,mutant,mutation,stability_mean,stability_median,stability_std,n_models,times_seen_stability,frac_models,LibA-240928-pH,LibB-240928-pH,effect,effect_std,times_seen_func,n_selections,region,rbs_region
0,1.0,1,Q,A,Q1A,0.004237,0.004237,0.04109,2.0,5.5,1.0,0.03329,-0.02481,-0.1226,0.2296,7.5,4,HA1,outside RBS
1,1.0,1,Q,C,Q1C,-0.0143,-0.0143,0.01123,2.0,4.5,1.0,-0.006359,-0.02224,-0.5732,0.5667,5.75,4,HA1,outside RBS
2,1.0,1,Q,D,Q1D,-0.0219,-0.0219,0.007839,2.0,5.0,1.0,-0.02744,-0.01636,0.255,0.3448,6.5,4,HA1,outside RBS
3,1.0,1,Q,E,Q1E,0.00689,0.00689,0.01096,2.0,7.0,1.0,0.01464,-0.000862,0.2941,0.0502,9.0,4,HA1,outside RBS
4,1.0,1,Q,F,Q1F,-0.001402,-0.001402,0.006532,2.0,6.5,1.0,0.003217,-0.006021,-0.7141,0.6042,7.0,4,HA1,outside RBS


### Mean stability effects across HA

In [32]:
effect_filtered_data = combined_data_ann.query(
    'effect > -3'
)

mean_df = effect_filtered_data.assign(
    mean_stability=effect_filtered_data.groupby('site')['stability_mean'].transform('mean')
)[['site', 'wildtype', 'mean_stability', 'region', 'rbs_region']].drop_duplicates()

site_to_i = {site: i for i, site in enumerate(mean_df['site'].unique())}
mean_df = mean_df.assign(_stat_site_order=lambda x: x["site"].map(site_to_i))

chart = (
    alt.Chart(mean_df)
        .mark_line(opacity=1, stroke='#586F7C', size=1)
        .encode(
            alt.X(
                "site:O",
                sort=alt.EncodingSortField(field="_stat_site_order", order="ascending"), 
                title='Site',
                axis=alt.Axis(
                    labelAngle=0,
                    values=[100, 200, 300, 400, 500],
                    tickCount=5,
                    grid=True
                )
            ),
            alt.Y(
                "mean_stability:Q", 
                title=["Mean effect on", "acid stability"],
                scale=alt.Scale(domain=[-1, 0.25]),
                axis=alt.Axis(
                    grid=False,
                    values=[-1, -0.5, 0],
                )
            ),
            tooltip=['wildtype', "site", "mean_stability", "region"],
        )
).properties(
    width=400,
    height=125
)

hline = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(
        color='#CC6677',
        size=1.25,
        opacity=1,
        strokeDash=[6,6]
).encode(y='y:Q')

hline + chart

In [35]:
np.random.seed(99)
combined_data_ann['jitter'] = np.random.normal(0, 0.1, size=len(combined_data_ann))

scatter = alt.Chart(
    combined_data_ann.query(
        'site in [165, 167]'
    ).query(
        'effect > -3 and wildtype != mutant'
    ),
).mark_circle(size=50, opacity=1,stroke='black', strokeWidth=0.4).encode(
    x=alt.X(
        "site:O",
        title='Site',
        axis=alt.Axis(labelAngle=0)
    ),
    y=alt.Y(
        "stability_mean",
        scale=alt.Scale(domain=[-1, 0.3]),
        title=['Mutation effect on', 'acid stability'],
    ),
    xOffset=alt.X('jitter:Q'),
    color=alt.Color(
        'site:O',
        legend=None,
        scale=alt.Scale(
            domain=[165, 167],
            range=['#8DA0CB', '#FFD92F']
        )
    ),
    tooltip=['stability_mean', 'site', 'mutant', 'wildtype']
).properties(
    height=175,
    width=150
)

hline = alt.Chart().mark_rule(
        color='black',
        size=1.5,
        opacity=1.0,
        strokeDash=[6,6]
    ).encode(y=alt.Y(datum=0))

scatter + hline

In [36]:
# Define amino acid charge groups
pos = {'K', 'R', 'H'}
neg = {'D', 'E'}

combined_data_ann = combined_data_ann.assign(
    charge_class=lambda df: df['mutant'].str[-1].map(
        lambda aa: 'Positive' if aa in pos else 'Negative' if aa in neg else 'Other'
    )
)

salt_bridge_df = combined_data_ann.query(
    'site in [89, 109, 269, 396]'
).query(
    'effect > -3 and wildtype != mutant'
).query('charge_class != "Other"')

scatter = alt.Chart(
    salt_bridge_df
).mark_circle(size=50, opacity=1,stroke='black', strokeWidth=0.4).encode(
    x=alt.X(
        "site:O",
        title='Site',
        axis=alt.Axis(labelAngle=0)
    ),
    y=alt.Y(
        "stability_mean",
        scale=alt.Scale(domain=[-1, 0.3]),
        title=['Mutation effect on', 'acid stability'],
    ),
    color=alt.Color(
        'charge_class:O',
        legend=None,
        scale=alt.Scale(
            domain=['Positive', 'Negative'],
            range=['#A6D854', '#E78AC3']
        )
    ),
    tooltip=['stability_mean', 'site', 'mutant', 'wildtype', 'charge_class']
).properties(
    height=175,
    width=150
)

hline = alt.Chart().mark_rule(
        color='black',
        size=1.5,
        opacity=1.0,
        strokeDash=[6,6]
    ).encode(y=alt.Y(datum=0))


hline + scatter

### Correlation between cell entry and stability effects

In [37]:
r_value = combined_data_ann.query('mutant != wildtype')['stability_mean'].corr(
    combined_data_ann.query('mutant != wildtype')['effect'], method = 'pearson'
)
r_text = f"r = {r_value:.2f}"
print(r_text)

r = 0.37


In [38]:
# Base scatter plot
base = alt.Chart(combined_data_ann.query('mutant != wildtype')).encode(
    x=alt.X(
        "stability_mean",
        title=(["Effect on acid stability"]),
    ),
    y=alt.Y(
        "effect",
        title=(["Effect on cell entry"]),
    ),
    tooltip=[
        'site', 'wildtype', 'mutant', 'region', 'rbs_region', 'stability_mean', 'effect'
    ]
).properties(
    width=200,
    height=200
).mark_circle(
    size=50,
    opacity=0.1,
    stroke=None,
    strokeWidth=0
).encode(
    color=alt.value('black')
)

# Horizontal line at y = 0
hline = alt.Chart().mark_rule(
        color='#CC6677',
        size=1.5,
        opacity=1.0,
        strokeDash=[6,6]
    ).encode(y=alt.Y(datum=0))

# Vertical line at x = 0
vline = alt.Chart().mark_rule(
        color='#CC6677',
        size=1.5,
        opacity=1.0,
        strokeDash=[6,6]
    ).encode(x=alt.X(datum=0))

r_label = alt.Chart(pd.DataFrame({'text': [r_text]})).mark_text(
    align='left',
    baseline='top',
    fontSize=16,
    fontWeight='normal',
    color='black'
).encode(
    text='text:N',
    x=alt.value(5),  # X position in pixels
    y=alt.value(5)   # Y position in pixels
)

# Combine scatter plot with lines
scatter_plot = alt.layer(base, hline, vline, r_label)

# Marginal histogram for the x-axis (stability)
hist_x = alt.Chart(combined_data_ann.query('mutant != wildtype')).mark_bar(
    opacity=1,
    color='#586F7C'
).encode(
    x=alt.X('stability_mean:Q', bin=alt.Bin(maxbins=50), title='', axis=alt.Axis(labels=False, ticks=False)),
    y=alt.Y('count()', title='Count'),
    tooltip=[]
).properties(
    width=200,
    height=50
)

# Marginal histogram for the y-axis (cell entry)
hist_y = alt.Chart(combined_data_ann.query('mutant != wildtype')).mark_bar(
    opacity=1,
    color='#586F7C'
).encode(
    x=alt.X('count()', title='Count'),
    y=alt.Y('effect:Q', bin=alt.Bin(maxbins=50), title='', axis=alt.Axis(labels=False, ticks=False)),
    tooltip=[]
).properties(
    width=50,
    height=200
)

# Combine the scatter plot with the marginal histograms
marginal_plot = alt.vconcat(
    hist_x,
    alt.hconcat(
        scatter_plot,
        hist_y
    )
)

# Display the chart
marginal_plot

### Conservation of sites with destabilizing mutations

In [39]:
# read in 60y entropy of sites
entropy_df = pd.concat(
    [pd.read_csv(
        'data/nextstrain_groups_blab_flu_seasonal_h3n2_ha1_60y_diversity.tsv', sep = '\t'
    ),
    pd.read_csv(
        'data/nextstrain_groups_blab_flu_seasonal_h3n2_ha2_60y_diversity.tsv', sep = '\t'
    ).assign(position=lambda x: x['position'] + 329)]
).rename(
    columns={'position': 'site'}
).drop(columns=['gene'])

entropy_df.head()

Unnamed: 0,site,entropy
0,1,0.034
1,2,0.271
2,3,0.719
3,4,0.032
4,5,0.169


In [46]:
# write out dataframe for plotting with entropy
mean_stability_and_entropy = pd.merge(
    effect_filtered_data,
    entropy_df,
    on='site',
    how='left'
)

mean_stability_and_entropy.head()

Unnamed: 0,epitope,site,wildtype,mutant,mutation,stability_mean,stability_median,stability_std,n_models,times_seen_stability,frac_models,LibA-240928-pH,LibB-240928-pH,effect,effect_std,times_seen_func,n_selections,region,rbs_region,entropy
0,1.0,1,Q,A,Q1A,0.004237,0.004237,0.04109,2.0,5.5,1.0,0.03329,-0.02481,-0.1226,0.2296,7.5,4,HA1,outside RBS,0.034
1,1.0,1,Q,C,Q1C,-0.0143,-0.0143,0.01123,2.0,4.5,1.0,-0.006359,-0.02224,-0.5732,0.5667,5.75,4,HA1,outside RBS,0.034
2,1.0,1,Q,D,Q1D,-0.0219,-0.0219,0.007839,2.0,5.0,1.0,-0.02744,-0.01636,0.255,0.3448,6.5,4,HA1,outside RBS,0.034
3,1.0,1,Q,E,Q1E,0.00689,0.00689,0.01096,2.0,7.0,1.0,0.01464,-0.000862,0.2941,0.0502,9.0,4,HA1,outside RBS,0.034
4,1.0,1,Q,F,Q1F,-0.001402,-0.001402,0.006532,2.0,6.5,1.0,0.003217,-0.006021,-0.7141,0.6042,7.0,4,HA1,outside RBS,0.034


In [49]:
mean_df_with_wt = mean_stability_and_entropy.assign(
    mean_stability=mean_stability_and_entropy.groupby('site')['stability_mean'].transform('mean')
)[['site', 'wildtype', 'mean_stability', 'entropy']].drop_duplicates()

base = alt.Chart(
    mean_df_with_wt
).encode(
    x=alt.X(
        "mean_stability",
        title=(["Mean effect on acid stability"]),
    ),
    y=alt.Y(
        "entropy",
        scale=alt.Scale(domain=[-0.1, 1.7]),
        title=(["Site entropy", "(in natural sequences)"]),
    ),
    tooltip=[
        'site', 'wildtype', 'mean_stability', 'entropy'
    ]
).properties(
    width=200,
    height=200
).mark_circle(
    size=40,
    opacity=1,
    color='#DADAEB',
    stroke='black',
    strokeWidth=0.2
)

# Horizontal line
hline = alt.Chart().mark_rule(
        color='black',
        size=1,
        opacity=1.0,
        strokeDash=[6,6]
    ).encode(y=alt.Y(datum=0.2))

# Vertical line
vline = alt.Chart().mark_rule(
        color='black',
        size=1,
        opacity=1.0,
        strokeDash=[6,6]
    ).encode(x=alt.X(datum=-0.1))

scatter_plot = alt.layer(base, hline, vline)
scatter_plot.display()

In [50]:
# print variable sites
mean_df_with_wt.query(
      'entropy > 0.2 and mean_stability < -0.1'
   )['site'].tolist()

[196, 202, 219, 223, 227, 450, 452]