# Visualisation

This notebook showcases different ways of visualizing lig-prot and prot-prot interactions, either with atomistic details or simply at the residue level.

### This is a work in progress...

In [None]:
import MDAnalysis as mda
import prolif as plf
# load topology
u = mda.Universe(plf.datafiles.TOP)
lig = u.select_atoms("resname LIG")
prot = u.select_atoms("protein")

In [None]:
# create RDKit-like molecules for visualisation
lmol = plf.Molecule.from_mda(lig)
pmol = plf.Molecule.from_mda(prot)

In [None]:
# get lig-prot interactions with atom info
fp = plf.Fingerprint(["HBDonor", "HBAcceptor", "Cationic", "PiStacking"])
fp.run(u.trajectory, lig, prot, return_atoms=True)
df = fp.to_dataframe()
df.T

## py3Dmol (3Dmol.js)

With py3Dmol we can easily display the interactions.

For interactions involving a ring (pi-cation, pi-stacking...etc.) ProLIF returns the index of one of the ring atoms, but for visualisation having the centroid of the ring looks nicer. We'll start by writing a function to find the centroid, given the index of one of the ring atoms.

In [None]:
from rdkit import Chem
from rdkit.Geometry import Point3D

def get_ring_centroid(mol, index):
    # find ring using the atom index
    Chem.SanitizeMol(mol, Chem.SanitizeFlags.SANITIZE_SETAROMATICITY)
    ri = mol.GetRingInfo()
    for r in ri.AtomRings():
        if index in r:
            break
    else:
        raise ValueError("No ring containing this atom index was found in the given molecule")
    # get centroid
    coords = mol.xyz[list(r)]
    ctd = plf.utils.get_centroid(coords)
    return Point3D(*ctd)

Finally, the actual visualisation code. The API of py3Dmol is exactly the same as the GLViewer class of 3Dmol.js, for which the documentation can be found [here](https://3dmol.csb.pitt.edu/doc/$3Dmol.GLViewer.html).

In [None]:
import py3Dmol

colors = {
    "HBAcceptor": "blue",
    "HBDonor": "red",
    "Cationic": "green",
    "PiStacking": "purple",
}

v = py3Dmol.view(800, 600)
v.removeAllModels()

displayed = []
for i, row in df.T.iterrows():
    lresid, presid, interaction = i
    lindex, pindex = row[0]
    lres = lmol[lresid]
    pres = pmol[presid]
    if lresid not in displayed:
        displayed.append(lresid)
        v.addModel(Chem.MolToMolBlock(lres), "sdf")
        v.setStyle({"model": -1}, {"stick": {"colorscheme": "cyanCarbon"}})
    if presid not in displayed:
        displayed.append(presid)
        v.addModel(Chem.MolToMolBlock(pres), "sdf")
        v.setStyle({"model": -1}, {"stick": {}})
    p1 = lres.GetConformer().GetAtomPosition(lindex)
    p2 = pres.GetConformer().GetAtomPosition(pindex)
    if interaction in ["PiStacking", "EdgeToFace", "FaceToFace", "PiCation"]:
        p1 = get_ring_centroid(lres, lindex)
    if interaction in ["PiStacking", "EdgeToFace", "FaceToFace", "CationPi"]:
        p2 = get_ring_centroid(pres, pindex)
    v.addCylinder({"start": dict(x=p1.x, y=p1.y, z=p1.z),
                   "end":   dict(x=p2.x, y=p2.y, z=p2.z),
                   "color": colors[interaction],
                   "radius": .15,
                   "dashed": True,
                   "fromCap": 1,
                   "toCap": 1,
                  })
    
v.zoomTo()
v.show()

## RDKit

RDKit isn't really ideal for 3D visualisation, but it does the trick as a "2D" alternative. This script is given as a starting point for visualising the interactions with RDKit, as I'm sure there's a better way to do this.

In [None]:
from rdkit import Chem, Geometry
from rdkit.Chem import AllChem, Draw
from collections import defaultdict
from IPython import display

colors = {
    "HBAcceptor": (0,0,1),
    "HBDonor": (1,0,0),
    "Cationic": (0,1,0),
    "PiStacking": (0.7,0.7,0.3),
}

# save displayed residues in a dict for easy access
mols = {lresid: lmol[lresid] for lresid in df.columns.get_level_values("ligand").unique()}
mols.update({presid: pmol[presid] for presid in df.columns.get_level_values("protein").unique()})

# save inside each residue atoms which interaction to display,
# along with a unique key
for key, (i, row) in enumerate(df.T.iterrows()):
    lresid, presid, interaction = i
    lindex, pindex = row[0]
    lres = mols[lresid]
    pres = mols[presid]
    name = f"highlight_{interaction}"
    lres.GetAtomWithIdx(lindex).SetUnsignedProp(name, key)
    pres.GetAtomWithIdx(pindex).SetUnsignedProp(name, key)

# create molecule that combines all residues
mol = Chem.Mol()
for res in mols.values():
    mol = AllChem.CombineMols(mol, res)

# retrieve the indices of atoms to highlight in new molecule
highlight_atoms = defaultdict(list)
highlight_ints = defaultdict(list)
for atom in mol.GetAtoms():
    props = atom.GetPropsAsDict()
    for prop, value in props.items():
        if prop.startswith("highlight"):
            idx = atom.GetIdx()
            interaction = prop.split("_")[-1]
            highlight_atoms[idx].append(colors[interaction])
            highlight_ints[(interaction, value)].append(idx)

# draw molecule
d = Draw.MolDraw2DSVG(800, 600)
opts = Draw.MolDrawOptions()
opts.fillHighlights = False
d.SetDrawOptions(opts)
d.DrawMoleculeWithHighlights(mol, "", dict(highlight_atoms), {}, {}, {})

# add lines for interactions
conf = mol.GetConformer()
for (interaction, _), aids in highlight_ints.items():
    a1, a2 = aids
    p1 = Geometry.Point2D(conf.GetAtomPosition(a1))
    p2 = Geometry.Point2D(conf.GetAtomPosition(a2))
    d.DrawWavyLine(p1, p2, colors[interaction], colors[interaction])

# display
d.FinishDrawing()
display.SVG(d.GetDrawingText())