In [3]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from sklearn.inspection import permutation_importance

In [4]:
char2NumDict = {"A":[1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],
                "R":[0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],
                "N":[0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],
                "D":[0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],
                "C":[0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],
                "Q":[0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0],
                "E":[0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0],
                "G":[0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0],
                "H":[0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0],
                "I":[0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0],
                "L":[0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0],
                "K":[0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0],
                "M":[0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0],
                "F":[0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0],
                "P":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0],
                "S":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0],
                "T":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0],
                "W":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0],
                "Y":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0],
                "V":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1]}
                

def aa_to_one_hot(peptide_string):
    num_list = []
    for char in peptide_string:
        num_list.extend(char2NumDict.get(char))
    return np.array(num_list)


In [47]:
def calculateImportanceColors(df):
    """Returns the colors (green/red) for each amino acid to color-code it's contribution to enrichment if present.
    The peptides split into two groups. Those having the amino acid, and those that don't. The difference in enrichment is then
    determined. If the aa contributes to enrichment, it will be green, otherwise it will be red.
    returns a list of colors, and a list of delta_means (the difference of having vs not having any given aa)
    """
    #let's recalculate the onehot encoded to assure we're using the original dataframes data
    X = pd.DataFrame(np.array([aa_to_one_hot(s) for s in df["seq"]]))
    df = df.copy()
    df.index = X.index
    X['num_before'] = df['num_before']
    X['num_after'] = df['num_after']
    
    colors = []
    delta_means = []
    
    for i in range(X.shape[1]-2):
        have_aa = X[X[i]==1]
        have_not_aa = X[X[i]==0]
        ER_have_aa = have_aa['num_after'].mean()/have_aa['num_before'].mean()
        ER_have_not_aa = have_not_aa['num_after'].mean()/have_not_aa['num_before'].mean()
        delta_mean = ER_have_aa - ER_have_not_aa
        
        delta_means.append(delta_mean)
        if delta_mean>0:
            colors.append("palegreen")
        elif delta_mean<0:
            colors.append("tomato")
        else:
            colors.append("gray")
            
    return colors, delta_means



def plotPermutationImportance(result, filename_suffix="", savefigure=False):
    
    plotparam = {"legend":None, "width":0.9, "edgecolor":'black', 'capsize':5}
    

    #fig_width = int(len(feature_names)*0.4)
    fig_width = result.shape[0]*0.4
    fig, ax = plt.subplots(2,1, figsize=(fig_width,5), sharey=True)
    result.plot.bar("Amino Acid", "Importance", ax=ax[0], yerr="Error", color=result["color"], **plotparam)
    ax[0].set_xticklabels(result["Amino Acid"],rotation=0)
    ax[0].set_ylim(0, ax[0].get_ylim()[1])
    
    sorted_results = result.sort_values(by="Importance", ascending=False)
    sorted_results.plot.bar("Amino Acid", "Importance", ax=ax[1], yerr="Error", color=sorted_results["color"], **plotparam)
    ax[1].set_xticklabels(sorted_results["Amino Acid"],rotation=0)
    ax[1].set_ylim(0,ax[1].get_ylim()[1])
    
    plt.tight_layout()
    
    if savefigure:
        plt.savefig("graphs/png/PermutationImportance"+filename_suffix+".png", dpi=200, bbox_inches='tight')
        plt.savefig("graphs/jpg/PermutationImportance"+filename_suffix+".jpg", dpi=200, bbox_inches='tight')
        plt.savefig("graphs/pdf/PermutationImportance"+filename_suffix+".pdf", bbox_inches='tight')
    plt.show()
    


def plotPermutationImportanceFromTSV(tsv_file, filename_suffix="", savefigure=False, positions_to_remove=[]):
    """Re-plots a permutation importance result file. Optionally, certain positions can be removed by providing
    a list of postions as strings i.e. positions_to_remove=[-2, 0, 9, -1S, -1V]
    """
    positions_to_remove = [str(ptr) for ptr in positions_to_remove] #assure we have strings
    
    result = pd.read_table(tsv_file, index_col=0)
    result['AA_no_newline'] = result['Amino Acid'].str.replace("\n", "")
    result['keep'] = True
    
    for ptr in positions_to_remove:
        mask = result['AA_no_newline'].str.startswith(ptr)
        result.loc[mask, 'keep'] = False
    
    result = result[result['keep']]
    result = result.drop(['AA_no_newline', 'keep'], axis=1)
    
    plotPermutationImportance(result, filename_suffix=filename_suffix, savefigure=savefigure)



def calculatePermutationImportance(model, X, Y, feature_names, filename_suffix="", savefigure=False, n_repeats=10):
    perm_importance = permutation_importance(model, X, Y, random_state=0, n_repeats=n_repeats, n_jobs=-1)
    
    colors, delta_means = calculateImportanceColors(df)
    
    result = pd.DataFrame({"Amino Acid":feature_names, "Importance":perm_importance["importances_mean"], "Error":perm_importance["importances_std"]})
    result["dMean"] = delta_means
    result["color"] = colors
    if savefigure:
        result.to_csv("PermutationImportance"+filename_suffix+".tsv", sep="\t")

    plotPermutationImportance(result, filename_suffix=filename_suffix, savefigure=savefigure)
    
    return result


def mergePermutationImportanceResults(result_list, feature_names, filename_suffix=""):
    merged_result = None
    for i in range(len(result_list)):
        if i==0:
            merged_result = result_list[i]
            merged_result = merged_result[["Amino Acid", "Importance"]]
            merged_result.set_index("Amino Acid", inplace=True)
            merged_result.columns = ["Importance "+str(i)]
        else:
            df = result_list[i].set_index("Amino Acid")
            merged_result = merged_result.join(df[["Importance"]])
            c = list(merged_result.columns)
            c[-1] = "Importance "+str(i)
            merged_result.columns = c
    
    means = merged_result.agg(["mean", "std"], axis=1)
    means.columns = ["Importance", "Error"]

    
    result = result_list[0].copy().set_index("Amino Acid").drop(["Importance", "Error"], axis=1).join(means).reset_index()
    result.to_csv("PermutationImportance"+filename_suffix+'.tsv', sep='\t')

    plotPermutationImportance(result, filename_suffix=filename_suffix, savefigure=True )    

