# General Formatting Tool

The goal of this notebook is to test different approaches to make a generalized tool for formatting the input data to the `dms-vis` tool as a JSON file. 

I think the approach that will dovetail best into the lab's current workflow will be to have two main functions: 

1. A function that takes the components and makes a single JSON file for one experiment.
2. A function that combines JSON files into a single file with the same sitemap.

In [55]:
import pandas as pd
import json

In [101]:
def format_input_json(mut_metric_df,
                      metric_col, 
                      sitemap_df,
                      mut_effect_df=None,
                      filter_cols=None,
                      structure=None,
                      included_chains="polymer",
                      excluded_chains="none",
                      alphabet="RKHDEQNSTYWFAILMVGPC-*",
                      colors=['#0072B2', '#CC79A7', '#4C3549', '#009E73']
                     ):
    """
    Take site-level and mutation-level measurements and format into 
    a JSON file for interactive visualization with `dms-viz`. 
    
    Prameters
    ---------
    mut_metric_df: pandas.DataFrame
        A dataframe containig site- and mutation-level data for visualization. 
    metric_col: str
        The name of the column the contains the metric for visualization.
    sitemap_df: pandas.DataFrame
        A dataframe mapping data sites to reference sites to protein sites. 
    structure: str or None 
        An RCSB PDB ID (i.e. 6UDJ) if not using a custom strucutre.
    included_chains: str or None
        If not mapping data to every chain, a space separated list of chain names (i.e. "C F M G J P").
    excluded_chains: str or None
        A space separated string of chains that should not be shown on the protein structure (i.e. "B L R").
    mut_effect_df: pandas.dataFrame or None
        A dataframe of functional effects to join to the main dataframe by mutation. 
    filter_cols: list or None
        A list of column names to designate as filters in the visualization. 
    alphabet: str
        The amino acid labels in the order the should be displayed on the heatmap. 
    colors: list
        A list of colors that will be used for each epitope in the experiment.
    """
    
    # Check that there are reference sites in the mutation data
    if not mut_metric_df.columns.isin(['site', 'reference_site']).any():
        raise ValueError("The mutation dataframe is missing either the site or reference_site column.")
        
    # Check that required columns are present in the mutation data
    missing_mutation_columns = {'epitope', 'site', 'wildtype', 'mutant', 'mutation', metric_col} - set(mut_metric_df.columns)
    if missing_mutation_columns:
        raise ValueError(f"The following columns do not exist in the mutation metric data: {list(missing_mutation_columns)}")

    # Check that required columns are present in the sitemap data
    missing_sitemap_columns = {'sequential_site', 'reference_site'} - set(sitemap_df.columns)
    if missing_sitemap_columns:
        raise ValueError(f"The following columns do not exist in the sitemap: {list(missing_sitemap_columns)}")
        
    # If the protein site isn't specified, assume that it's the same as the reference site
    if 'protein_site' not in sitemap_df.columns: 
        sitemap_df['protein_site'] = sitemap_df['reference_site'].apply(lambda y: y if y.isnumeric() else "")
    else:
        # Make sure that the provided column has no invalid values
        if not sitemap_df['protein_site'].apply(lambda y: y == "" or y.isnumeric()).all():
            raise ValueError("The protein_site column of the sitemap contains invalid values.")
        
    # Add the included chains to the sitemap data if there are any
    sitemap_df['chains'] = sitemap_df['protein_site'].apply(lambda y: included_chains if y.isnumeric() else "")

    # Get a list of the epitopes and map these to the colors 
    epitopes = list(set(mut_metric_df.epitope))
    if len(epitopes) > len(colors):
        raise ValueError(f"There are {len(epitopes)} epitopes, but only {len(colors)} color(s) specified. Please specify more colors.")
    epitope_colors = { epitope: colors[i] for i, epitope in enumerate(epitopes) }
    
    # Join optional columns to the mutation metric data
    if mut_effect_df is not None:
        # Check that the necessary columns are present
        missing_effect_columns = {'wildtype', 'reference_site', 'mutant', 'effect'} - set(mut_effect_df.columns)
        if missing_effect_columns:
            raise ValueError(f"The following columns do not exist in the functional data: {list(missing_effect_columns)}")
        # Join to the main metric data 
        mut_effect_df['mutation'] = mut_effect_df.apply(lambda row: row.wildtype + row.reference_site + row.mutant, axis=1)
        mut_metric_df = pd.merge(mut_metric_df, mut_effect_df[['mutation', 'effect']].drop_duplicates(), on='mutation')
        mut_metric_df = mut_metric_df.rename(columns={"effect": "func_effect"})
    
    # If there are columns to filter by, make sure that these are present in the data and numeric
    if filter_cols:
        missing_filter_columns = set(filter_cols) - set(mut_metric_df.columns)
        if missing_filter_columns: 
            raise ValueError(f"The filter column(s): {missing_filter_columns} are not present in the data.")
    
    # Make a dictionary holding the experiment data 
    experiment_dict = {
        'mut_escape_df': json.loads(mut_metric_df.to_json(orient='records')),
        'sitemap': sitemap_df.set_index('reference_site').to_dict(orient='index'),
        'alphabet': [aa for aa in alphabet],
        'pdb': structure,
        'dataChains': included_chains.split(" "),
        'excludeChains': excluded_chains.split(" "),
        'epitopes': epitopes,
        'epitope_colors': epitope_colors,
        'filter_cols': filter_cols
    }
    
    return experiment_dict

## Test the functions 

In [102]:
# == Inputs == #

# Escape dataframe
mut_escape_df = pd.read_csv("data/hiv/escape/IDC508_avg.csv")
# Sitemap dataframe 
sitemap = pd.read_csv("data/hiv/site_numbering_map.csv")
# Functional effect dataframe
funceffects = pd.read_csv("data/hiv/muteffects_observed.csv")

set(mut_escape_df.epitope)

{1, 2}

In [108]:
test = format_input_json(mut_escape_df,
                  "escape_mean", 
                  sitemap,
                  structure="6UDJ",
                  included_chains="C F M G J P",
                  excluded_chains="B L R A Q K",
                  mut_effect_df=funceffects,
                  filter_cols=['times_seen', 'func_effect'],
                  alphabet="RKHDEQNSTYWFAILMVGPC-*",
                  colors=['#0072B2', '#CC79A7', '#4C3549', '#009E73']
                 )

In [99]:
selection_dict = {"testing": test}

In [100]:
with open("./test.json", "w") as out:
    json.dump(selection_dict, out, indent=4)