# Init

In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
#imports
import logging
import sys
logging.basicConfig(
    stream=sys.stdout,
    level=logging.DEBUG,
    format='%(asctime)s %(name)s-%(levelname)s: %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S')
import mdtraj as md
import numpy as np
import scipy.ndimage.filters
import matplotlib.pyplot as plt
import os
import math
import json
from sklearn.neural_network import MLPClassifier
from sklearn.preprocessing import StandardScaler
sys.path.append('MD_common/')
sys.path.append('heatmapping/')
#sys.path.append('/home/oliverfl/git/interprettensor/interprettensor/modules/')
from helpfunc import *
from colvars import *
import nbimporter
import AnalyzeClusteredFrames as ancf
import MD_fun
import modules, utils
from trajclassifier import *
from relevancepropagator import *

fun = MD_fun.MD_functions()
os.chdir(get_project_path())
#simulations = [("A", "08"), ("A", "00")]  #, ("A", "16")]
logger = logging.getLogger("learnclust")
traj_type = "drorA_3_clusters"#"strings_apo_holo"#"drorD","freemd_apo"
distance_metric ="distance_closest-heavy"#"contact_closest-heavy" #"CA" #cvs-len5, CA, CAonlyCvAtoms, distance_closest-heavy
cvs_name = "cvs-{}".format(traj_type)
CA_query = None
logger.info("Done")

## Load clustering data from other module

In [None]:
nclusters = int(traj_type.split("_")[1])
cluster_simu = Simulation({
    "condition": "A",
    "number": "00",
    "name": "all",
    "stride": 100
})
cluster_simu.clusterpath="Result_Data/beta2-dror/clustering/"
cluster_simu = ancf.load_default(cluster_simu)
clustering_id = "drorA"

## Compute distance metric, e.g. CA distances

In [None]:
if distance_metric.startswith("contact") or distance_metric.startswith("distance"):
    scheme = distance_metric.split("_")[-1]
    logger.debug("Using scheme %s for computing distance metric %s", scheme, distance_metric)
    atoms = get_atoms("protein and name CA",cluster_simu.traj.top, sort=False)
    CA_atoms, cv_atoms = atoms, None
    protein_residues = [a.residue.index for a in atoms]
    protein_residues = sorted(protein_residues)
    frame_distances = np.zeros((len(cluster_simu.traj), len(protein_residues), len(protein_residues)))
    cutoff = 0.5
    for idx, r1 in enumerate(protein_residues):
        if idx == len(protein_residues) - 1:
            break
        if idx % 10 == 0:
            logger.debug("Computing contacts for residue %s/%s", idx + 1, len(protein_residues))
        res_pairs = [(r1,r2) for r2 in protein_residues[idx+1:]]
        dists, dist_atoms = md.compute_contacts(cluster_simu.traj,
                                               contacts=res_pairs,
                                               scheme=scheme,
                                               ignore_nonprotein=True)
        if distance_metric.startswith("contact"):
            contacts  = dists
            contacts[contacts > cutoff] = 0
            contacts[contacts > 0] = 1
            frame_distances[:,idx,(idx+1):] = contacts
            frame_distances[:,(idx+1):,idx] = contacts    
        elif distance_metric.startswith("distance"):
            inv_dists = 1/dists
            frame_distances[:,idx,(idx+1):] = inv_dists
            frame_distances[:,(idx+1):,idx] = inv_dists    
elif distance_metric.startswith("cvs"):
    cvs = load_object("cvs/" + distance_metric)
    frame_distances = eval_cvs(cluster_simu.traj, cvs)
    CA_atoms = None
    CA_query=None
    cv_atoms = []
    for idx, cv in enumerate(cvs):
        resq = "name CA and (resSeq {} or resSeq {})".format(cv.res1, cv.res2) 
        res_atoms =  get_atoms(resq, cluster_simu.traj.topology, sort=False)
        cv_atoms.append(tuple(res_atoms))
    logger.debug(cv_atoms)       
else:
    raise Exception("Unsupported value " + distance_metric)
    
logger.debug("Done. Loaded distances into a matrix of shape %s",
         frame_distances.shape)  

# Train Network
- Using http://scikit-learn.org/stable/modules/generated/sklearn.neural_network.MLPClassifier.html

In [None]:
trainingstep = 2 #set to something else to test prediction power
####Optionally shuffle indices:
#indices = np.arange(ĺen(cluster_simu.traj))
#np.random.shuffle(indices)
#frame_distances = frame_distances[indices]
#cluster_simu.traj = cluster_simu.traj[indices]
training_samples, target_values, scaler, classifier = transform_and_train(
    frame_distances, cluster_simu, trainingstep=trainingstep)
logger.debug("Done with learning (trainingstep=%s)", trainingstep)

## Check accuracy of predictions

In [None]:
check_predictions(
    classifier.predict(training_samples), training_samples, target_values)
logger.debug("Done")

# Implementing Layer-Wise Relevance Propagation 
* **relevance propagation method** described at http://heatmapping.org/tutorial/

* **Some info on MLP** (from https://www.hiit.fi/u/ahonkela/dippa/node41.html):

The computations performed by such a feedforward network with a single hidden layer with nonlinear activation functions and a linear output layer can be written mathematically as

 $\displaystyle \mathbf{x}= \mathbf{f}(\mathbf{s}) = \mathbf{B}\boldsymbol{\varphi}( \mathbf{A}\mathbf{s}+ \mathbf{a} ) + \mathbf{b}$	(4.15)

where  $ \mathbf{s}$ is a vector of inputs and  $ \mathbf{x}$ a vector of outputs.  $ \mathbf{A}$ is the matrix of weights of the first layer,  $ \mathbf{a}$ is the bias vector of the first layer.  $ \mathbf{B}$ and  $ \mathbf{b}$ are, respectively, the weight matrix and the bias vector of the second layer. The function  $ \boldsymbol{\varphi}$ denotes an elementwise nonlinearity. The generalisation of the model to more hidden layers is obvious.

* **About the MLP implementation we use**:

If you do want to extract the MLP weights and biases after training your model, you use its public attributes coefs_ and intercepts_.
- coefs_ is a list of weight matrices, where weight matrix at index i represents the weights between layer i and layer i+1.
- intercepts_ is a list of bias vectors, where the vector at index i represents the bias values added to layer i+1.

In [None]:
weights = classifier.coefs_
biases = classifier.intercepts_

propagation_samples = training_samples
propagation_values = target_values
#Using cluster reps gives good results
# propagation_samples = training_samples[cluster_simu.cluster_rep_indices]
# propagation_values = target_values[cluster_simu.cluster_rep_indices]

relevance = relevance_propagation(weights, biases, propagation_samples,
                                  propagation_values)
sensitivity = sensitivity_analysis(weights, biases, propagation_samples,
                                   propagation_values)
logger.info("Done")

## Analyze the relevance propagation results

In [None]:
avg_relevance, avg_sensitivity = analyze_relevance(relevance, sensitivity,
                                                   target_values, plot=True, max_scale=True)
# relevance_per_cluster, sensitivity_per_cluster = analyze_relevance_per_cluster(relevance, sensitivity, target_values)
# relevance_per_cluster, sensitivity_per_cluster = analyze_relevance_per_cluster(relevance, sensitivity, target_values)
def to_atom_pairs(avg_relevance, avg_sensitivity, rowcount, atoms):
    """Convert the avg relevance and sensitivity to AtomPair. Instead of real distance we use the relevance"""
    #Convert to
    nfeatures = len(avg_relevance)
    pairs = np.empty((nfeatures, ), dtype=object)
    feature_to_resids = np.empty((nfeatures,2), dtype=int)
    for idx, rel in enumerate(avg_relevance):
        if cv_atoms is not None:
            atom1, atom2 = cv_atoms[idx]
        else:
            atomidx1, atomidx2 = to_matrix_indices(idx, rowcount)
            atom1, atom2 = atoms[atomidx1], atoms[atomidx2]
        pair = ancf.AtomPair(rel, atom1, atom2)
        pair.relevance = rel
        pair.sensitivity = avg_sensitivity[idx]
        pairs[idx] = pair
        feature_to_resids[idx, 0] = atom1.residue.resSeq
        feature_to_resids[idx, 1] = atom2.residue.resSeq
    return pairs, feature_to_resids
atom_pairs, feature_to_resids = to_atom_pairs(avg_relevance, avg_sensitivity, frame_distances.shape[1], CA_atoms)
logger.debug("Done")

# CVs evaluation
## Picking those with highest relevance

In [None]:
cvs = []
cutoff = 0.98 #0.8
cvs_definition = []
for ap in sorted(atom_pairs, cmp=lambda ap1,ap2 : -1 if ap1.relevance[0,0] > ap2.relevance[0,0] else 1):
    rel = ap.relevance[0,0]
    if rel < cutoff:
        break    
    a1,a2 = ap.atom1, ap.atom2
    #print(a1,a2, relevance)
    cvid = "{}-{}".format(a1.residue,a2.residue)
    res1, res2 = a1.residue.resSeq, a2.residue.resSeq
    cv = CADistanceCv(cvid, res1, res2, periodic=True)
    logger.debug("%s has relevance %s",cv, rel)
    cv.normalize(trajs=[cluster_simu.traj])
    cvs.append(cv)
    cvs_definition.append({"@class":"CADistanceCv", "periodic": True, "id":cvid, "res1":res1, "res2":res2, "scale":cv._norm_scale+0,"offset": cv._norm_offset+0})

logger.debug("%s CVs in total", len(cvs))

def to_vmd_query(ca_cvs):
    allRes = " ".join(["{} {}".format(cv.res1, cv.res2) for cv in ca_cvs])
    vmdq = "name CA and resid {}".format(allRes)
    return vmdq
logger.debug("VMD query for plotting CAs:\n%s", to_vmd_query(cvs))

## save them

In [None]:
json_filename = "cvs-%s-len%s"%(traj_type, len(cvs))
logger.info("Saving CVs to file %s", json_filename)
persist_object(cvs, json_filename)
with open("cvs/" + json_filename + ".json", 'w') as fp:
    json.dump({"cvs": cvs_definition},fp, indent=2)

## Visualize CVs

In [None]:
cluster_indices = np.array(cluster_simu.cluster_indices)
median_cluster_vals =np.empty((nclusters, len(cvs)))
order_to_cluster = [2,1,3]
for cid, cv in enumerate(cvs):
    evals = cv.eval(cluster_simu.traj)
    plt.plot(evals, '--', alpha=0.25, label=cv.id)
    for c in range(nclusters):
        c_indices = np.argwhere(cluster_indices == order_to_cluster[c])
        median_cluster_vals[c,cid] = np.median(evals[c_indices])
np.savetxt("stringpath-cluster-median-{}-len{}.txt".format(traj_type, len(cvs)), median_cluster_vals)
if len(cvs) < 20:
    plt.legend(loc=(1.02,0))
plt.show()
#to_vmd_query()

# Create CVS with partion graph

In [None]:
def partition_as_graph(atom_pairs,
                       dist_func=lambda p: p.relevance,
                       percentile=99.95,
                       explain=True,
                      max_partition_count=30):
    final_distances = np.array([dist_func(p) for p in atom_pairs])
    cutoff = np.percentile(final_distances, percentile)
    graph = ancf.partition_as_graph(
        atom_pairs,
        dist_func=dist_func,
        cutoff=cutoff,
        split_subgraphs=True,
        max_partition_count=max_partition_count)
    if explain:
        graph.explain_to_human()
    return graph


atom_pairs = to_atom_pairs(avg_relevance, avg_sensitivity, frame_distances.shape[1], CA_atoms)
logger.info("Partitioning atom pairs to a colored graph")
percentile=99.5
graph = partition_as_graph(
    atom_pairs, dist_func=lambda p: p.relevance, percentile=percentile)
logger.debug("Done")

## Create and Plot the CVs

In [None]:
cv_generator, id_generator = ancf.most_relevant_dist_generator(graph, atom_pairs)
cvs = ancf.create_cvs(graph, CV_generator=cv_generator, ID_generator=id_generator)
sys.setrecursionlimit(10000) #might be necssary for Pickle...
# cvs = ancf.create_cvs(graph, CV_generator=ancf.compute_color_mean_distance)
# cvs = ancf.create_cvs(graph, CV_generator=ancf.compute_color_center_distance)
cvs = normalize_cvs(cvs, simulations=[cluster_simu])
cvs_filename = "cvs-len%s"%(len(cvs))
logger.info("Saving CVs to file %s", cvs_filename)
persist_object(cvs, cvs_filename)
logger.info("#distances as input=%s, percentile=%s, graph of %s atoms and %s colors -> %s distance CVs", 
            len(graph.atompairs), percentile, len(graph.nodes), len(graph.colors), len(cvs))

ancf.create_cluster_plots(cluster_simu, atom_pairs, graph, cvs)
logger.debug("Done")

## Plot order parameters

In [None]:
# plt.plot(cluster_simu.cluster_indices,'--', label="Cluster state", alpha=0.3)
graph.plot_distances(cluster_simu, histogram=True, separate_clusters=False, max_per_plot=10, bincount=10, use_contacts=False)
graph.plot_distances(cluster_simu, histogram=False, separate_clusters=False, max_per_plot=10, bincount=10, use_contacts=False)
logger.debug("Done")

# Other

## Rank the atoms with most relevance

In [None]:
import operator
from collections import Iterable 

relevance_cutoff = 0.3 

def compute_relevance_per_atom_for_pairs(atom_pairs):
    def add_relevance(relevance_count, atom, relevance):
        if relevance > relevance_cutoff: #get rid of noise 
            relevance_count[atom] = relevance_count.get(atom, 0) + relevance

    relevance_count = {}
    for ap in atom_pairs:
        add_relevance(relevance_count, ap.atom1, ap.relevance[0])
        add_relevance(relevance_count, ap.atom2, ap.relevance[0])
    return [
        (k, v[0, 0])
        for k, v in sorted(
            relevance_count.items(), key=operator.itemgetter(1), reverse=True)
    ]

def compute_relevance_per_atom_for_coordinates(atom_pairs):
    relevance_count = []
    #Merge relevance per residue
    resSeq_to_CA_relevance = {
        
    }
    #Merge XYZ for atoms
    for idx, a in enumerate(all_atoms):
        rels = avg_relevance[3*idx:3*(idx+1)]
        #Average coordiantes
        atom_rel = rels[rels > relevance_cutoff].sum()
        resSeq = a.residue.resSeq
        current_atom, current_rel = resSeq_to_CA_relevance.get(resSeq, (a,0.))
        if a.name == "CA":
            current_atom = a
        current_rel += atom_rel
        resSeq_to_CA_relevance[resSeq] = (current_atom, current_rel)
    relevance_count = [
        (k, v)
        for k, v in sorted(
            resSeq_to_CA_relevance.values(), key=operator.itemgetter(1), reverse=True)
    ]
    #plt.hist([r for (a,r) in relevance_count])
    #plt.show()
    return relevance_count

def compute_relevance_per_atom(atoms):
    if len(atoms) == 0:
        return []
    if isinstance(atoms[0], Iterable):
        return compute_relevance_per_atom_for_pairs(atoms)
    else:
        return compute_relevance_per_atom_for_coordinates(atoms)
    
def to_full_vmd_beta_script(relevance_count):
    max_rel = relevance_count[0][1]
    min_rel = relevance_count[len(relevance_count)-1][1]
    script = "";
    for a,r in relevance_count:
        beta = 10*(r - min_rel)/(max_rel-min_rel)
        script += to_vmd_beta_value(a.residue.resSeq, beta)
    return script
    
relevance_count = compute_relevance_per_atom(all_atoms if distance_metric == "coordinates" else atom_pairs)
max_to_print = 10
for i, (a, r) in enumerate(relevance_count):
    if i >= max_to_print:
        break
    logger.info("Atom %s has relevance %s", a, r)
vmd_beta_script = to_full_vmd_beta_script(relevance_count)
#Print vmd_beta_script and paste into TK-console. Set color to beta
#logger.debug("Command to color protein residues in VMD:\n%s", vmd_beta_script)

## Create Classifier CVs

create CVs you can use such as: 

```python
evals = cv.eval(traj)
```


In [None]:
discrete_classifier_cv, probaility_classifier_cvs = create_classifier_cvs(clustering_id, training_samples, target_values, scaler, classifier, trainingstep, query=CA_query, cvs=cvs)
logger.debug("Created CVs")

### Save classifier CVs to file

In [None]:
def save_data(save_dir, frame_distances, training_samples, target_values, feature_to_resids):
    np.save("{}/frame_distances".format(save_dir),frame_distances)
    np.save("{}/training_samples".format(save_dir),training_samples)
    np.save("{}/target_values".format(save_dir), target_values)
    np.save("{}/feature_to_resids".format(save_dir), feature_to_resids) 
    
def save_sklearn_objects(save_dir, scaler, classifier):
    persist_object(classifier, "{}/classifier".format(save_dir))
    persist_object(scaler, "{}/scaler".format(save_dir))

def save_cvs(save_dir, discrete_classifier_cv, probaility_classifier_cvs):
    persist_object(discrete_classifier_cv, "{}/discrete_classifier_cv".format(save_dir))
    persist_object(probaility_classifier_cvs, "{}/probability_classifier_cvs".format(save_dir))
    
save_dir = "neural_networks/" + clustering_id + "-" + distance_metric 
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
save_data(save_dir, frame_distances, training_samples, target_values, feature_to_resids)
save_sklearn_objects(save_dir, scaler, classifier)
save_cvs(save_dir, discrete_classifier_cv, probaility_classifier_cvs)

logger.debug("Done")