In [1]:
import altair as alt
import pandas as pd
import httpimport
# allow more rows for Altair
_ = alt.data_transformers.disable_max_rows()

### Import custom altair theme

In [2]:
# Import custom altair theme from remote github using httpimport module
def import_theme_new():
    with httpimport.github_repo("bblarsen-sci", "altair_themes", "main"):
        import heatmap_theme

        @alt.theme.register("custom_theme", enable=True)
        def custom_theme():
            return heatmap_theme.heatmap_theme()


import_theme_new()

In [3]:
def make_empty_df(
    df, 
    amino_acids = ["R","K","H","D","E","Q","N","S","T","Y","W","F","A","I","L","M","V","G","P","C"]
    ):
    """
    Due to sometimes having missing data, we need to construct an empty data frame with all 
    possible combinations of sites and amino acids. This function first makes an empty data frame
    with all combinations of sites and amino acids called 'empty_df'. Then merges the DMS data 
    with this empty data frame 'merged_df'.

    Parameters
    ----------
    df : pandas.DataFrame
        DataFrame containing the DMS data
    amino_acids : list
        List of amino acids to include in the empty DataFrame
    
    Returns
    -------
    merged_df : pandas.DataFrame
    """
    # Get the minimum and maximum site numbers
    min_site = min(df['site'])
    max_site = max(df['site'])

    # Create a list of all sites in the range of min and max from the DMS data
    sites = range(min_site, max_site+1)

    # Create the combination of each site with each amino acid
    data = [{"site": site, "mutant": aa} for site in sites for aa in amino_acids]
    empty_df = pd.DataFrame(data)

    # Merge the empty DataFrame with the functional effects DataFrame
    merged_df = pd.merge(empty_df, df, on=['site','mutant'], how='left')
    # Need to modify the NaN values to be empty strings for later plotting
    merged_df = merged_df.fillna('')
    return merged_df

In [4]:
def find_min_max(df):
    """
    Find the minimum and maximum values in the DataFrame

    Parameters
    ----------
    df : pandas.DataFrame
        DataFrame containing the DMS data
    
    Returns
    -------
    min_val : float
        Minimum value in the DataFrame
    max_val : float
        Maximum value in the DataFrame
    """
    min_val = min(df['site'])
    max_val = max(df['site'])
    return min_val, max_val


In [5]:
def prepare_data(
    merged_df,
    amino_acids = ["R","K","H","D","E","Q","N","S","T","Y","W","F","A","I","L","M","V","G","P","C"]
):
    """
    Next, to plot the heatmaps, we need to assign a rank to each amino acid. This is done by
    creating a dictionary with the amino acids as keys and the ranks as values. Then we add a
    column to the DataFrame called 'mutant_rank' which is the rank of the amino acid for each
    row in the DataFrame.

    Parameters
    ----------
    merged_df : pandas.DataFrame
        DMS data in pandas.DataFrame merged with an empty DataFrame containing all
        possible combinations of sites and amino acids.
    amino_acids : list
        List of amino acids to include in the empty DataFrame

    Returns
    -------
    prepared_df : pandas.DataFrame
        DMS data merged with an empty DataFrame containing all
        possible combinations of sites and amino acids with an additional column
        'mutant_rank' which is the rank of the amino acid for each row in the DataFrame
    """
    sort_order = {mutant: i for i, mutant in enumerate(amino_acids)}
    merged_df["mutant_rank"] = merged_df["mutant"].map(sort_order)
    prepared_data = merged_df.copy()
    return prepared_data

In [6]:
def prepare_wt_data(
    prepared_df,
    amino_acids = ["R","K","H","D","E","Q","N","S","T","Y","W","F","A","I","L","M","V","G","P","C"]
):
    """
    We need to make a separate DataFrame with only the wildtype data. This is done by first
    dropping all duplicate rows with the same site and wildtype. Then we create a dictionary
    with the amino acids as keys and the ranks as values. We then add a column to the DataFrame
    called 'mutant_rank' which is the rank of the amino acid for each row in the DataFrame.

    Parameters
    ----------
    prepared_df : pandas.DataFrame
        DMS data merged with an empty DataFrame containing all
        possible combinations of sites and amino acids with an additional column
        'mutant_rank' which is the rank of the amino acid for each row in the DataFrame
    amino_acids : list

    Returns
    -------
    wildtype_df : pandas.DataFrame
        DataFrame with only the unique wildtypes for each site
    """
    unique_wildtypes_df = prepared_df.copy().drop_duplicates(subset=["site", "wildtype"])
    sort_order = {mutant: i for i, mutant in enumerate(amino_acids)}
    unique_wildtypes_df["mutant_rank"] = unique_wildtypes_df["wildtype"].map(sort_order)
    wildtype_df = unique_wildtypes_df[unique_wildtypes_df["wildtype"] != '']
    return wildtype_df

In [7]:
def split_range(start, end, parts, last_interval_reduction=20):
    total_range = end - start
    regular_interval = (total_range + last_interval_reduction) // parts
    last_interval = regular_interval - last_interval_reduction
    
    intervals = []
    current = start
    
    for i in range(parts - 1):
        next_point = min(current + regular_interval, end)
        intervals.append((current, next_point))
        current = next_point
    
    # Add the last interval
    intervals.append((current, end))
    
    return intervals

def create_full_ranges(start, end, parts, last_interval_reduction=20):
    intervals = split_range(start, end, parts, last_interval_reduction)
    return [list(range(start, end)) for start, end in intervals]

In [8]:
def plot_heatmap(
    prepared_df, 
    wildtype_df,
    full_ranges,
    legend_title = '',
    plot_title = '',
    subtitle_str = '',
    stroke_width = 0.5,
    stroke_color = 'black',
    effect_color_scheme = 'redblue',
    domain = [-4, 2],
    null_color = '#d1d3d4',

):
    """
    Function to make a wrapped altair heatmap from the DMS data. This function first creates a list of
    DataFrames for each range of sites. Then creates a base chart with the x-axis as the site
    and the y-axis as the amino acid. The base chart is then layered with an empty chart, an
    effect chart, a wildtype box chart and a wildtype text chart. The charts are then combined
    into a single chart.

    Parameters
    ----------
    prepared_df : pandas.DataFrame
        DMS data prepared for plotting
    wildtype_df : pandas.DataFrame
        DataFrame with only the unique wildtypes for each site
    full_ranges : list
        List of lists of ranges of sites to plot
    legend_title : str
        Title for the legend
    plot_title : str
        Title for the plot
    subtitle_str : str
        Subtitle for the plot
    stroke_width : float
        Width of the stroke for the effect chart
    stroke_color : str
        Color of the stroke for the effect chart
    effect_color_scheme : str
        Color scheme for the effect. Possible values are 'blueorange','brownbluegreen','purplegreen','pinkyellowgreen','purpleorange', 'redblue','redgrey','redyellowblue'
    domain : list
        List of two values for the domain of the color scale
    null_color : str
        Color for the mutants which have no associated DMS
    
    Returns
    -------
    combined_chart : altair.Chart
        Altair chart with the heatmaps
    """
    # Create an empty list to hold the charts
    charts = []

    sites = sorted(prepared_df["site"].unique(), key=lambda x: float(x))

    for idx, subset in enumerate(full_ranges):
        subset_df = prepared_df[prepared_df['site'].isin(subset)]
        subset_wt_df = wildtype_df[wildtype_df['site'].isin(subset)]

        # keep track of whether this is the last wrapped chart
        is_last = idx == len(full_ranges) - 1
        
        # Only show the x-axis title on the last chart
        x_axis = alt.Axis(
                labelAngle=-90,
                labelExpr="datum.value % 10 === 0 ? datum.value : ''",
                title="Site" if is_last else None,
                labels=True,
        )

        # Only show the legend on the last chart
        effect_legend = (
                alt.Legend(
                    title=legend_title,
                    direction="horizontal",
                    gradientLength=100,
                    #labelFontSize=16,
                    titleAnchor="middle",
                    tickCount=4,
                    labelAlign="left",
                    titleAlign='center',
                    
                )
                if is_last
                else None
        )

        color_scale_effect = alt.Scale(
            scheme=effect_color_scheme,
            domainMid=0,
            domain=domain,
        )
        
        chart_base = alt.Chart(subset_df).encode(
            alt.X("site:O") 
                .title("Site") 
                .sort(sites) 
                .axis(x_axis),
            alt.Y("mutant")
                .title("Amino Acid")
                .sort(alt.EncodingSortField(field="mutant_rank", order="ascending"))
                .axis(alt.Axis(grid=False)),
        ).properties(
            width=alt.Step(10),
            height=alt.Step(10),
        )

        chart_empty = chart_base.mark_rect(
            color=null_color,
            strokeWidth=0
        ).encode(
            tooltip=["site", "mutant"]
        ).transform_filter(
            alt.datum.times_seen == '' & 
            alt.datum.mutant == alt.datum.wildtype
        )

        chart_effect = chart_base.mark_rect(
                strokeWidth=stroke_width,
                stroke=stroke_color
        ).encode(
            alt.Color("effect:Q")
                .scale(color_scale_effect)
                .legend(effect_legend),
            tooltip = ["site", "mutant", "wildtype",'effect','times_seen'],
        ).transform_filter(alt.datum.times_seen != '')
        

        chart_wildtype_box = alt.Chart(subset_wt_df).mark_rect(
            color="white", 
            stroke=stroke_color, 
            strokeWidth=stroke_width
        ).encode(
            alt.X("site:O", sort=sites),
            alt.Y("wildtype")
                .sort(alt.EncodingSortField(field="mutant_rank", order="ascending")),
        )

        chart_wildtype_text = alt.Chart(subset_wt_df).mark_text(
            color="black", 
            text="X", 
            size=8, 
            align="center", 
            baseline="middle",dy=1
        ).encode(
            alt.X("site:O", sort=sites),
            alt.Y("wildtype")
                .sort(alt.EncodingSortField(field="mutant_rank", order="ascending")),
            tooltip = ['site','wildtype']
        )
                
        chart = alt.layer(chart_empty, chart_effect, chart_wildtype_box, chart_wildtype_text).resolve_scale(y='shared', x='shared', color="independent")
        charts.append(chart)
    
    combined_chart = alt.vconcat(*charts, spacing=30, title=alt.Title(f"{plot_title}",subtitle=subtitle_str)).resolve_scale(y='shared', x='independent', color='shared')
    
    return combined_chart



In [None]:
def main():
    # user parameters
    data_path = snakemake.input.entry_df
    #file_name_prefix = "NipahF_entry"
    heatmap_rows = 5
    
    # run the analysis
    df = pd.read_csv(data_path).round(2)
    merged_df = make_empty_df(df)
    prepared_df = prepare_data(merged_df)
    wildtype_df = prepare_wt_data(prepared_df)
    min_val, max_val = find_min_max(prepared_df)
    full_ranges = create_full_ranges(min_val, max_val, heatmap_rows)
    print(full_ranges)
    plot = plot_heatmap(
        prepared_df,
        wildtype_df,
        full_ranges,
        legend_title="CHO-bEFNB3 Entry",
        plot_title="Nipah F",
        subtitle_str="Effects of mutations on cell entry relative to the unmutated reference sequence",
        stroke_width=1,
        stroke_color="white",
        domain=[-4, 2],
        #effect_color_scheme="purpleorange",
    )
    plot.save(snakemake.output.entry_heatmap_svg)
    plot.save(snakemake.output.entry_heatmap_png, ppi=300)
    plot.save(snakemake.output.entry_heatmap_html)
    

if __name__ == "__main__":
    main()