December 2020

Note: Save notebook as HTML if intereactive plots not working:<br>
`jupyter-nbconvert RISM_3D_plots.ipynb --to html --output RISM_3D_plots.html`

# 3D and interactive RISM plots

Some examples of plotting RISM surfaces in 3D plots alongside molecular structures.

In [None]:
import RISM_grids as rg
print('RISM_grids version: {}'.format(rg.__version__))

from IPython.display import Image

from rdkit.Chem.Draw import IPythonConsole
from rdkit import Chem
from rdkit.Chem import AllChem

import numpy as np
import plotly.graph_objects as go

In [None]:
# Example .dx file:
dx_file = 'examples/mobley_7754849_conf0_guv.O.1.dx'

# Read .dx file:
guv_3D = rg.Grid3D.FromDxFile(dx_file)

## 2D plots

In [None]:
# Plot 2D slice at x = 34:
guv_3D.slice('x', 34).plt(vmax=6)

## 3D Plots

For 3D plots, the 2D surfaces for each plane can be plotted on 3D axes alongside the molecular structure.

#### Read molecular structure

The molecular structure is read and processed using RDKit.

In [None]:
pdb_file = 'examples/mobley_7754849_conf0.pdb'

mol = AllChem.MolFromPDBFile(pdb_file, removeHs=False)
display(mol)

In [None]:
# Process molecular structure:

# Set colours and relative sizes for different atoms:
atom_colours = {'C'  : 'cyan', 
                'H'  : 'white',
                'O'  : 'red',
                'N'  : 'blue',
                'P'  : 'gold',
                'S'  : 'yellow',
                'Cl' : 'green'}
atom_sizes = {'C'  : 4, 
              'H'  : 2,
              'O'  : 4,
              'N'  : 4,
              'P'  : 5,
              'S'  : 5,
              'Cl' : 5}

# Get atom coordinates:
atom_coords = mol.GetConformers()[0].GetPositions()

# Save coordinates by atom type:
atomtype_coords = {}
for i, atom in enumerate(mol.GetAtoms()):
    atomtype = atom.GetSymbol()
    if atomtype not in atomtype_coords:
        atomtype_coords[atomtype] = []
    atomtype_coords[atomtype] += [atom_coords[i]]

# Add atoms as points on a 3D scatter plot:
# Might be better to use shapes so that sizes are constant
# and don't change when you zoom in and out of the plot:
atoms = []
for atomtype in atomtype_coords.keys():
    ac = np.array(atomtype_coords[atomtype])
    atoms += [go.Scatter3d(
        x=ac[:,0],
        y=ac[:,1],
        z=ac[:,2],
        mode='markers',
        marker=dict(size=atom_sizes[atomtype],color=atom_colours[atomtype]),
        text=atomtype,
        name=atomtype,
        hoverinfo='text'
        )]

# Read bonds:
bonds = []
for bond_i in mol.GetBonds():
    bond0 = bond_i.GetBeginAtomIdx()
    bond1 = bond_i.GetEndAtomIdx()
    
    # Add bonds as line in 3D scatter plot:
    bonds += [go.Scatter3d(
        x=atom_coords[[bond0, bond1],0],
        y=atom_coords[[bond0, bond1],1],
        z=atom_coords[[bond0, bond1],2],
        mode = 'lines',
        line=dict(width=3,color='orange'), # grey
        #marker=dict(size=5,color='red'),
        #text='Bond',
        name = 'Bond',
        #hoverinfo='text'
        )]

#### Functions for 3D plots

The 3D plots are made using plotly.

In [None]:
# Plot the RISM surface:
def get_slice(x, y, z, surfacecolor, opacity=0.6):
    return go.Surface(x=x,
                      y=y,
                      z=z,
                      surfacecolor=surfacecolor,
                      opacity=opacity,
                      coloraxis='coloraxis')

# Get colour limits for a slice:
def get_colour_lims(surfacecolor):
    return np.min(surfacecolor), np.max(surfacecolor)

# Get RISM surface:
# The surface will be the x_i-y_i plane with z_i as the surface normal, so different surfaces can
# be generated by passing different indices to x_i, y_i, z_i, e.g. 0, 1, 2 will be the xy plane
# slice_d is the level that the slice cuts the axis normal to the slice, in grid points
def defsurf(rism_3D_grid, x_i, y_i, z_i, slice_d, plot_at_edge=False, opacity=1):
    
    # Set up a grid for the surface:
    i = np.linspace(rism_3D_grid.origin[x_i], 
                    rism_3D_grid.origin[x_i] + rism_3D_grid.gridsize[x_i]*rism_3D_grid.gridstep[x_i], 
                    rism_3D_grid.gridsize[x_i])
    j = np.linspace(rism_3D_grid.origin[y_i], 
                    rism_3D_grid.origin[y_i] + rism_3D_grid.gridsize[y_i]*rism_3D_grid.gridstep[y_i], 
                    rism_3D_grid.gridsize[y_i])
    i, j = np.meshgrid(i, j)

    # Get coordinates for axis perpendicular to surface:
    if plot_at_edge:
        k = np.ones(i.shape)*rism_3D_grid.origin[z_i]
    else:
        k = np.ones(i.shape)*(rism_3D_grid.origin[z_i] + (slice_d*rism_3D_grid.gridstep[z_i]))

    # Convert index z_i into axis label and get RISM values for the surface:
    cut_ax = {0 : 'x', 1 : 'y', 2 : 'z'}
    rism_vals = rism_3D_grid.slice(cut_ax[z_i], slice_d).grid

    sminz, smaxz = get_colour_lims(rism_vals)

    # Convert i, j, k coordinates back into x, y, z:
    coords = [i, j, k]
    co = np.argsort([x_i, y_i, z_i])
    
    slice_z = get_slice(coords[co[0]], coords[co[1]], coords[co[2]], rism_vals, opacity=opacity)

    return slice_z, sminz, smaxz

#### Plot figure with RISM surfaces as "shadows"

In [None]:
# Get slices through the middle of the RISM box parallel to each plane
# and then plot these at the edge of the plot:
slice_x, vminx, vmaxx = defsurf(guv_3D, 2, 1, 0, guv_3D.gridsize[0]//2, plot_at_edge=True)
slice_y, vminy, vmaxy = defsurf(guv_3D, 0, 2, 1, guv_3D.gridsize[1]//2, plot_at_edge=True)
slice_z, vminz, vmaxz = defsurf(guv_3D, 1, 0, 2, guv_3D.gridsize[2]//2, plot_at_edge=True)

# Set minimum and maximum values for the colour bar:
vmin = min([vminx, vminy, vminz])
vmax = max([vmaxx, vmaxy, vmaxz])
# As with 2D heatmaps, setting vmax lower than the maximum to
# truncate the colourbar range can make smaller fluctuations 
# more clearly visible:
vmax=4

# Plot figure:
fig1 = go.Figure(data=[slice_x, slice_y, slice_z] + atoms + bonds)
fig1.update_layout(
         title_text='RISM slices', 
         title_x=0.5,
         width=700,
         height=700,
         showlegend=False,
         #scene_zaxis_range=[0,30], 
         coloraxis=dict(colorscale='BrBG',
                        colorbar_thickness=25,
                        colorbar_len=0.75,
                        cmin=vmin,
                        cmax=vmax
                       ))
# fig1.show()

This graph could also be combined with a molecule rendered in VMD.

In [None]:
# With VMD figure:
Image(filename = "examples/3D_RISM_VMD_plot_modified.png", width=300, height=300)

#### Plot figure with RISM surfaces at midplanes through box

To see how the surfaces fit around the molecule they can also be plotted at the correct point in the 3D plot.

In [None]:
# Get slices through the middle of the RISM box parallel to each plane
# and then plot these at the box midplanes:
slice_x, vminx, vmaxx = defsurf(guv_3D, 2, 1, 0, guv_3D.gridsize[0]//2, opacity=0.6)
slice_y, vminy, vmaxy = defsurf(guv_3D, 0, 2, 1, guv_3D.gridsize[1]//2, opacity=0.6)
slice_z, vminz, vmaxz = defsurf(guv_3D, 1, 0, 2, guv_3D.gridsize[2]//2, opacity=0.6)

vmin = min([vminx, vminy, vminz])
vmax = max([vmaxx, vmaxy, vmaxz])
vmax=4

fig1 = go.Figure(data=[slice_x, slice_y, slice_z] + atoms + bonds)
fig1.update_layout(
         title_text='RISM slices', 
         title_x=0.5,
         width=700,
         height=700,
         showlegend=False,
         #scene_zaxis_range=[0,30], 
         coloraxis=dict(colorscale='BrBG',
                        colorbar_thickness=25,
                        colorbar_len=0.75,
                        cmin=vmin,
                        cmax=vmax))
# #fig1.show()

## Interactive plots

Add a slider to the 3D plots to pull the RISM surface over the molecule.

In [None]:
# Add sliders:
# https://plotly.com/python/sliders/

coord_idxs = {'x' : 0, 'y' : 1, 'z' : 2}

# Function for producing interactive plot:
def slider_fig(slice_idxs, norm, step_int=1, vmax=None):

    # Set up figure with molecular structure plotted, could also add 
    # non-interactive slices:
    fig1 = go.Figure(data=bonds + atoms) # + [slice_z])

    # Number of traces on the figure used to represent the non-interactive 
    # part of the plot (e.g. molecular structure, any non-interactive slices)
    # to ensure these are not changed by the slider:
    n_mol_traces = len(fig1.data)

    # Set initial values for the min and max for the colour bar:
    vmin = 10
    vmax_tmp = 0

    # Tried to get multiple sliders on same plot, but doesn't work:
    #for [slice_idxs, norm] in , y_i, z_i in [slice_idcs]: # [2, 1, 0], [2, 0, 1], [1, 0, 2]]:
    
    # Get all slices and add them to the graph:
    x_i, y_i, z_i = slice_idxs
    for slice_d in range(0, guv_3D.gridsize[coord_idxs[norm]], step_int):
        #slice_x, vminx, vmaxx = defsurf(guv_3D, 2, 1, 0, slice_d)
        slice_x, vminx, vmaxx = defsurf(guv_3D, x_i, y_i, z_i, slice_d, opacity=0.6)
        vmin = min([vmin, vminx])
        vmax_tmp = max([vmax_tmp, vmaxx])
        # Add each slice to the figure:
        fig1.add_trace(slice_x)

    # Take vmax calculated from slices if vmax not given to function:
    if vmax is None:
        vmax = vmax_tmp

    # Make all slices controlled by the slider invisible initially:
        # Should this be + 1?
    for d in fig1.data[n_mol_traces + 1:]:
        d.visible = False

    # In each step, set all slices to invisible, then make one visible:
    steps = []
    for i in range(len(fig1.data) - n_mol_traces):
        step = dict(
            method="update",
            args=[{"visible": [True]*n_mol_traces + [False] * (len(fig1.data) - n_mol_traces)},
                  # Use HTML <br> for new line in title:
                  {"title": "Interactive plot of RISM slice perpendicular to "+norm+"<br>"\
                            "Showing slice: " + str(i*step_int)}],  # layout attribute
        )
        # Make ith trace visible:
        step["args"][0]["visible"][i + n_mol_traces] = True
        steps.append(step)

    # Set up the slider, tried to add multiple sliders to list, but doesn't work:
    sliders = [dict(
        active=0,
        currentvalue={"prefix": norm+" slice: "},
        #pad={"t": 50},
        steps=steps
    )]

    # Add sliders to the figure:
    fig1.update_layout(
             title_text="Interactive plot of RISM slice perpendicular to "+norm+"<br>"\
                        "Showing slice: 0",
             title_x=0.5,
             width=700,
             height=700,
             showlegend=False,
             # Have to give axis ranges to keep plot size constant when slider moves,
             # make ranges slightly larger than RISM box:
             scene_xaxis_range=[guv_3D.origin[0] - guv_3D.gridstep[0] ,
                                guv_3D.origin[0] + (guv_3D.gridsize[0] + 1)*guv_3D.gridstep[0]],
             scene_yaxis_range=[guv_3D.origin[1] - guv_3D.gridstep[1],
                                guv_3D.origin[1] + (guv_3D.gridsize[1] + 1)*guv_3D.gridstep[1]],
             scene_zaxis_range=[guv_3D.origin[2] - guv_3D.gridstep[2],
                                guv_3D.origin[2] + (guv_3D.gridsize[2] + 1)*guv_3D.gridstep[2]],
#              scene_aspectmode='cube', #'data'
             scene_aspectratio=dict(x=1, y=1, z=1),
             coloraxis=dict(colorscale='BrBG',
                            colorbar_thickness=25,
                            colorbar_len=0.75,
                            cmin=vmin, 
                            cmax=vmax),
             sliders=sliders)
    
    fig1.write_html("interactive_fig_"+norm+".html")
    fig1.show()    

In [None]:
# Show plots:

# Plot separate graph for each surface (can't work out how to combine them):
for slice_idxs, norm in [[[2, 1, 0], 'x'], [[0, 2, 1], 'y'], [[1, 0, 2], 'z']]: 
    slider_fig(slice_idxs, 
               norm, 
               # Step interval:
               step_int=1, 
               # Truncate colour bar:
               vmax=4)