## Entrenchment analysis

In [1]:
import os

import pandas as pd 
import altair as alt 
import numpy as np
import scipy
import theme

alt.themes.register('main_theme', theme.main_theme)
alt.themes.enable('main_theme')

alt.data_transformers.disable_max_rows()

DataTransformerRegistry.enable('default')

In [2]:
#phenotypes_df = pd.read_csv('results/h3n2_ha_60y_phenotypes_df.csv')
phenotypes_df = pd.read_csv(snakemake.input.phenotypes_and_freqs)

# site_map = pd.read_csv('../data/site_numbering_map.csv')
site_map = pd.read_csv(snakemake.input.site_numbering_map)

phenotypes_df = pd.merge(
    phenotypes_df,
    site_map,
    left_on=['site', 'sequential_site', 'wildtype', 'region'], 
    right_on=['reference_site', 'sequential_site', 'sequential_wt', 'region'], 
).drop(
    columns=['sequential_site', 'reference_site', 'sequential_wt']
).assign(
    appeared=phenotypes_df['most_recent_fix_date'].notna(),
    stability_measured=phenotypes_df['pH stability'].notna(),
    in_rbs=lambda x: x['rbs_region'].apply(
        lambda r: "Inside receptor binding pocket" if r != "outside RBS" else "Outside receptor binding pocket"
    )
)

In [3]:
def plot_phenotype_vs_date(data, phenotype, phenotype_title, colors, show_x=True, show_rbs_label=True):
    rbs_colors = {
        "Inside receptor binding pocket": colors[0],
        "Outside receptor binding pocket": colors[1],
    }

    y_min = data[phenotype].min()
    y_max = data[phenotype].max()

    # x_min = pd.to_datetime(data['most_recent_fix_date']).min()
    # x_max = pd.to_datetime(data['most_recent_fix_date']).max()

    # Ensure date is in datetime format
    data['most_recent_fix_date'] = pd.to_datetime(data['most_recent_fix_date'])

    # sort data to control which points overlay others
    data = data.sort_values('site', ascending=False)

    x_axis = alt.Axis(
        title=None if not show_x else ["Most recent date when", "amino acid was fixed"],
        labels=show_x,
        ticks=show_x,
        tickCount=5,
        domain=show_x,
        grid=True,
        gridColor='white',
        gridWidth=1.5
    )

    # Create base chart without data specification
    base = alt.Chart().encode(
        x=alt.X(
            "most_recent_fix_date:T",
            axis=x_axis
        ),
        y=alt.Y(
            phenotype,
            title=phenotype_title,
            scale=alt.Scale(
                domain=[y_min, y_max]
            ),
            axis=alt.Axis(
                grid=True,
                gridColor='white',
                gridWidth=1.5
            )
        ),
        color=alt.Color(
            "in_rbs",
            scale=alt.Scale(domain=list(rbs_colors.keys()), range=list(rbs_colors.values())),
            title=None,
            legend=None
        ),
        tooltip=[
            'site', 'wildtype', 'mutant', 'region', 'max_frequency', phenotype,
            'pH stability', 'rbs_region', 'most_recent_fix_date'
        ]
    )

    # Create horizontal line specification
    hline = alt.Chart().mark_rule(
        color='black',
        size=1.25,
        opacity=1.0,
        strokeDash=[6,6]
    ).encode(y=alt.Y(datum=0))

    # Layer everything together first
    complete_layer = alt.layer(
        base.mark_circle(size=60, opacity=1, stroke='black', strokeWidth=0.5),  # scatter plot
        hline,
        #base.transform_regression(
        #    'most_recent_fix_date',
        #    phenotype,
        #    groupby=['in_rbs'],
        #).mark_line(size=3),
    ).properties(
        width=300,
        height=150
    )

    header_kwargs = {
        'labelFontSize': 16,
        'labelFontWeight': 'bold',
    }

    if not show_rbs_label:
        header_kwargs['labelExpr'] = "''"

    # Apply faceting with data specification
    complete_chart = complete_layer.facet(
        data=data.query('mutant != wildtype and most_recent_fix_date.notna()'),
        facet=alt.Facet(
            'in_rbs',
            title=None,
            header=alt.Header(**header_kwargs)
        ),
        columns=2
    ).resolve_scale(
        y='shared',
        x='shared'
    )

    return complete_chart

In [4]:
entry_chart = plot_phenotype_vs_date(
    phenotypes_df, 
    'MDCKSIAT1 cell entry', 
    ['Effect on cell entry in', 'MA22 background'],
    ['#E41A1C', '#FFC1C3'],
    show_x=False
) 

stability_chart = plot_phenotype_vs_date(
    phenotypes_df, 
    'pH stability', 
    ['Effect on stability in', 'MA22 background'],
    ['#377EB8', '#C6DBEF'],
    show_rbs_label=False
) 

entrenchment_chart = (entry_chart & stability_chart).resolve_scale(
    color='independent',
    x='shared'
).configure_view(
    fill='#F1F1F1' # panel background
)

print(f"Saving {snakemake.output.chart_html=}")
entrenchment_chart.save(snakemake.output.chart_html)

entrenchment_chart