**BEWARE** This notebook is for made available to the public for reproducability purposes only. The code below is not maintained. 

# Init

In [None]:
import sys
import pandas as pd
from IPython.display import SVG, display, Image
from os.path import dirname
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.neural_network import MLPClassifier, MLPRegressor
from sklearn.model_selection import train_test_split
import numpy as np
import subprocess
from sklearn.manifold import *
import os
import glob
import mdtraj as md
import scipy
import matplotlib as mpl
from matplotlib.colors import ListedColormap, LinearSegmentedColormap
import matplotlib.pyplot as plt; 
plt.style.use('seaborn-colorblind')
from functools import reduce

# Make sure to add state sampling and demystifying to your python path, e.g
sys.path.append("/home/oliverfl/git/delemottelab/demystifying")
sys.path.append(dirname("../../../state_sampling"))
import demystifying as dm
from statesampling import log, colvars, utils


_log =  log.getLogger("analysis")
_log.setLevel('DEBUG')

ligands = np.array([
    'apo', 
    'carazolol',      
    'alprenolol', 
    'timolol',  
    'salmeterol',      
    'adrenaline',      
    'p0g'
])
colors = np.array(['darkkhaki', 
                   'olive', 
                   'forestgreen', 
                   'chartreuse', 
                   'darkslategray',
                   'slateblue', 
                   'midnightblue', 
                   'silver', 
                   'pink', 
                   'darksalmon',
                  ])
markers = dict(
    apo="o", 
    carazolol=">",      
    alprenolol="^", 
    timolol="<",  
    salmeterol="s",      
    adrenaline="p",      
    p0g="h"
)

trajs=[] #Optionally you can load MDTraj trajectories into this list.
working_dir = "../.simu/"
topology = md.load(working_dir + "apo-equilibrated.gro").topology
traj_type="strings" #choose between 'strings' and 'single_state_sampling'
feature_type="inv__contacts__closest-heavy" 
group_by_type=False
_log.info("Done. Using traj_type %s", traj_type)

# Ligand Signaling Effects
## Define values

In [None]:
ligand_to_type = {
    'carazolol' : 'not agonist',
    'apo' :  'not agonist',
    'adrenaline' : 'agonist',
    'alprenolol': 'not agonist',
    'p0g' : 'agonist',
    'salmeterol': 'agonist',
    'timolol': 'not agonist'
}
ligand_types = [
    'not agonist',
    'agonist'
]
ligand_to_abbreviation = {
    'carazolol' : 'CAU',
    'apo' :  'APO',
    'adrenaline' : 'ALE',
    'alprenolol': 'ALP',
    'p0g' : 'P0G',
    'salmeterol': 'SAL',
    'timolol': 'TIM', 
}
ligand_abbreviations = np.array([ligand_to_abbreviation[l] for l in ligands])    
ligand_to_effect = {
    #See http://molpharm.aspetjournals.org/content/85/3/492/tab-figures-data
    #The lower the EC50, the less the concentration of a drug is required to produce 50% of maximum effect and the higher the potency
    #measured as pEC50 (high values -> strong response )
    'apo' :  dict(
        cAMP_pEC50=0, 
        cAMP_Emax=0.,
        cAMP_Emax_ste=None,
        pERK12_Emax=0,
        Ca2_Emax=0,
        Endocytosis_Emax=0,
        exval_tm5_bulge=1.24309369,
        exval_Connector_deltaRMSD=0.02319871, 
        exval_TM6_TM3_distance=0.86821019,
        exval_Ionic_lock_distance=0.98442891,
        exval_YY_motif=1.23168429,        
        exval_Pro211_Phe282=0.61553677,  
        exval_tm5_bulge_ste=0.00020138,
        exval_Connector_deltaRMSD_ste=0.00011046, 
        exval_TM6_TM3_distance_ste=0.00146633,
        exval_Ionic_lock_distance_ste=0.0023686,
        exval_YY_motif_ste=0.00167255,        
        exval_Pro211_Phe282_ste=None,   
        awh_nb_tm5=1.2993,
        exval_TM6_TM3_distance_ca=1.05442884,
        exval_TM6_TM3_distance_ca_ste=None,                
    ),    
    'carazolol' : dict(
        cAMP_pEC50=0, 
        cAMP_Emax=0.,
        cAMP_Emax_ste=None,        
        pERK12_Emax=0,
        Ca2_Emax=0,
        Endocytosis_Emax=0,
        exval_tm5_bulge=1.22956958,
        exval_Connector_deltaRMSD=0.017588282, 
        exval_TM6_TM3_distance=0.81141607,
        exval_Ionic_lock_distance=0.92297367,
        exval_YY_motif=1.1612276,        
        exval_Pro211_Phe282=0.59886003,  
        exval_tm5_bulge_ste=0.00036111,
        exval_Connector_deltaRMSD_ste=8.41365600e-05, 
        exval_TM6_TM3_distance_ste=0.00203712,
        exval_Ionic_lock_distance_ste=0.00321631,
        exval_YY_motif_ste=0.00243157,        
        exval_Pro211_Phe282_ste=None,   
        awh_nb_tm5=1.3218,    
        exval_TM6_TM3_distance_ca=1.105435,
        exval_TM6_TM3_distance_ca_ste=None,  
        
    ), # carazolol's experimental taken from apo. cannot use carvedilol since it is a G protein antagonist and arrest agonist...
    'alprenolol':  dict(
        cAMP_pEC50=9.81, 
        cAMP_Emax=35.76,
        cAMP_Emax_ste=3.70,        
        pERK12_Emax=102.3,
        Ca2_Emax=0,
        Endocytosis_Emax=0,      
        exval_tm5_bulge=1.21801455,
        exval_Connector_deltaRMSD=0.02275429, 
        exval_TM6_TM3_distance=0.75567965,
        exval_Ionic_lock_distance=0.94340876,
        exval_YY_motif=1.23447094,           
        exval_Pro211_Phe282=0.57185033,   
        exval_tm5_bulge_ste=0.0001231,
        exval_Connector_deltaRMSD_ste=0.00014594, 
        exval_TM6_TM3_distance_ste=0.00232501,
        exval_Ionic_lock_distance_ste=0.00246458,
        exval_YY_motif_ste=0.00228454,        
        exval_Pro211_Phe282_ste=None,   
        awh_nb_tm5=1.3143,    
        exval_TM6_TM3_distance_ca=1.05058275,
        exval_TM6_TM3_distance_ca_ste=None,       
    ), 
    'timolol': dict(
        cAMP_pEC50=8.81, 
        cAMP_Emax=-44.45,
        cAMP_Emax_ste=5.28,        
        pERK12_Emax=0,
        Ca2_Emax=0,
        Endocytosis_Emax=0,
        exval_tm5_bulge=1.28462016,
        exval_Connector_deltaRMSD=0.01608262, 
        exval_TM6_TM3_distance=0.83716437,
        exval_Ionic_lock_distance=0.95976886,
        exval_YY_motif=1.42438547,          
        exval_Pro211_Phe282=0.61874432,   
        exval_tm5_bulge_ste=0.00030679,
        exval_Connector_deltaRMSD_ste=0.00021453, 
        exval_TM6_TM3_distance_ste=0.0033251,
        exval_Ionic_lock_distance_ste=0.00569959,
        exval_YY_motif_ste=0.00087478,        
        exval_Pro211_Phe282_ste=None,   
        awh_nb_tm5=1.347,       
        exval_TM6_TM3_distance_ca=1.0556941,
        exval_TM6_TM3_distance_ca_ste=None,     
    ),     
    'salmeterol': dict(
        cAMP_pEC50=8.63, 
        cAMP_Emax=105.5,
        cAMP_Emax_ste=6.11, 
        pERK12_Emax=74.39,
        Ca2_Emax=34.45,
        Endocytosis_Emax=0,
        exval_tm5_bulge=1.14754131,
        exval_Connector_deltaRMSD=0.04149619, 
        exval_TM6_TM3_distance=0.85799139,
        exval_Ionic_lock_distance=1.00344138,
        exval_YY_motif=1.14292594,          
        exval_Pro211_Phe282=0.57731706,   
        exval_tm5_bulge_ste=0.00016367,
        exval_Connector_deltaRMSD_ste=4.00103042e-05, 
        exval_TM6_TM3_distance_ste=0.00184956,
        exval_Ionic_lock_distance_ste=0.00448551,
        exval_YY_motif_ste=0.00067195,        
        exval_Pro211_Phe282_ste=None,    
        awh_nb_tm5=1.2028,  
        exval_TM6_TM3_distance_ca=1.06233784,
        exval_TM6_TM3_distance_ca_ste=None,          
    ),
    'adrenaline' : dict(
        cAMP_pEC50=7.78, 
        cAMP_Emax=74.04,
        cAMP_Emax_ste=6.19, 
        pERK12_Emax=157.7,
        Ca2_Emax=112.0,
        Endocytosis_Emax=115.5,  
        exval_tm5_bulge=1.12983003,
        exval_Connector_deltaRMSD=0.02935141, 
        exval_TM6_TM3_distance=0.8329062,
        exval_Ionic_lock_distance=0.96943942,
        exval_YY_motif=0.86182756,            
        exval_Pro211_Phe282=0.58856012,    
        exval_tm5_bulge_ste=0.00017935,
        exval_Connector_deltaRMSD_ste=4.66225665e-05, 
        exval_TM6_TM3_distance_ste=0.00114867,
        exval_Ionic_lock_distance_ste=0.00524182,
        exval_YY_motif_ste=0.000854,        
        exval_Pro211_Phe282_ste=None,   
        awh_nb_tm5=1.237, 
        exval_TM6_TM3_distance_ca=1.08249786,
        exval_TM6_TM3_distance_ca_ste=None,     
    ),    
    'p0g' : dict(
        # experimental values taken from isoprotenerol.
        cAMP_pEC50=8.23, 
        cAMP_Emax=100.,
        cAMP_Emax_ste=0.15, 
        pERK12_Emax=100.,
        Ca2_Emax=100.,
        Endocytosis_Emax=100.,
        exval_tm5_bulge=1.11438486,
        exval_Connector_deltaRMSD=0.04348406, 
        exval_TM6_TM3_distance=0.78827263,
        exval_Ionic_lock_distance=0.95074948,
        exval_YY_motif=1.22048298,              
        exval_Pro211_Phe282=0.57743262,    
        exval_tm5_bulge_ste=0.00010445,
        exval_Connector_deltaRMSD_ste=3.66579016e-05, 
        exval_TM6_TM3_distance_ste=0.00057464,
        exval_Ionic_lock_distance_ste=0.00098435,
        exval_YY_motif_ste=0.00270393,        
        exval_Pro211_Phe282_ste=None,   
        awh_nb_tm5=1.2462,      
        exval_TM6_TM3_distance_ca=1.03998389,
        exval_TM6_TM3_distance_ca_ste=None,  
    ),     
}
_log.info("Done")

## Plot correlations

In [None]:
def _fix_label(l):
    if l.startswith("exval_"):
        l = "E[{}] [nm]".format(l.replace("exval_", ""))
    l = l.replace("_", " ")
    l = l.replace("tm", "TM")
    l = l.replace("delta delta", "deltadelta")
    l = l.replace("delta", "$\Delta$")
    l = l.replace("connector", "Connector")
    l = l.replace("p0g", "BI-167107")
    l = l.replace("YY", "Y-Y")
    l = l.replace("Emax", "Emax [%]")
    return l

def correlation_plot(xdim='tm5_basin', ydim='cAMP_Emax', 
                     ligands=ligands, 
                     ligand_to_effect=ligand_to_effect,
                     prefix='', 
                     predict=False):
    fig = plt.figure(figsize=(4,4))    
    xvals = []
    yvals = []
    to_predict = []
    def add_to_graph(l_idx, x, y, prediction=False, xerr=None, yerr=None):
        xvals.append(x)
        yvals.append(y)
        color=colors[l_idx]
        plt.errorbar(x,y, fmt='o', color=color, xerr=xerr, yerr=yerr)
        txt="  " + _fix_label(ligands[l_idx]).capitalize() + ("*" if prediction else "")
        plt.text(x,y, txt)
        
    for l_idx, l in enumerate(ligands):
        vals = ligand_to_effect.get(l, None)
        if vals is None:
            continue
        x = vals.get(xdim, None)
        y = vals.get(ydim, None)
        if x is None or (y is None and not predict):
            continue
        elif y is None:
            to_predict.append((l_idx, x))
            continue
        add_to_graph(l_idx, x, y, xerr=vals.get(xdim + "_ste", None), yerr=vals.get(ydim + "_ste", None))
    if len(xvals) is None:
        _log.warn("No input found")
        return
    
    #Regression
    slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(xvals, yvals)
    
    #Make predictions
    for (l_idx, x) in to_predict:
        add_to_graph(l_idx, x, x*slope + intercept, prediction=True)
        
    #Visualize
    xvals = np.array(xvals)
    xlin = np.linspace(xvals.min()  - xvals.std()/4, xvals.max() + xvals.std()/4, 10)
    plt.plot(xlin, xlin*slope + intercept, linestyle='--', color="grey", alpha=0.3, linewidth=5)
    plt.ylabel(_fix_label(ydim))
    plt.xlabel(_fix_label(xdim))
    plt.xlim([xlin.min(), xlin.max()])
    plt.title("R = {:.2f}, p = {:.3f}".format(r_value, p_value))   
    plt.tight_layout(pad=0.3)   
    if not os.path.exists("output/correlations"):
        os.makedirs("output/correlations")
    plt.savefig("output/correlations/{}{}_{}.svg".format(prefix, xdim, ydim)) 
    plt.show()
        
correlation_plot(
    xdim='exval_tm5_bulge',   #exval_tm5_bulge, exval_Connector_deltaRMSD, exval_YY_motif
    #exval_TM6_TM3_distance, exval_Ionic_lock_distance, exval_Pro211_Phe282
    ydim='cAMP_Emax', #cAMP_Emax, Endocytosis_Emax, pERK12_Emax, Ca2_Emax
    predict=True,
)
_log.info("Done")

# Load Data

## Numpy arry IO
**NOTE** that the dataset files need to be downloaded from https://drive.google.com/drive/folders/16_9JX3z2Vmly4ZdNTTG_WqiByzS5mNfc?usp=sharing

In [None]:
dataset_dir="output/datasets/{}/{}/".format(traj_type, feature_type)
_log.info("Using dataset_dir %s", dataset_dir)

## Load samples

In [None]:
with np.load(dataset_dir + "data.npz", allow_pickle=True) as data:
    samples=data['samples'] 
    labels=data['labels']
    ligand_labels=data['ligand_labels']
    feature_to_resids=data['feature_to_resids']
    scaler=data['scaler']
    # Little hack to get the scaler out of the object array below
    scaler.shape=(1,)
    scaler = scaler[0]
if labels.shape[1] != len(ligands):
    raise Exception("Number of loaded classes {} differs from the number of defined ligands {}".format(labels.shape[1], len(ligands)))
_log.info("Done. Loaded samples of shape %s and labels for %s classes", samples.shape, labels.shape[1])

## Save samples

In [None]:
if not os.path.exists(dataset_dir):
    os.makedirs(dataset_dir)
if samples is not None:
    np.savez_compressed(dataset_dir + "data", 
                    samples=samples, 
                    labels=labels, 
                    ligand_labels=ligand_labels, 
                    feature_to_resids=feature_to_resids,
                    scaler=scaler
                   )
_log.info("Done. Saved to %s", dataset_dir)

## Extract features

In [None]:
residue_mapping = pd.read_csv(working_dir + "beta2_generic_residues.csv", 
                              delimiter=';',
                             dtype={'generic':'object', 'beta2':'int'})

def fix_generic_numbers(residues):
    for r in residues:
        if not r.is_protein:
            continue
        generic = residue_mapping[residue_mapping['beta2'] == r.resSeq]['generic']
        generic = generic.values[0] if len(generic) > 0 else None
        r.generic= generic        
        if generic is None:
            r.fullname = "{}{}".format(r.code, r.resSeq)
        else:
            r.fullname = "{}{}({})".format(r.code, r.resSeq, r.generic)

def find_residue(res_id, topology=topology):
    q = "protein and resSeq {}".format(res_id)
    return topology.atom(topology.select(q)[0]).residue
            
def create_index_to_residue(feature_to_resids, topology=topology):
    res = []
    seen_residues = set()
    for axis in range(feature_to_resids.shape[1]):
        for fr in feature_to_resids[:,axis]:
            if fr in seen_residues:
                continue
            r = find_residue(fr, topology)
            #generic = residue_mapping[residue_mapping['beta2'] == r.resSeq]['generic']
            #generic = generic.values[0] if len(generic) > 0 else None
            #r.generic= generic
            #if generic is None:
            #    r.fullname = "{}{}".format(r.code, r.resSeq)
            #else:
            #    r.fullname = "{}{}({})".format(r.code, r.resSeq, r.generic)
            res.append(r)
            seen_residues.add(fr)
        
    return np.array(res)


def to_relevant_residues(topology=topology, 
                         ignored_residues = [],
                         included_residues=None):
    ignored_residues.append(24) #For some reasone this residue gives really huge rmsd for a few frames, maybe broken PBCs. 
    residues = [
        r for r in topology.residues
        if r.is_protein and r.resSeq not in ignored_residues
    ]
    if included_residues is not None and len(included_residues) > 0:
        residues = [r for r in residues if r.resSeq in included_residues]        
    return np.array(residues)


def to_rmsd_cvs(topology=topology, ignored_residues=[], included_residues=None):
    cvs = []
    feature_to_resids = []
    for r in to_relevant_residues(topology, ignored_residues=ignored_residues, included_residues=included_residues):
        q = "protein and resSeq {} and element != 'H'".format(r.resSeq)
        active_cv = colvars.cvs.RmsdCv(ID="active-rmsd_{}".format(r), 
                                   name="Active RMSD {}".format(r.fullname), 
                                   reference_structure=active_traj,
                                   query=q)
        inactive_cv = colvars.cvs.RmsdCv(ID="inactive-rmsd_{}".format(r), 
                                   name="Inactive RMSD {}".format(r.fullname), 
                                   reference_structure=inactive_traj,
                                   query=q)
        cvs.append(active_cv),
        feature_to_resids.append([r.resSeq])        
        cvs.append(inactive_cv)  
        feature_to_resids.append([r.resSeq])
    return np.array(cvs), np.array(feature_to_resids)


def to_contact_cvs(topology=topology, scheme="ca", inverse=True, ignored_residues=[], included_residues=None):
    cvclass = colvars.InverseContactCv if inverse else colvars.ContactCv
    residues = to_relevant_residues(topology, ignored_residues=ignored_residues, included_residues=included_residues)
    residue_combos = []
    for idx, r1 in enumerate(residues):
        for r2 in residues[idx+3:]:
            residue_combos.append((r1,r2))
    cvs = np.array([
        cvclass(res1=res1.resSeq, 
                res2=res2.resSeq, 
                scheme=scheme,
                name="{}-{}".format(res1.fullname, res2.fullname), 
                ID="|{}-{}|^{}({})".format(res1, res2, -1 if inverse else 1, scheme))
        for (res1, res2) in residue_combos
    ])
    feature_to_resids = np.array([[cv.res1, cv.res2] for cv in cvs])
    return cvs, feature_to_resids

def to_features(trajs, feature_type):
    ignored_residues = caps_residues
    included_residues = []
    if "noligand" in feature_type:
        ignored_residues += ligand_interactions
        _log.info("Excluding ligand binding site interactions")
    if "npxxy" in feature_type:
        included_residues += [322, 323, 324, 325, 326, 327]
    if "conserved" in feature_type:
        included_residues += [51, 79, 131, 158, 211, 288, 323]
    if 'demystifying-cvs' in feature_type:
        cvs = demystifying_cvs
        feature_to_resids = np.array([[cv.res1, cv.res2] for cv in cvs])        
    elif 'sidechain-rmsd' in feature_type:
        cvs, feature_to_resids = to_rmsd_cvs(topology, ignored_residues=ignored_residues, included_residues=included_residues)
    elif 'contacts__' in feature_type:
        scheme = feature_type.split("__")[-1]
        inverse = 'inv__' in feature_type      
        cvs, feature_to_resids = to_contact_cvs(topology, inverse=inverse, scheme=scheme, ignored_residues=ignored_residues, included_residues=included_residues)
    else:
        raise Exception("Invalid feature type {}".format(feature_type))
    for cv in cvs:
        if cv.name is None:
            cv.name = cv.id
    return [colvars.eval_cvs(cvs, t) for t in trajs], cvs, feature_to_resids

def fix_residue_format_on_cv_names(cvs,topology=topology):
    for cv in cvs:
        if hasattr(cv, "res1") and hasattr(cv, "res2"):
            r1 = find_residue(cv.res1, topology)
            r2 = find_residue(cv.res2, topology)
            cv.name = "{}-{}".format(r1.fullname, r2.fullname)  

fix_generic_numbers(topology.residues)
caps_residues=[23, 27, 227, 266, 344]
ligand_interactions = [109, 113, 114, 117, 193, 195, 203, 204, 207, 286, 289, 290, 293, 308, 309, 312]

demystifying_cvs = colvars.io.load_cvs(working_dir + "/cvs.json")
fix_residue_format_on_cv_names(demystifying_cvs)
rmsd_cvs = to_rmsd_cvs(topology) if len(trajs) > 0 else []
_log.info("Loaded %s demystifying cvs and %s rmsd cvs", len(demystifying_cvs), len(rmsd_cvs))
features, cvs, feature_to_resids = to_features(trajs, feature_type)
index_to_residue = create_index_to_residue(feature_to_resids)
if len(trajs) > 0:
    _log.info("Computed %s features for %s datasets (%s)", features[0].shape[1] , len(features), feature_type)
else:
    _log.info("Loaded nothing else since there are no trajectories")

# Demystifying

In [None]:
def _get_important_residues(importance, index_to_residue=index_to_residue, importance_cutoff=0.5, 
                            count_cutoff=100):
    index_value_importance = [
        (idx, imp) 
        for (idx,imp) in 
        sorted(enumerate(importance), key=lambda t: t[1], reverse=True)
    ]
    res = dict()
    for (idx, imp) in index_value_importance:
        if imp < importance_cutoff or len(res) == count_cutoff:
            break
        residue = index_to_residue[idx]
        label = residue.fullname
        res[label] = residue.resSeq
    default = _get_default_important_residues()
    return dict(**res, **default)

def _get_default_important_residues(supervised=None, feature_type=feature_type):
    # From 2020 BPJ paper 
    res = dict(
        #npxxy = [322, 323, 324, 325, 326],
        #yy = [219, 326],
        #ligand_interactions = [109, 113, 114, 117, 193, 195, 203, 204, 207, 286, 289, 290, 293, 308, 309, 312],
        most_conserved_TM_residues = [51, 79, 131, 158, 211, 288, 323, 332],
        #dry = [130, 131, 132],
        #pif = [121, 211, 282],
        #m82 = [82],
        g_prot_interactions=[131, 134, 135, 136, 138, 139, 141, 142, 143, 
        222, 225, 226, 228, 229, 230, 232, 233, 271, 274, 275],

    )
    """ OLD rom original paper
    if supervised:
        if "rmsd" in feature_type:
            return {
                # 'Ligand interactions': ligand_interactions,
                'PIF': pif,
                'M82': [82],  # , 286, 316],
                'DRY': dry,
                # 'NPxxY': npxxy,
                # 'Most conserved TM residues': most_conserved_TM_residues
            }
        else:
            return {
                #'Ligand interactions': ligand_interactions,
                #'D79': [79],
                'E268': [268],
                #'L144': [144],
                'NPxxY': npxxy,
                
            }
    else:
        return {
            'NPxxY': npxxy,
            'End of TM6': [268, 272, 275, 279],
            'L144': [144],
        }
    """
    return res


demystifying_dir = "output/demystifying/"
def extract_features(feature_extractors, overwrite=False):
    results_dir = "{wdir}/{traj_type}/{feature_type}/{by_type}/".format(
        wdir=demystifying_dir,
        traj_type=traj_type,
        feature_type=feature_type,
        by_type="by_type" if group_by_type else ""
    )   
    postprocessors = []
    for extractor in feature_extractors:
        do_computations = True
        if not overwrite and os.path.exists(results_dir):
            existing_files = glob.glob("{}/{}/importance_per_residue.npy".format(results_dir, extractor.name))
            if len(existing_files) > 0 and not overwrite:
                _log.debug("File %s already exists. skipping computations", existing_files[0])
                do_computations = False
        if do_computations:
            _log.info("Computing importance for extractor %s", extractor.name)
            extractor.extract_features()
        p = extractor.postprocessing(working_dir=results_dir,
                                     pdb_file=demystifying_dir + "/all.pdb",
                                     feature_to_resids=feature_to_resids)
        if do_computations:
            p.average()
            p.evaluate_performance()
            p.persist()
        else:
            p.load()
        postprocessors.append([p])
    return np.array(postprocessors)

    
def _generate_structures(postprocessor):
    cmd_template = "bash single_render.sh {wdir} {view} {feature_type} {classifier_type} {traj_type} {state}"
    states = np.array([""])
    if postprocessor.extractor.supervised:
        states = np.append(states, ligand_types if group_by_type else ligands)
    for view in ['side', 'top']:
        for state in states:
            cmd = cmd_template.format(
                wdir=demystifying_dir,
                view=view,
                feature_type=feature_type,
                classifier_type=postprocessor.extractor.name,
                traj_type=traj_type,
                state= state + "_grouped" if group_by_type else state,
            )
            #_log.info(cmd)
            try:
                subprocess.run(cmd.strip().split(" "))
            except Exception as err:
                _log.exception(err)
                _log.warning("Failed to execute command %s", cmd)
                

def _generate_line_graphs(postprocessor):
    p = postprocessor
    states = np.array([None])
    if postprocessor.supervised:
        states = np.append(states, ligand_types if group_by_type else ligands)
    #highlighted_residues = _get_default_important_residues(supervised=p.extractor.supervised)
    importance_per_residue = p.importance_per_residue
    for index, state in enumerate(states):
        if state is None:
            p.importance_per_residue = importance_per_residue
        else:
            p.importance_per_residue = p.importance_per_residue_and_cluster[:, index-1]
        highlighted_residues = _get_important_residues(p.importance_per_residue, 
                                                       count_cutoff=10,
                                                       importance_cutoff=0.2 if p.supervised else 0.1)
        outfile ="{outdir}/importance_per_residue_{traj_type}_{feature_type}_{classifier}.svg".format(
            outdir=p.get_output_dir(),
            traj_type=traj_type,
            feature_type=feature_type,
            classifier=p.extractor.name + ("" if state is None else "_" + state),
        ) 
        dm.visualization.visualize([[p]],
                                show_importance=True,
                                show_performance=False,
                                show_projected_data=False,
                                mixed_classes=False,
                                plot_title=p.extractor.name + ("" if state is None else " - " + state),
                                highlighted_residues=highlighted_residues,
                                outfile=outfile)
        display(SVG(filename=outfile))
        plt.close()
    p.importance_per_residue = importance_per_residue

def _generate_snakeplots(postprocessor):
    p = postprocessor
    states = np.array([None])
    if postprocessor.supervised:
        states = np.append(states, ligand_types if group_by_type else ligands)
    #TODO iterate over states
    # see https://stackoverflow.com/questions/24726528/replacing-inner-contents-of-an-svg-in-python
    from lxml import etree
    SVGNS = u"http://www.w3.org/2000/svg"
    with open(working_dir + "/snake_adrb2_human.svg", 'r') as file:
        #Open the snakeplot downloaded from GPCRdb
        template_svg = file.read()
    
    cmap = plt.get_cmap("Blues")
    
    xml_data = etree.fromstring(template_svg)
    for index, state in enumerate(states):
        importances = p.importance_per_residue if state is None else p.importance_per_residue_and_cluster[:, index-1]
        for r, imp in enumerate(importances):
            resid = index_to_residue[r].resSeq
            # We search for element 'text' with id='tile_text' in SVG namespace
            ss = "//{%s}circle[@id='%d']" % (SVGNS, resid)
            #print(ss)
            find_residue = etree.ETXPath(ss)
            # find_residue(xml_data) returns a list 
            # take the 1st element from the list, replace the fill   
            #See https://docs.python.org/2/library/xml.etree.elementtree.html#modifying-an-xml-file
            data = find_residue(xml_data)
            if len(data) > 0:
                color = mpl.colors.to_hex(cmap(imp))
                #print(color)
                data[0].set('fill', color)
            else:
                _log.warning("No SVG element found for residue %s", resid)
        #Save
        outfile ="{outdir}/snakeplot_importance_{traj_type}_{feature_type}_{classifier}.svg".format(
            outdir=p.get_output_dir(),
            traj_type=traj_type,
            feature_type=feature_type,
            classifier=p.extractor.name + ("" if state is None else "_" + state),
        )     
        new_svg = etree.tostring(xml_data)
        with open(outfile, "wb") as of:
            of.write(new_svg)
        display(SVG(filename=outfile))

    

def visualize_importance(postprocessors):
    for [p] in postprocessors:
        #_generate_line_graphs(p)
        #_generate_snakeplots(p)
        #_generate_structures(p)
        continue
        
        
kwargs = dict(
    samples=samples.copy(),
    labels=group_labels_by_type(labels) if group_by_type else labels.copy(),
    label_names=ligand_types if group_by_type else ligands,
    filter_by_distance_cutoff=False,
    use_inverse_distances=True,
    n_splits=1,
    shuffle_datasets=False,
)
_log.info("Done")

## Supervised

In [None]:
supervised_feature_extractors = [
    dm.feature_extraction.KLFeatureExtractor(n_iterations=3,**kwargs),
    dm.feature_extraction.RandomForestFeatureExtractor(
        n_iterations=10,
        one_vs_rest=True,
        classifier_kwargs=dict(n_estimators=100, n_jobs=-1),
        **kwargs)
]
supervised_postprocessors = extract_features(supervised_feature_extractors, overwrite=False)
visualize_importance(supervised_postprocessors)
_log.debug("Done")

## Unsupervised

In [None]:
#RBM requires data to be scaled with an upper limit of 1
rbm_kwargs = dict(**kwargs)
rbm_kwargs['samples'] = MinMaxScaler().fit_transform(scaler.inverse_transform(samples))
rbm_kwargs['shuffle_datasets'] = True
unsupervised_feature_extractors = [
    dm.feature_extraction.PCAFeatureExtractor(classifier_kwargs=dict(n_components=2),
                            supervised=True,
                           variance_cutoff='2_components',
                           **kwargs),
   #dm.feature_extraction.RbmFeatureExtractor(
    #    supervised=True,
    #    n_iterations=50,
    #    classifier_kwargs=dict(n_components=100, learning_rate=1e-3),
    #    **rbm_kwargs 
    #)
]
unsupervised_postprocessors = extract_features(unsupervised_feature_extractors, 
                                               overwrite=False)
visualize_importance(unsupervised_postprocessors)

# Visualize data with projections

## Help methods

In [None]:
def create_importance_cmap(color=None, N = 1024):
    #see https://matplotlib.org/3.1.0/tutorials/colors/colormap-manipulation.html
    if color is None:
        color =np.array([135, 21, 0])/256
    elif isinstance(color, str):
        color = mpl.colors.to_rgb(color)
    vals = np.zeros((N, 4))
    max_color = np.array([1,1,1])*0.95
    min_color = color 
    vals[:, 0] = np.linspace(min_color[0], max_color[0], N)
    vals[:, 1] = np.linspace(min_color[1], max_color[1], N)
    vals[:, 2] = np.linspace(min_color[2], max_color[2], N)
    return ListedColormap(vals)


"""
Marks every snapshot with a ligand-specific marker
colors every snapshots according to its index
"""
def plot_activation_ligands(X, ligands=ligands, labels=labels, 
                 xlabel=None, ylabel=None, show_title=True,
                 method=None, savefig=True, alpha=1., subplots=False, ncols=3):

    #values = ligand_labels if group_by_type else ligands
    values = ligands
    if subplots:
        fig, axs = plt.subplots(ncols=ncols, 
                        nrows=1+int(len(ligands)/ncols), 
                        figsize=(12, 10),
                        squeeze=True,
                        sharey=True, 
                        sharex=True)
    else:
        fig = plt.figure(figsize=(4,4))
    row, col = 0, 0
    cmap = create_importance_cmap()   
    for i, lig in enumerate(values):
        l = _fix_label(lig)
        indices = labels[:, i] == 1
        if subplots:
            plt.sca(axs[row, col])
        if len(X.shape) < 2 or X.shape[1] == 1:
            plt.hist(X[indices], label=l, alpha=alpha, color=colors[i], density=True)
        else:
            xx = X[indices]
            plt.scatter(xx[:, 0], xx[:, 1], 
                        label=l, 
                        alpha=alpha, 
                        marker=markers.get(lig ,"."),
                        #edgecolors=cmap(1),                        
                        color=cmap(np.linspace(0, 1, len(xx))), 
                        s=8)
            #plt.scatter(xx[0,0 ], xx[0, 1], marker='d', color='black')
            #plt.scatter(xx[-1,0 ], xx[-1, 1], marker='^', color='black')
        col += 1
        if col >= ncols:
            col = 0
            row += 1        
        plt.legend()
    if show_title:
        plt.title("{}\n{}".format("" if method is None else method, traj_type))
    if xlabel is not None:
        plt.xlabel(xlabel)
    if ylabel is not None:
        plt.ylabel(ylabel)    
    if savefig:
        plt.savefig("output/projections/{}_{}_{}.svg".format(method, feature_type, traj_type))
    plt.show()
    
"""
Marks every snapshot with a ligand-specific color
"""
def plot_state_ligands(X, ligands=ligands, labels=labels, 
                 xlabel=None, ylabel=None, show_title=True,
                 method=None, savefig=True, alpha=1., subplots=False, ncols=3):
    #values = ligand_labels if group_by_type else ligands
    values = ligands
    if subplots:
        fig, axs = plt.subplots(ncols=ncols, 
                        nrows=1+int(len(ligands)/ncols), 
                        figsize=(12, 10),
                        squeeze=True,
                        sharey=True, 
                        sharex=True)
    else:
        fig = plt.figure(figsize=(4,4))
    row, col = 0, 0
    for i, lig in enumerate(values):
        l = _fix_label(lig)
        indices = labels[:, i] == 1
        if subplots:
            plt.sca(axs[row, col])
        if len(X.shape) < 2 or X.shape[1] == 1:
            plt.hist(X[indices], label=l, alpha=alpha, color=colors[i], density=True)
        else:
            plt.scatter(X[indices, 0], X[indices, 1], 
                        label=l, 
                        alpha=alpha, 
                        color=colors[i], 
                        marker=markers.get(lig ,"."),    
                        s=2)
        col += 1
        if col >= ncols:
            col = 0
            row += 1        
        plt.legend()
    if show_title:
        plt.title("{}\n{}".format("" if method is None else method, traj_type))
    if xlabel is not None:
        plt.xlabel(xlabel)
    if ylabel is not None:
        plt.ylabel(ylabel)    
    if savefig:
        plt.savefig("output/projections/{}_{}_{}.svg".format(method, feature_type, traj_type))
    plt.show()
    
def plot_ligands(X, **kwargs):
    if traj_type == 'strings':
        return plot_activation_ligands(X, **kwargs)
    else:
        return plot_state_ligands(X, **kwargs)

## PCA

In [None]:
from sklearn.decomposition import PCA
pca = PCA()#n_components=2)
X_pca = pca.fit_transform(samples)
plot_ligands(X_pca[:, :2], method="PCA")
plot_ligands(X_pca[:, 2:4], method="PCA2")

## TSNE

In [None]:
from sklearn.manifold import *
# Setting parameters to make the method more deterministic, 
# see https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html
tsne_config=dict(n_components=2, random_state=0, n_jobs=-1,perplexity=30, learning_rate=200)
_log.debug("Using config %s", tsne_config)
tsne = TSNE(**tsne_config)
X_tsne = tsne.fit_transform(samples)
plot_ligands(X_tsne, method="TSNE")

## MDS
see https://scikit-learn.org/stable/modules/generated/sklearn.manifold.MDS.html#sklearn.manifold.MDS

In [None]:
from sklearn.manifold import *

config=dict(n_components=2, 
            random_state=0, 
            metric=True,
            n_jobs=-1)
_log.debug("Using config %s", config)
mds = MDS(**config)
X_mds = mds.fit_transform(samples)
plot_ligands(X_mds, method="MDS")

## Demystiyfing results
### Methods

In [None]:
def clean_id(cv):
    d = cv.name
    d = d.replace("^-1(closest-heavy)", "")
    d = d.replace("|", "")
    return d

def project_top_features(postprocessors, max_features=6, distance_cutoff=0.5):
    if 'inv' in feature_type:
        distances = 1/scaler.inverse_transform(samples)
    else:
        distance = scaler.inverse_transform(samples)
    
    for [p] in postprocessors:
        imps = p.get_important_features()
        counter = 0
        for i, (idx1, imp1) in enumerate(imps):
            if i % 2 == 1:
                continue
            idx2, imp2 = imps[i+1]
            idx1, idx2 = int(idx1), int(idx2)
            id1, id2 = clean_id(cvs[idx1]), clean_id(cvs[idx2])
            X_imps = distances[:, [idx1, idx2]]
            if X_imps.min() > distance_cutoff:
                continue
            _log.debug("#%s with importance %s and %s: %s-%s", i, imp1, imp2, id1, id2)
            method="{}_{}_{}".format(p.extractor.name, id1, id2)
            _log.debug(method)
            plot_ligands(X_imps, 
                         xlabel=id1,
                         ylabel=id2,
                         method=method)
            compute_pairwise_similarity(X_imps, 
                                method=neg_frame_to_frame_distance, 
                                title=method) 
            counter += 1
            if counter >= max_features:
                break


### Supervised

In [None]:
project_top_features(supervised_postprocessors)
_log.info("Done")

### Automatically find top ranked features for certain residues with close contacts

In [None]:
def find_close_contacts(postprocessors, limit = 10, distance_cutoff=0.6,
                        importance_cutoff=0.5,
                        states=None,
                       to_find = [321, 206, 207, 327, 312, 315],
                    to_ignore=[]
                       ):
    if 'inv' in feature_type:
        distances = 1/scaler.inverse_transform(samples)
    else:
        distance = scaler.inverse_transform(samples)
    to_find = [str(s) for s in to_find]
    to_ignore=[str(s) for s in to_ignore]
    for [p] in postprocessors:
        counter = 0
        imps = p.get_important_features(states=states)
        for i in range(0,len(imps),2):

            idx1, imp1 = imps[i]
            idx2, imp2 = imps[i+1]
            if imp1 < importance_cutoff:
                break         
            idx1, idx2 = int(idx1), int(idx2)
            id1, id2 = clean_id(cvs[idx1]), clean_id(cvs[idx2])
            method="{}_{}_{}".format(p.extractor.name, id1, id2)
            found = False
            for ti in to_ignore:
                if ti in method:
                    found = True
                    break
            if found:
                continue
            for tf in to_find:
                if tf in method:
                    found = True
            if not found:
                continue
            X_imps = distances[:, [idx1, idx2]]
            if X_imps.min() > distance_cutoff:
                #Look for close contacts
                continue
            _log.debug("#%s with importance %s and %s: %s-%s. States: %s", i, imp1, imp2, id1, id2, states)
            plot_ligands(X_imps, 
                         xlabel=id1, ylabel=id2,
                         method=method)
            counter += 1
            if counter == limit:
                break

                
if group_by_type:
    find_close_contacts(supervised_postprocessors)
else:
    for l_idx, ligand in enumerate(ligands):
        _log.info("---------%s-------", ligand)
        find_close_contacts(supervised_postprocessors, states=[l_idx], 
                            limit=3, 
                            distance_cutoff=0.5, 
                            importance_cutoff=0.1)
_log.info("Done")

### Unsupervised

In [None]:
project_top_features(unsupervised_postprocessors)
_log.info("Done")

### Select residues to plot against

In [None]:
def find_feature_index(residues, feature_to_resids=feature_to_resids):
    for idx, row in enumerate(feature_to_resids):
        if len(row) == len(residues):
            found = True
            for r in row:
                if r not in residues:
                    found = False
                    break
            if found:
                return idx
    _log.warning("No features found for residues %s", residues)
            
def find_for_residues(postprocessor, 
                       residue_pairs = [],
                      ligands=ligands,
                       ):
    p = postprocessor
    if 'inv' in feature_type:
        distances = 1/scaler.inverse_transform(samples)
    else:
        distance = scaler.inverse_transform(samples)
    imps = p.feature_importances.mean(axis=1)
    imps = (imps-imps.min())/(imps.max()-imps.min())
    for pair1, pair2 in residue_pairs:
        idx1 = find_feature_index(pair1)
        idx2  = find_feature_index(pair2)
        cv1 = cvs[idx1]
        cv2 = cvs[idx2]
        imp1 = imps[idx1]
        imp2 = imps[idx2]
        idx1, idx2 = int(idx1), int(idx2)
        id1, id2 = clean_id(cv1), clean_id(cv2)
        method="{}_{}_{}".format(p.extractor.name, id1, id2)
        X_imps = distances[:, [idx1, idx2]]
        _log.debug("with importance %s and %s: %s-%s", imp1, imp2, id1, id2)
        plot_ligands(X_imps, 
                     ligands=ligands,
                     show_title=False,
                     xlabel=id1, ylabel=id2,
                     method=method)

#Pathways
find_for_residues(
    supervised_postprocessors[0, 0],
    [
        [(118, 206), (284, 321)],
        [(277, 327), (281, 325)],
        [(207, 307), (203, 338)],
        [(131, 272), (326, 285)], #Include if we need a scatter plot with C285
        
    ]
)

_log.info("Done")

### Ligand specific - primarily single states 

In [None]:
#Singling out salmeterol and alprenolol ligands
ligand_projection_mapping =  dict(
    apo=[(222, 271), (225, 268), (272, 131), (75, 322), (127, 321)],
    carazolol=[(51, 319)],
    alprenolol=[(79, 321), (51, 319)],
    timolol=[(274, 321), (50, 327)],
    salmeterol=[(285, 326), (136, 272), (321, 326)],
    #For adrenaline and p0g, see also agonist vs non-agonist plots above
    adrenaline=[(113, 308),(275, 326)],
    p0g=[(275, 326)]
)
find_for_residues(
    supervised_postprocessors[0, 0],
[
    [(225, 268), (272, 131)], #apo
   # [(79, 321), (51, 319)], #carzolol and alprenolol
    #[(127, 321), (51, 319)], #agonists, apo, alprenolol and carazolol
    [(79, 321), (79, 322)], #apo and alprenolol     
    [(274, 321), (50, 327)], #timolol
    [(321, 326), (136, 272)], #salmeterol
    [(51, 319), (275, 326)], #apo, alprenolol and carazolol, adrenaline and p0g
]    
)

# Compare similarity 
## Methods

In [None]:
from scipy.stats import entropy
eps = 1e-4 

def jaccard_similarity(x1, x2):
    #Probably not a good measure after all!
    #from https://stackoverflow.com/questions/46975929/how-can-i-calculate-the-jaccard-similarity-of-two-lists-containing-strings-in-py
    intersection = len(list(set(x1).intersection(x2)))
    union = (len(x1) + len(x2)) - intersection
    return float(intersection) / union

def cluster_similarity(x1, x2):
    if len(x1) != len(x2):
        raise Exception("Clusters must be of same size")
    x1,x2 = x1.squeeze(), x2.squeeze()
    clusters = set(x1)
    for xx in x2:
        clusters.add(xx)
    similarity = 0
    diff = 0
    for c in clusters:
        x1c = x1[x1 == c]
        x2c = x2[x2 == c]
        diff += abs(len(x1c)-len(x2c))        
    similarity = 1 / (1 + diff)
    return similarity
        
def inv_center_dist(x1, x2):
    return 1/(1+np.linalg.norm(x1.mean(axis=0) - x2.mean(axis=0)))

def KL_divergence(x1, x2, bin_width=None, symmetric=False):
    """
    Compute Kullback-Leibler divergence
    From demystifying repo
    """
    n_features = x1.shape[1] 

    DKL = np.zeros(n_features)
    if bin_width is not None:
        tmp_bin_width = bin_width

    for i_feature in range(n_features):
        xy = np.concatenate((x1[:, i_feature], x2[:, i_feature]))
        bin_min = np.min(xy)
        bin_max = np.max(xy)

        if bin_width is None:
            tmp_bin_width = np.std(x1[:, i_feature])
            if tmp_bin_width == 0:
                tmp_bin_width = 0.1  # Set arbitrary bin width if zero
        else:
            tmp_bin_width = self.bin_width

        if tmp_bin_width >= (bin_max - bin_min):
            DKL[i_feature] = 0
        else:
            bin_n = int((bin_max - bin_min) / tmp_bin_width)
            x1_prob = np.histogram(x1[:, i_feature], bins=bin_n, range=(bin_min, bin_max), density=True)[0] + 1e-9
            x2_prob = np.histogram(x2[:, i_feature], bins=bin_n, range=(bin_min, bin_max), density=True)[0] + 1e-9
            #TODO should we use symmetrized KL as done below?
            if symmetric:
                DKL[i_feature] = 0.5 * (entropy(x1_prob, x2_prob) + entropy(x2_prob, x1_prob))
            else:
                DKL[i_feature] = entropy(x1_prob, x2_prob)
    return DKL

def avg_KL(x1,x2):
    return KL_divergence(x1,x2).mean()

def neg_avg_KL(x1,x2):
    return -avg_KL(x1,x2)


def frame_to_frame_distance(x1, x2):
    dist = 0
    for xx1 in x1:
        dist += np.linalg.norm(xx1-x2, axis=1).sum()
        #for xx2 in x2:
        #    dist += np.linalg.norm(xx1-xx2)
    dist /= x1.shape[0]*x2.shape[0]  
    return dist
        
def compute_pairwise_similarity(X, method, title=None, ligand_labels=ligand_labels,
                                ligands=ligands, ligand_types=ligand_types):
    if len(X.shape) < 2:
        X = X[:,np.newaxis]
    ligands = ligand_types if group_by_type else ligands
    data = np.zeros((len(ligands), (len(ligands)))) + np.nan
    for idx1, l1 in enumerate(ligands):
        #Partition data
        l1_indices = ligand_labels[:,idx1] == 1
        x1 = X[l1_indices]
        for idx2, l2 in enumerate(ligands):
            l2_indices = ligand_labels[:, idx2] == 1
            x2 = X[l2_indices]
            #print(l1,l2, x1.shape, x2.shape)
            data[idx1, idx2 ] = method(x1, x2)        
        # normalize
        #stats[:] = (stats - stats.min())/(stats.max() - stats.min())
    # normalize
    data = (data - data.min())/(data.max() - data.min())
    
    #Plot, see https://matplotlib.org/3.1.1/gallery/images_contours_and_fields/image_annotated_heatmap.html
    fig, ax = plt.subplots()
    im = plt.imshow(data, cmap=plt.get_cmap("Blues")) #YlGnBu
    # We want to show all ticks...
    ax.set_xticks(np.arange(len(ligands)))
    ax.set_yticks(np.arange(len(ligands)))
    # ... and label them with the respective list entries    
    ax.set_xticklabels(ligands)
    ax.set_yticklabels(ligands)
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
         rotation_mode="anchor")
    # Create colorbar
    cbar = ax.figure.colorbar(im, ax=ax)
    cbar.ax.set_ylabel("Similarity", rotation=-90, va="bottom")
    # Show and save
    plt.title("{}\n{}".format(title, traj_type))
    plt.tight_layout(pad=0.3)
    plt.savefig("output/similarities/{}_feature_type_{}_{}{}.svg".format(title, feature_type, 
                                                                         traj_type,
                                                                         "_bytype" if group_by_type else ""
                                                                        ))
    plt.show()        

def inv_frame_to_frame_distance(x1, x2):
    return 1/(eps + frame_to_frame_distance(x1,x2))

def neg_frame_to_frame_distance(x1, x2):
    return -frame_to_frame_distance(x1,x2)

def inv_KL(x1, x2):
    kl = KL_divergence(x1,x2)
    return 1/(eps+ kl )

def neg_KL(x1, x2):
    return -KL_divergence(x1,x2)

KL = KL_divergence

_log.info("Done")

## Using Eucledian distance

In [None]:
compute_pairwise_similarity(samples, 
                            method=neg_frame_to_frame_distance, 
                            title="full-frame_to_frame_distance") 
_log.info("Done")