# Homolog BinaryMap Encoding

## Import `Python` modules

In [157]:
import pandas as pd
import re
import json
import binarymap as bmap
import matplotlib.pyplot as plt
import seaborn as sns
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax.experimental import sparse
from jaxopt import ProximalGradient
import jaxopt
import numpy as onp
from scipy.stats import pearsonr

from collections import defaultdict
from timeit import default_timer as timer
%matplotlib inline

In [158]:
# import sys
# sys.path.append("..")
# from multidms.utils import create_homolog_modeling_data

## Strategy for converting mutations to be relative to the reference homolog

As described above, our strategy involves converting all mutations in all homologs to be relative to the amino-acid sequence of the reference homolog. The below function performs this step.

Note 1: we will probably need some special-purpose code to handle gaps. This isn't done yet.

Next, we will test the above function with a small test case. Below, we define variants from two imaginary homologs: "reference" and "2".

In [421]:
# reference = "MT"
homologs = {
    "1" : "M-G",
    "2" : "MPP"
}

In [422]:
test_dict = {
    'homolog' : ["1","1","1","2","2","2","2","2","2"],
    'variant' : ['M1E', 'G3R', 'G3P', 'M1E', 'P3R', 'P3G', 'M1E P3G', 'M1E P3R', 'P2T'],
    'log2E' : [2, -7, -0.5, 2.3, -5, 0.4, 2.7, -2.7, 0.3],
}
test_df = pd.DataFrame(test_dict)
test_df

Unnamed: 0,homolog,variant,log2E
0,1,M1E,2.0
1,1,G3R,-7.0
2,1,G3P,-0.5
3,2,M1E,2.3
4,2,P3R,-5.0
5,2,P3G,0.4
6,2,M1E P3G,2.7
7,2,M1E P3R,-2.7
8,2,P2T,0.3


In [423]:
def create_homolog_modeling_data(
    func_score_df:pd.DataFrame,
    homolog_name_col: str,
    substitution_col: str,
    func_score_col: str
):
    """
    Takes a dataframe for making a `BinaryMap` object, and adds
    a column where each entry is a list of mutations in a variant
    relative to the amino-acid sequence of the reference homolog.
    
    Parameters
    ----------

    func_score_df : pandas.DataFrame
        This should be in the same format as described in BinaryMap.
    
    homolog_name_col : str
        The name of the column in ret_fs_df that identifies the
        homolog for a given variant. We require that the
        reference homolog variants are labeled as 'reference'
        in this column.
    
    substitution_col : str 
        The name of the column in ret_fs_df that
        lists mutations in each variant relative to the homolog wildtype
        amino-acid sequence where sites must come from a sequence alignment
        to the reference.
        
    func_score_col : str
        Column in func_scores_df giving functional score for each variant.
        
    
    Returns
    -------
        
    TODO
    
    """    
    
    
    def split_sub(sub_string):
        """String match the wt, site, and sub aa
        in a given string denoting a single substitution"""
        
        pattern = r'(?P<aawt>\w)(?P<site>[\d\w]+)(?P<aamut>[\w\*])'
        match = re.search(pattern, sub_string)
        assert match != None, sub_string
        return match.group('aawt'), str(match.group('site')), match.group('aamut')
    
    def split_subs(subs_string):
        """wrap the split_sub func to work for a 
        string contining multiple substitutions"""
        
        wts, sites, muts = [], [], []
        for sub in subs_string.split():
            wt, site, mut = split_sub(sub)
            wts.append(wt); sites.append(site); muts.append(mut)
        return wts, sites, muts
   
    ret_fs_df = func_score_df.copy()
    ret_fs_df["wts"], ret_fs_df["sites"], ret_fs_df["muts"] = zip(
        *ret_fs_df[substitution_col].map(split_subs)
    )

    # we want to build a site map top infer the wildtype
    # this will have a dictionary for each homolog
    # that dictionary will have keys for sites seen, 
    # and value as homolog wildtype
    site_map = pd.DataFrame(dtype="string")
    # for each homolog, find available sites
    for hom, hom_func_df in ret_fs_df.groupby(homolog_name_col):
        for idx, row in hom_func_df.iterrows():
            for wt, site  in zip(row.wts, row.sites):
                # TODO do we continue if "-" in mut col? or in split_subs()?
                site_map.loc[site, hom] = wt
    
    # Find all sites where any of the homologs don't have an entry
    na_rows = site_map.isna().any(axis=1)
    sites_to_throw = na_rows[na_rows].index
    site_map.dropna(inplace=True)
    
    # remove all variant's that contain disallowed site
    def flags_disallowed(sites, sites_list):
        """Check to see if a sites list contains 
        any disallowed sites"""
        for site in sites:
            if site in sites_list:
                return False
        return True
    
    ret_fs_df["allowed_variant"] = ret_fs_df.sites.apply(
        lambda sl: flags_disallowed(sites_to_throw,sl)
    )
    ret_fs_df = ret_fs_df[ret_fs_df["allowed_variant"]]
    
    def subs_wrt_ref(subs_str, sites_map, ref, hom):
        """Takes in a string of substitutions wrt homolog seq,
        a sites map containing seqs for both the homolog
        and reference, as well as the column names for
        both. It then copies and mutates the homolog
        sequence before
        """
        
        var_map = sites_map.copy()
        for sub in subs_str.split():
            wt, site, mut = split_sub(sub)
            var_map.loc[site, hom] = mut

        ref_muts = [
            f"{row[ref]}{i}{row[hom]}" 
            for i, row in var_map.iterrows()
            if row[ref] != row[hom]
        ]
        
        return " ".join(ref_muts)

    # Duplicate the substitutions_col, then convert the respective subs to be wrt ref
    # using the function above
    ret_fs_df = ret_fs_df.assign(
            var_wrt_ref = ret_fs_df[substitution_col].values
    )
    for hom, hom_func_df in ret_fs_df.groupby(homolog_name_col):
        if hom == "1": continue
        
        hom_var_wrt_ref = hom_func_df[substitution_col].apply(
            lambda subs: subs_wrt_ref(subs, site_map, "1", hom)
        )
        ret_fs_df.loc[hom_func_df.index.values, "var_wrt_ref"] = hom_var_wrt_ref   

    # Get list of all allowed substitutions that we will tune beta parameters for
    allowed_subs = {
        s for subs in ret_fs_df.var_wrt_ref
        for s in subs.split()
    }
    
    # Make BinaryMap representations for each homolog
    X, y = {}, {}
    for homolog, homolog_func_score_df in ret_fs_df.groupby("homolog"):
        ref_bmap = bmap.BinaryMap(
            homolog_func_score_df,
            substitutions_col="var_wrt_ref",
            allowed_subs=allowed_subs
        )
        
        # convert binarymaps into sparse arrays for model input
        X[homolog] = sparse.BCOO.fromdense(ref_bmap.binary_variants.toarray())
        
        # create jax array for functional score targets
        y[homolog] = jnp.array(homolog_func_score_df[func_score_col].values)
    
    return (X, y), ret_fs_df.drop(["wts", "sites", "muts"], axis=1), ref_bmap.all_subs    

In [424]:
print(homologs)

{'1': 'M-G', '2': 'MPP'}


In [425]:
(X, y), updated_func_score_df, all_subs = create_homolog_modeling_data(test_df, "homolog", "variant", "log2E")
updated_func_score_df

Unnamed: 0,homolog,variant,log2E,allowed_variant,var_wrt_ref
0,1,M1E,2.0,True,M1E
1,1,G3R,-7.0,True,G3R
2,1,G3P,-0.5,True,G3P
3,2,M1E,2.3,True,M1E G3P
4,2,P3R,-5.0,True,G3R
5,2,P3G,0.4,True,
6,2,M1E P3G,2.7,True,M1E
7,2,M1E P3R,-2.7,True,M1E G3R
