This code visualizes the attention weights for the Graph Attention Network.

The dictionary of results is stored in "t5_GTP_attention_misclassified.p" and contains 4 examples for proteins where there are binding sites misclassified by GCN but correctly classfified by GAT.

TO run the code please install the following packages using pip

`pip install numpy`

`pip install py3Dmol`

`pip install dgl`

`pip install biopython`

`pip install ipywidgets`



In [2]:
import py3Dmol
import pickle
import numpy as np
import dgl

from Bio.PDB import *

parser=PDBParser()

from ipywidgets import interact, interactive, fixed, interact_manual


In [8]:
MODIFIED_AA=["TOX","MSE",'LLP','TPO','CME','CSD','MLY','SEP','CSO']

def is_AC(res):
    if res.get_resname() in MODIFIED_AA:
        return True
    return res.get_full_id()[3][0]==" "

def get_residues(protein):
    pdb_id=protein.split("_")[0].lower()
    chain_id=protein.split("_")[1]

    structure = parser.get_structure("X",f"{pdb_id}.pdb")

    all_res=[]
    for chain in structure.get_chains():
        if chain.id==chain_id:
            for res in chain.get_residues():
                if is_AC(res):
                    all_res+=[res]
    return all_res


# Choose the ligand
LIGAND="GTP"
result=pickle.load(open(f"t5_{LIGAND}_attention_misclassified.p","rb"))


# Choose an example ! (k= 0,1,2)
k=0

# GET the pdb ids of the correpsonding proteins in the dictionary 
protein_results=result[k]  
protein=protein_results["protein_name"]
all_res=get_residues(protein)

# Predictions
preds=protein_results["misclassified"]
# true labels
true_bs=protein_results["misclassified"]


all_res=np.array(all_res)
res_node_ids=np.array(range(len(all_res)))

# predicted binding sites
# pred_binding_sites=[x.id[1] for x in all_res[preds]]
pred_binding_sites=[x.id[1] for x in all_res[preds]]
pred_bs_node_ids=res_node_ids[preds]
# true_binding_sites=[x.id[1] for x in all_res[true_bs]]
A=np.asarray(protein_results["attention_weights"])

# This code visualizes the 5 most relevant neighbors of a predicted binding site
def plot_attention(i):
    binding_site=pred_binding_sites[i]
    i=pred_binding_sites.index(binding_site)
    relevant_neighbors=np.argsort(A[pred_bs_node_ids[i],:].flatten())[::-1][:5]
    all_neighbors=[x.id[1] for x in all_res[A[pred_bs_node_ids[i],:]!=0]]
    print("Total neighbors :",len(all_neighbors))
    relevant_residues=[x.id[1] for x in all_res[relevant_neighbors]]
    print("Total attention for 5 most relevant neighbors",A[pred_bs_node_ids[i],relevant_neighbors].sum()*100,"%")

    viewer=py3Dmol.view(f'{protein.lower()[:4]}.pdb')
    viewer.setStyle({},{'cartoon':{}})
    viewer.addSurface(py3Dmol.VDW,{'opacity':0.8,'color':'lightblue'},{"hetflag":False})
    viewer.setStyle({"resi":all_neighbors},{"cartoon":{"color":"blue"}})
    viewer.setStyle({"resi":relevant_residues},{"stick":{"color":"green"}})
    viewer.setStyle({"resi":binding_site},{"stick":{"color":"yellow"}})
    viewer.setStyle({"hetflag":True},{"stick":{}})
    viewer.show()



  all_res=np.array(all_res)


In [9]:
interact(plot_attention,i=range(len(pred_binding_sites)))

interactive(children=(Dropdown(description='i', options=(0,), value=0), Output()), _dom_classes=('widget-inter…

<function __main__.plot_attention(i)>