In [None]:
import sys
import os

import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
import numpy as np

sys.path.append(os.getcwd())
from load_data import *
from plotting_tools import *

sys.path.append('../3_score/')
from optimize import Scores
from fingerprint import FuzzyFingerPrint

In [None]:
## useful class specifications:
# Ligand -- attributes: crystal (currently not used/none), poses (dict of pose_num:pose).
# Pose -- attributes: rmsd (float), fp (Fingerprint, num (integer), gscore (float)
# FuzzyFingerPrint -- attributes: pdb, feats (dict of residue name:interaction list).

# crystals: maps structures to Poses, 
#           e.g., crystals['4LDO'] = Pose(rmsd=0, fingerprint, num=0, gscore=0)
# glides:   maps ligand_struct, grid_struct to Ligands, 
#           e.g., glides['4LDO']['4LDO'] = Ligand(...)

## STEP 1:
## load in the crystal structures, docking results, and fingerprints

receptor = 'B1AR_all'
(xcrystals, xglides, ligs, structs) = load_data(receptor,
                                                w=[10,10,10,1,0],
                                                require_fp=False,
                                                combine_structs=False,
                                                load_docking=False)

In [None]:
## STEP 2:
## visualize docking results

# in the heatmaps, each row is a structure and each column is a ligand

n = 500

get_ipython().magic(u'matplotlib inline')

print 'new docking'
best_rmsd_matrix = get_docking_stats(ligs, structs, xglides, n, lambda x: np.min(x) if x != [] else np.nan)
heatmap(best_rmsd_matrix, structs, ligs, red=4)
print np.nanmean(best_rmsd_matrix)

#top_rmsd_matrix = get_docking_stats(ligs, structs, xglides, n, lambda x: x[0] if x != [] else np.nan)
#heatmap(top_rmsd_matrix, ligs, structs, red=10)
#print np.nanmean(top_rmsd_matrix)

#num_poses_matrix = get_docking_stats(ligs, structs, xglides, n, lambda x: len(x) if x != [] else np.nan)
#heatmap(num_poses_matrix, ligs, structs, red=300)
#print np.nanmean(num_poses_matrix)

#var_mat = get_docking_stats(ligs, structs, xglides, n, lambda x: np.var(x) if x != [] else np.nan)
#heatmap(var_mat, ligs, structs)

helpfully_frozen(ligs, structs, xglides, n)

In [None]:
get_ipython().magic(u'matplotlib inline')

def plot_docking_by_structure(ligs, structs, glides, n=25, title=''):
    plt.plot([2,2],[0,1],'--k')
    for s in structs:
        li = [l for l in ligs if s in glides[l]]
        min_rmsds = [min([glides[l][s].poses[i].rmsd for i in range(min(n, len(glides[l][s].poses.keys())))]) for l in li]
        min_rmsds.sort()
        min_rmsds = [0] + min_rmsds + [min_rmsds[-1]]
        prop_ligands = np.cumsum([-1.0/len(li)] + [1.0/len(li) for l in li] + [1.0/len(li)])
        plt.step(min_rmsds, prop_ligands, label=s)
    plt.gca().set_xlim([0,6])
    plt.gca().set_ylim([0,1])
    plt.xlabel('RMSD [A]')
    plt.ylabel('Cumulative Proportion of Ligands')
    plt.title(title)
    plt.legend()
    plt.show()
    
plot_docking_by_structure(ligs, structs, xglides, 500, '{} cross docked. all ~300 poses.'.format(receptor))

In [None]:
get_ipython().magic(u'matplotlib inline')

import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.cm as cm

def interaction_heatmap(A, structs, res, cluster_labels=None):
    fig, ax = plt.subplots()

    colors = {0:'red',1:'blue',2:'green',3:'orange'}

    df = pd.DataFrame(A, structs, res)
    
    # plotting
    #fig,ax = plt.subplots()
    def i_matrix(A, res, i):
        aa = np.zeros(A.shape)
        for j, r in enumerate(res):
            if r[1] == i:
                aa[:,j] = A[:,j]
            else:
                aa[:,j] = np.nan*A[:,j]
        return aa
    
    ax.matshow(i_matrix(A, res, 3), cmap=cm.Purples, vmin=0, vmax=np.max(A))
    ax.matshow(i_matrix(A, res, 2), cmap=cm.Reds, vmin=0, vmax=np.max(A))
    ax.matshow(i_matrix(A, res, 1), cmap=cm.Blues, vmin=0, vmax=np.max(A))
    ax.matshow(i_matrix(A, res, 0), cmap=cm.Greens, vmin=0, vmax=np.max(A))

    # put the major ticks at the middle of each cell
    ax.set_xticks(np.arange(A.shape[1]), minor=False)
    ax.set_yticks(np.arange(A.shape[0]), minor=False)
    #ax.invert_yaxis()
    ax.xaxis.tick_top()
    
    if cluster_labels is not None:
        for c in range(max(cluster_labels)):
            num_in_c = sum([1 for i in cluster_labels if i <= c]) - 0.5
            ax.plot(ax.get_xlim(), [num_in_c, num_in_c], linewidth = 4, c="k")

    ax.set_xticklabels(res, minor=False, rotation = 'vertical')
    ax.set_yticklabels(structs, minor=False)
    plt.title('')
    plt.xlabel('interactions')
    plt.ylabel('crystal structures')
    
    square_size = 0.8
    fig = matplotlib.pyplot.gcf()
    fig.set_size_inches(square_size*len(structs), square_size*len(res))
    
    plt.show()

In [None]:
# To do:
# 1) fingerprint -> vector (dict)
# 2) cluster vectors
# 3) sort based on cluster
# 4) show clusters

def get_fp_vectors(crystals, i_type=[0,1,2]):
    all_interactions = {}
    fp_map = {}
    fp_vectors = {}
    for s in crystals:
        if s not in fp_map:
            fp_map[s] = {}
        for r in crystals[s].fp.feats:
            for i in i_type:
                if crystals[s].fp.feats[r][i] != 0:
                    fp_map[s][(r,i)] = crystals[s].fp.feats[r][i]**2
                    #if (r,i) not in all_interactions:
                    all_interactions[(r,i)] = all_interactions.get( (r,i), 0) + fp_map[s][(r,i)]
                    
    all_interactions = sorted(all_interactions.keys(), key=lambda x:-all_interactions[x])
    return all_interactions, {s: [fp_map[s].get(x, 0) for x in all_interactions] for s in crystals}

In [None]:
from sklearn.cluster import KMeans

i_indices = [0,1,2]
all_interactions, fp_vectors = get_fp_vectors(xcrystals,i_indices)
unsorted_ligs = fp_vectors.keys()

def get_fp_matrix(sorted_ligs, fp_vectors):
    fp_matrix = np.zeros(( len(sorted_ligs), len(fp_vectors[sorted_ligs[0]]) ))
    for i, l in enumerate(sorted_ligs):
        fp_matrix[i,:] = fp_vectors[l][:]
    return fp_matrix

fp_mat_1 = get_fp_matrix(unsorted_ligs, fp_vectors)

kmeans = KMeans(n_clusters=3, random_state=0).fit(fp_mat_1)
print kmeans.labels_

sorted_ligs = sorted(unsorted_ligs, key=lambda x:kmeans.labels_[unsorted_ligs.index(x)])
fp_mat_2 = get_fp_matrix(sorted_ligs, fp_vectors)

num_i = 10

#interaction_heatmap(fp_mat_1[:,:num_i], unsorted_ligs, all_interactions[:num_i], 'crystal structures', 'residues')
interaction_heatmap(fp_mat_2[:,:num_i], sorted_ligs, all_interactions[:num_i], kmeans.labels_)
interaction_heatmap(kmeans.cluster_centers_[:,:num_i], [i for i in range(max(kmeans.labels_) + 1)], all_interactions[:num_i])

In [None]:
get_ipython().magic(u'matplotlib inline')

def load_alignments(receptor):
    fname = '/scratch/PI/rondror/docking_data/{}/residue_alignments.txt'.format(receptor)
    
    all_res = {}
    with open(fname) as f:
        for line in f:
            rnum, rname, struct1, struct2, rmsd = line.strip().split()
            if rnum not in all_res:
                all_res[rnum] = {}
            all_res[rnum][(struct1, struct2)] = float(rmsd)
            
    return all_res

def show_alignments(structs, residues, alignments):
    pass

prot_rmsds = load_alignments(receptor)
print prot_rmsds.keys()

In [None]:
## STEP 4:
## score all pairs of ligands
#print glides.keys()
#best_rmsd_matrix = best_pose(glides.keys(), glides.keys(), glides, 50)
#best_rmsd_matrix2 = best_pose(xglides.keys(), xglides.keys(), xglides, 25)

get_ipython().magic(u'matplotlib inline')

all_final = {}

for struct in ['3V4A']:
    #filt_lig = [l for l in ligs if not np.isnan(best_rmsd_matrix[structs.index(struct)][ligs.index(l)])]
    filt_lig = [l for l in ligs if best_rmsd_matrix[structs.index(struct)][ligs.index(l)] < 1.8]
        
    scores = Scores(xglides, xcrystals, filt_lig, struct, n)
    title = '{} docked to: {}'.format(receptor, struct)
    final_rmsds = plot_final_rmsds(scores, title)
    
    crystal_cluster = {(l, -1):xcrystals[l].fp for l in filt_lig}
    opt_cluster = [(l,p) for (l,p) in scores.optimized_scores.items()]
    opt_cluster = {(l,p):xglides[l][struct].poses[p].fp for (l,p) in opt_cluster}
    plot_shared_interactions(crystal_cluster, c2=opt_cluster, max_r=20,
                             lab1='Crys', lab2='Opt', title=title,interactions=[0,1,2,3])
    
    print opt_cluster.keys()
    
    a = scores.all_analysis
    print struct, 'performance:'
    print_table({i:a[i][1][:] for i in a})
    for i in a:
        if i not in all_final: all_final[i] = []
        all_final[i].extend(a[i][1][:])

print 'average across all data:'
print_table(all_final)