# Homolog BinaryMap Encoding

## Import `Python` modules

In [3]:
import pandas as pd
import re
import json
import binarymap as bmap
import matplotlib.pyplot as plt
import seaborn as sns
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax.experimental import sparse
from jaxopt import ProximalGradient
import jaxopt
import numpy as onp
from scipy.stats import pearsonr

from collections import defaultdict
from timeit import default_timer as timer

from tqdm.auto import tqdm
tqdm.pandas()
%matplotlib inline

## Strategy for converting mutations to be relative to the reference homolog

In [4]:
# reference = "MT"
homologs = {
    "1" : "M-G",
    "2" : "MPP"
}

In [5]:
test_dict = {
    'homolog' : ["1","1","1","1", "2","2","2","2","2","2"],
    'variant' : ['M1E', 'G3R', 'G3P', 'M1W', 'M1E', 'P3R', 'P3G', 'M1E P3G', 'M1E P3R', 'P2T'],
    'log2E' : [2, -7, -0.5, 2.3, 1, -5, 0.4, 2.7, -2.7, 0.3],
}
test_df = pd.DataFrame(test_dict)
test_df

Unnamed: 0,homolog,variant,log2E
0,1,M1E,2.0
1,1,G3R,-7.0
2,1,G3P,-0.5
3,1,M1W,2.3
4,2,M1E,1.0
5,2,P3R,-5.0
6,2,P3G,0.4
7,2,M1E P3G,2.7
8,2,M1E P3R,-2.7
9,2,P2T,0.3


In [8]:
def create_homolog_modeling_data(
    func_score_df:pd.DataFrame,
    homolog_name_col: str,
    reference_homolog: str,
    substitution_col: str,
    func_score_col: str
):
    """
    Takes a dataframe for making a `BinaryMap` object, and adds
    a column where each entry is a list of mutations in a variant
    relative to the amino-acid sequence of the reference homolog.
    
    Parameters
    ----------

    func_score_df : pandas.DataFrame
        This should be in the same format as described in BinaryMap.
    
    homolog_name_col : str
        The name of the column in func_score_df that identifies the
        homolog for a given variant. We require that the
        reference homolog variants are labeled as 'reference'
        in this column.
        
    reference_homolog : str
        The name of the homolog existing in ``homolog_name_col`` for
        which we should convert all substitution to be with respect to.
    
    substitution_col : str 
        The name of the column in func_score_df that
        lists mutations in each variant relative to the homolog wildtype
        amino-acid sequence where sites numbers must come from an alignment
        to a reference sequence (which may or may not be the same as the
        reference homolog).
        
    func_score_col : str
        Column in func_scores_df giving functional score for each variant.
        
    
    Returns
    -------
        
    tuple : (dict[BinaryMap], dict[jnp.array]), pd.DataFrame, np.array, pd.DataFrame
    
        This function return a tuple which can be unpacked into the following:
        
        - (X, y) Where X and y are both dictionaries containing the prepped data
            for training our JAX multidms model. The dictionary keys
            stratify the datasets by homolog
            
        - A pandas dataframe which primary contains the information from
            func_score_df, but has been curated to include only the variants
            deemed appropriate for training, as well as the substitutions
            converted to be wrt to the reference homolog.
            
        - A numpy array giving the substitutions (beta's) of the binary maps
            in the order that is preserved to match the matrices in X.
            
        - A pandas dataframe providing the site map indexed by alignment site to
            a column for each homolog wt amino acid. 
    
    """
    
    # TODO: strip gapped substitutions (insertions) variants
    # from the func_score_df?
    
    def split_sub(sub_string):
        """String match the wt, site, and sub aa
        in a given string denoting a single substitution"""
        
        pattern = r'(?P<aawt>\w)(?P<site>[\d\w]+)(?P<aamut>[\w\*])'
        match = re.search(pattern, sub_string)
        assert match != None, sub_string
        return match.group('aawt'), str(match.group('site')), match.group('aamut')
    
    def split_subs(subs_string):
        """wrap the split_sub func to work for a 
        string contining multiple substitutions"""
        
        wts, sites, muts = [], [], []
        for sub in subs_string.split():
            wt, site, mut = split_sub(sub)
            wts.append(wt); sites.append(site); muts.append(mut)
        return wts, sites, muts
   
    # Add columns that parse mutations into wt amino acid, site,
    # and mutant amino acid
    ret_fs_df = func_score_df.copy()
    ret_fs_df["wts"], ret_fs_df["sites"], ret_fs_df["muts"] = zip(
        *ret_fs_df[substitution_col].map(split_subs)
    )

    # Use the substitution_col to infer the wildtype
    # amino-acid sequence of each homolog, storing this
    # information in a dataframe.
    site_map = pd.DataFrame(dtype="string")
    for hom, hom_func_df in ret_fs_df.groupby(homolog_name_col):
        for idx, row in hom_func_df.iterrows():
            for wt, site  in zip(row.wts, row.sites):
                site_map.loc[site, hom] = wt
    
    # Find all sites for which at least one homolog lacks data
    # (this can happen if there is a gap in the alignment)
    na_rows = site_map.isna().any(axis=1)
    print(f"Found {sum(na_rows)} site(s) lacking data in at least one homolog.")
    sites_to_throw = na_rows[na_rows].index
    site_map.dropna(inplace=True)
    
    # Remove all variants with a mutation at one of the above
    # "disallowed" sites lacking data
    def flags_disallowed(disallowed_sites, sites_list):
        """Check to see if a sites list contains 
        any disallowed sites"""
        for site in sites_list:
            if site in disallowed_sites:
                return False
        return True
    
    ret_fs_df["allowed_variant"] = ret_fs_df.sites.apply(
        lambda sl: flags_disallowed(sites_to_throw,sl)
    )
    n_var_pre_filter = len(ret_fs_df)
    ret_fs_df = ret_fs_df[ret_fs_df["allowed_variant"]]
    print(f"{n_var_pre_filter-len(ret_fs_df)} of the {n_var_pre_filter} variants"
          f" were removed because they had mutations at the above sites, leaving"
          f" {len(ret_fs_df)} variants.")

    # Duplicate the substitutions_col, then convert the respective subs to be wrt ref
    # using the function above
    ret_fs_df = ret_fs_df.assign(var_wrt_ref = ret_fs_df[substitution_col])
    for hom, hom_func_df in ret_fs_df.groupby(homolog_name_col):
        if hom == reference_homolog: continue

        # compute bundle muts for a specific site
        for idx, row in tqdm(hom_func_df.iterrows(), total=len(hom_func_df)):
            var_map = site_map[[reference_homolog, hom]].copy()
            for wt, site, mut in zip(row.wts, row.sites, row.muts):
                var_map.loc[site, hom] = mut
            nis = var_map.where(
                var_map[reference_homolog] != var_map[hom]
            ).dropna()
            muts = nis[reference_homolog] + nis.index + nis[hom]
            ret_fs_df.loc[idx, "var_wrt_ref"] = " ".join(muts.values)

    # Get list of all allowed substitutions for which we will tune beta parameters
    allowed_subs = {
        s for subs in ret_fs_df.var_wrt_ref
        for s in subs.split()
    }
    
    # Make BinaryMap representations for each homolog
    X, y = {}, {}
    for homolog, homolog_func_score_df in ret_fs_df.groupby("homolog"):
        ref_bmap = bmap.BinaryMap(
            homolog_func_score_df,
            substitutions_col="var_wrt_ref",
            allowed_subs=allowed_subs
        )
        
        # convert binarymaps into sparse arrays for model input
        X[homolog] = sparse.BCOO.from_scipy_sparse(ref_bmap.binary_variants)
        
        # create jax array for functional score targets
        y[homolog] = jnp.array(homolog_func_score_df[func_score_col].values)
    
    ret_fs_df.drop(["wts", "sites", "muts"], axis=1, inplace=True)

    return (X, y), ret_fs_df, ref_bmap.all_subs, site_map

In [9]:
(X, y), updated_func_score_df, all_subs, site_map = create_homolog_modeling_data(
    test_df, 
    "homolog", 
    "1", 
    "variant", 
    "log2E"
)
updated_func_score_df

Found 1 site(s) lacking data in at least one homolog.
1 of the 10 variants were removed because they had mutations at the above sites, leaving 9 variants.


  0%|          | 0/5 [00:00<?, ?it/s]

Unnamed: 0,homolog,variant,log2E,allowed_variant,var_wrt_ref
0,1,M1E,2.0,True,M1E
1,1,G3R,-7.0,True,G3R
2,1,G3P,-0.5,True,G3P
3,1,M1W,2.3,True,M1W
4,2,M1E,1.0,True,M1E G3P
5,2,P3R,-5.0,True,G3R
6,2,P3G,0.4,True,
7,2,M1E P3G,2.7,True,M1E
8,2,M1E P3R,-2.7,True,M1E G3R


In [25]:
site_map

Unnamed: 0,1,2
1,M,M
3,G,P


In [14]:
all_subs

['M1E', 'M1W', 'G3P', 'G3R']