# Experiments for Identifying Buried Water Sites

&nbsp;

### Preparation
I'd like everything to be easily visualizable, so let's work in 2D. Admittedly, this will lower the complexity of possible cavity arrangements, but hopefully will serve as a sufficient sandbox to attempt various approaches to the site-identification problem.



-------------------------------

In [None]:
## imports & configuration

# standard imports
import json

# custom imports
import ipywidgets
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches
import scipy.optimize
import scipy.spatial
import tqdm.notebook

# inline plots
# %config InlineBackend.figure_formats = ['svg']
%matplotlib notebook

# jupyter theme
try:
    import jupyterthemes as jt
    jt.jtplot.style()
except ImportError:
    pass

-------------------------------

In [None]:
## class to represent drawable matplotlib canvas
class Canvas:
    def __init__(self):
        pass

    def init_figure(self, figsize=(5, 3)):
        self.fig, self.ax = plt.subplots(figsize=figsize)
        for s in ['top','bottom','left','right']:
            self.ax.spines[s].set_linewidth(2)
        self.ax.set_aspect('equal', 'box')
        self.ax.set_xlim(0, 1)
        self.ax.set_ylim(0, 1)
        self.ax.xaxis.set_ticks([])
        self.ax.yaxis.set_ticks([])
        self.fig.tight_layout()

    def draw_circles(self, positions, radii, style=None):
        # set default circle style
        style = style or dict(edgecolor="C0", linewidth=2, fill=None)

        # if radius is a scalar, convert to array
        if isinstance(radii, (int, float)):
            radii = radii*np.ones(len(positions))

        # loop through particles
        return [
            self.ax.add_patch(matplotlib.patches.Circle(xy=p, radius=r, **style))
            for i, (p, r) in enumerate(zip(positions, radii))
        ]

In [None]:
## custom slider ipywidget class
class Slider:
    # build widget objects
    def __init__(self, update_method, index_var, **kwargs):
        # play button
        defaults = dict(value=0, min=0, max=99, step=1, interval=1000, disabled=False)
        defaults.update(kwargs)
        self.play = play = ipywidgets.Play(**defaults)

        # make interactive slider
        self.int_slider = ipywidgets.IntSlider(min=play.min, max=play.max, step=play.step, value=play.value)
        self.slider = ipywidgets.interactive(update_method, **{index_var: self.int_slider})

        # link play button to slider value
        ipywidgets.jslink((self.play, 'value'), (self.slider.children[0], 'value'))

        # construct player widget
        self.player = ipywidgets.HBox([self.play, self.slider])

In [None]:
## trajectory configuration

# trajectory length
n_frames = 100

# set box size (scaling)
box_size = 70 # Å

In [None]:
## construct fake protein trajectory

# construct grid of hexagonally-spaced points
def construct_hex_grid(m, n):
    tzip2d = lambda xs, ys: np.concatenate(np.stack(np.meshgrid(xs, ys)).T, axis=0)
    return np.concatenate([
        1+tzip2d(2*np.sqrt(3)*(0+1*np.arange(m//2)), 0+2*np.arange(n//2)),
        1+tzip2d(1*np.sqrt(3)*(1+2*np.arange(m//2)), 1+2*np.arange(n//2)),
    ], axis=0)

# construct protein grid
grid_shape = (32, 56) # rows and columns
margin = 0.4 # fraction of grid radius
n_prot_base = np.prod(grid_shape)
l_max = max(grid_shape[0]*np.sqrt(3), grid_shape[1])
grid_radius = 1/(l_max)
scale = 1/(l_max+1)
grid_pos = scale*(np.array([0.5, 0])+construct_hex_grid(*grid_shape))
prot_radius = (1 - margin)*grid_radius

# propagate grid positions through vector field
field_func = lambda x: np.sin(10*(x-0.5))[:, ::-1]
flow_positions = np.zeros((n_frames, len(grid_pos), 2))
flow_positions[0] = grid_pos
for i in range(1, n_frames):
    flow_positions[i] = flow_positions[i-1]+0.001*field_func(flow_positions[i-1])
    flow_positions[i] = (0.999*(flow_positions[i]-0.5))+0.5

# construct random perturbations
perts = np.random.randn(n_frames, len(grid_pos), 2)
perts /= np.linalg.norm(perts, axis=-1)[:, :, None]
perts *= margin*grid_radius*np.random.random((n_frames, len(grid_pos)))[:, :, None]

# use SDF to subselect from grid
def smin(a, b, k=32):
    res = np.exp2(-k * a) + np.exp2(-k * b)
    return np.log2(res) / -k
sub = lambda x, y, **kwargs: -smin(-x, y, **kwargs)
SDF = lambda x: sub(
    sub(
        sub(
            sub(
                np.linalg.norm(x - np.array([0.5, 0.5]), axis=-1)-0.4,
                np.linalg.norm(x - np.array([0.3, 0.3]), axis=-1)-0.04,
            ),
            np.linalg.norm(x - np.array([0.28, 0.52]), axis=-1)-0.1,
        ),
        smin(
            smin(
                np.linalg.norm(x - np.array([0.5, 0.8]), axis=-1)-0.01,
                np.linalg.norm(x - np.array([0.6, 0.7]), axis=-1)-0.005,
                k=16
            ),
            smin(
                np.linalg.norm(x - np.array([0.7, 0.6]), axis=-1)-0.005,
                np.linalg.norm(x - np.array([0.8, 0.5]), axis=-1)-0.01,
                k=16
            ),
            k=16
        )
    ),
    smin(
        np.linalg.norm(x - np.array([0.6, 0.3]), axis=-1)-0.01,
        np.linalg.norm(x - np.array([0.6, 0.2]), axis=-1)-0.01,
        k=16
    )
)
prot_sel = SDF(grid_pos) < 0.

# output final protein positions
prot_pos = (flow_positions+perts)[:, prot_sel]

# adjust from 50 Å box
prot_radius *= 50/box_size
prot_pos = ((prot_pos - 0.5)*50/box_size)+0.5

In [None]:
## construct water trajectories

# compute water-box size at standard state concentration
# (55.345 mol/L * 6.022e23 w/mol 1e3 L/m3 1e-30 m3/Å3) == 0.033328759 w/Å3
s_wat = (1/(55.345 * 6.022e23 * 1e3 * 1e-30))**(1/3) # or 3.10737465 Å / water

# choose a size for water (water radius is 1.925 pm)
r_wat = 1.925*0.5 # Å

# set the number of particles and radii
n_waters = int((box_size**2)/(s_wat**2))

# maximum # of packing attempts per water
n_attempts = 10000

# initialize container
wat_pos = np.nan*np.ones((n_frames, n_waters, 2))

# iterate through frames
for i in tqdm.notebook.tqdm(range(n_frames)):
    # loop through waters
    for j in range(n_waters):
        # try n_attempt times to pack
        for k in range(n_attempts):
            # random position without bounds
            pos = r_wat+(box_size-2*r_wat)*np.random.random(2)

            # compute displacements
            disp = wat_pos[i][:j] - pos[None]

            # if no collisions, add position and break
            if np.all((disp*disp).sum(axis=-1) > (2*r_wat)**2):
                wat_pos[i][j] = pos
                break

        # error if too many attempts without packing!
        if k+1 >= n_attempts:
            raise RuntimeError(f"frame {i} water {j} couldn't be packed with {n_attempts} attempts")

# scale waters
wat_pos /= box_size
wat_radius = r_wat/box_size

# resolve water-protein collisions
wat_prot_distances = np.linalg.norm(prot_pos[:, None] - wat_pos[:, :, None], axis=-1)
collisions = np.any(wat_prot_distances < prot_radius+wat_radius, axis=-1)
wat_pos[collisions] = np.array([-1, -1])

In [None]:
# construct graphs, drop colliding waters
wat_wat_distances = np.linalg.norm(wat_pos[:, None] - wat_pos[:, :, None], axis=-1)
wat_indices = np.triu_indices(n_waters, k=1)
wat_Gs = [
    nx.Graph(list(np.stack(wat_indices).T[sel]))
    for sel in (wat_wat_distances[(slice(None),) + wat_indices] < 4*wat_radius)
]
_ = [G.add_nodes_from(set(range(n_waters))-set(G.nodes)) for G in wat_Gs]
_ = [wat_Gs[i].remove_nodes_from(np.where(sel)[0]) for i, sel in enumerate(collisions)]

In [None]:
# identify non-bulk clusters and their contacts
bulk_cutoff = 20
clusters_contacts = [
    {
        j: (nodes, np.where(np.any(wat_prot_distances[i, sorted(nodes)] < 2*(wat_radius+prot_radius), axis=0))[0])
        for j, nodes in enumerate(nx.connected_components(G))
        if len(nodes) <= bulk_cutoff
    }
    for i, G in enumerate(wat_Gs)
]
clusters_contacts = [
    {
        k: (ids, contacts)
        for k, (ids, contacts) in data.items()
        if len(contacts) > 0
    }
    for data in clusters_contacts
]
nonbulks = [set.union(*[c[0] for c in cs.values()]) for cs in clusters_contacts]

In [None]:
## draw interactive canvas

# initialize canvas
canvas = Canvas()
canvas.init_figure(figsize=(7, 7))
circles = []
circles += canvas.draw_circles(wat_pos[0], wat_radius)
circles += canvas.draw_circles(prot_pos[0], prot_radius, style=dict(edgecolor="k", facecolor="#C0C", linewidth=2))
positions = np.concatenate([wat_pos, prot_pos], axis=1)
canvas.lines = []

# method to update canvas each frame
def update(frame):
    # clear old lines and draw new ones
    [l.remove() for l in canvas.lines]
    canvas.lines = [
        canvas.ax.add_line(matplotlib.lines.Line2D(*wat_pos[frame][edge, :].T, c='k'))
        for edge in wat_Gs[frame].edges
    ]

    # iterate through and update circles
    for i, circle in enumerate(circles):
        # highlight nonbulk waters
        if i not in nonbulks[frame]:
            circle.set_alpha(0.2)
        else:
            circle.set_alpha(1.0)

        # reposition circles
        circle.center = positions[frame][i]

    # highlight protein contacts
    for clusters, contacts in clusters_contacts[frame].values():
        for j in contacts:
            circles[n_waters+j].set_alpha(1.0)

# create interactive widget
Slider(update, "frame", max=len(positions)-1).player

------------------------------------

In [None]:
# reconstruct water sites relative to reference protein structure
ref_prot_pos = prot_pos[n_frames//2]

In [None]:
# displacement-based
sites = []
for i, data in enumerate(clusters_contacts):
    for k, (atoms, contacts) in data.items():
        disps = wat_pos[i][sorted(atoms)][:, None]-prot_pos[i][contacts][None, :]
        assert disps.shape == (len(atoms), len(contacts), 2)
        for disp in disps:
            site_wts = np.zeros((len(prot_pos[0]), 1))
            site_dsp = np.zeros((len(prot_pos[0]), 2))
            inv_dist = 1/np.linalg.norm(disp, axis=-1)
            site_wts[contacts, :] = inv_dist[:, None]/inv_dist.sum()
            site_dsp[contacts, :] = disp
            sites.append((site_wts, site_dsp))

wts, dsps = map(np.array, zip(*sites))
xtal_pos = (wts*(ref_prot_pos[None]+dsps)).sum(axis=1)

In [None]:
# LAWS-based
sites = []
for i, data in enumerate(clusters_contacts):
    for k, (atoms, contacts) in data.items():
        disps = wat_pos[i][sorted(atoms)][:, None]-prot_pos[i][contacts][None, :]
        assert disps.shape == (len(atoms), len(contacts), 2)
        for disp in disps:
            d2s = (disp*disp).sum(axis=-1)
            wts = (1/d2s)/(1/d2s).sum()
            sites.append((np.sqrt(wts), np.sqrt(d2s), contacts))

# for i, (w, d, p) in enumerate(sites):
#     if len(p) < 2: continue
#     opt = scipy.optimize.least_squares(lambda x: w*(np.linalg.norm(x-p, axis=-1)-d), xtal_pos[i], method="lm")
#     assert opt.success

solns = [
    scipy.optimize.least_squares(lambda x: w*(np.linalg.norm(x-ref_prot_pos[p], axis=-1)-d), xtal_pos[i], method="lm").x
    for i, (w, d, p) in enumerate(sites)
    if len(p) >= 2
]
laws_pos = np.array(solns)

In [None]:
def distance_graph_clustering(positions, cutoff):
    distances = np.linalg.norm(positions[None] - positions[:, None], axis=-1)
    indices = np.triu_indices(len(positions), k=1)
    sel = distances[indices] < cutoff
    return nx.Graph(list(np.stack(indices).T[sel]))

In [None]:
G = distance_graph_clustering(laws_pos, 0.02)
clusters = [x for x in nx.connected_components(G) if len(x) > 5]

In [None]:
vs, cs = zip(*[(laws_pos[x], {-1: 'k', 0: 'r', 1: 'g', 2: 'b', 3: '#FF0', 4: '#0FF'}[i]) for i, xs in enumerate(clusters) for x in xs])

In [None]:
import sklearn.cluster
db = sklearn.cluster.DBSCAN(eps=0.01, min_samples=10).fit(laws_pos)
colors = [{-1: 'k', 0: 'r', 1: 'g', 2: 'b', 3: '#FF0', 4: '#0FF'}[i] for i in db.labels_]

In [None]:
import sklearn.cluster
labels = sklearn.cluster.Birch(threshold=0.0001, n_clusters=4).fit_predict(laws_pos)
colors = [{-1: 'k', 0: 'r', 1: 'g', 2: 'b', 3: '#FF0', 4: '#0FF'}[i] for i in labels]

In [None]:
import sklearn.mixture
labels = sklearn.mixture.GaussianMixture(n_components=4).fit_predict(laws_pos)
colors = [{-1: 'k', 0: 'r', 1: 'g', 2: 'b', 3: '#FF0', 4: '#0FF'}[i] for i in labels]

In [None]:
import sklearn.cluster
labels = sklearn.cluster.AgglomerativeClustering(n_clusters=4).fit_predict(laws_pos)
colors = [{-1: 'k', 0: 'r', 1: 'g', 2: 'b', 3: '#FF0', 4: '#0FF'}[i] for i in labels]

In [None]:
import scipy.stats
xmin, xmax, ymin, ymax, n = 0, 1, 0, 1, 100
X, Y = np.mgrid[xmin:xmax:complex(n), ymin:ymax:complex(n)]
rs = np.vstack([X.ravel(), Y.ravel()])
kernel = scipy.stats.gaussian_kde(laws_pos.T)
Z = np.reshape(kernel(rs).T, X.shape)

fig, ax = plt.subplots()
ax.imshow(np.rot90(Z), cmap=plt.cm.gist_earth_r,
          extent=[xmin, xmax, ymin, ymax])
# ax.scatter(*laws_pos.T, s=2, c=colors)
ax.scatter(*np.stack(vs).T, s=2, c=cs)
ax.set_xlim([xmin, xmax])
ax.set_ylim([ymin, ymax])
plt.show()

In [None]:
# visualize results
_canvas = Canvas()
_canvas.init_figure(figsize=(7, 7))
_canvas.circles = []
_canvas.circles += _canvas.draw_circles(ref_prot_pos, prot_radius, style=dict(edgecolor="k", facecolor="#C0C", linewidth=2))
_canvas.circles += _canvas.draw_circles(laws_pos, wat_radius/4)

------------------------------------------

In [None]:
## visualize protein selection SDF

# evalaute SDF over grid
X, Y = np.meshgrid(np.linspace(0, 1, 50), np.linspace(0, 1, 50))
Z = SDF(np.concatenate(np.stack([Y, X]).T)).reshape(X.shape)

# view SDF
plt.figure()
plt.title("SDF")
plt.contourf(X, Y, Z, levels=10)
plt.colorbar()
plt.tight_layout()

In [None]:
## visualize distortion vector field

# compute vector field
pts = np.concatenate(np.stack(np.meshgrid(np.linspace(0, 1, 30), np.linspace(0, 1, 30)), axis=0).T)

# view vector field
plt.figure()
plt.title("$f(x, y) = (sin(y), sin(x))$")
plt.quiver(*pts.T, *field_func(pts).T)
plt.tight_layout()

In [None]:
## view raw trajectory

# initialize canvas
_canvas = Canvas()
_canvas.init_figure(figsize=(7, 7))
_circles = []
_circles += _canvas.draw_circles(wat_pos[0], wat_radius)
_circles += _canvas.draw_circles(prot_pos[0], prot_radius, style=dict(edgecolor="k", facecolor="#C0C", linewidth=2))
_positions = np.concatenate([wat_pos, prot_pos], axis=1)

# method to update canvas each frame
def old_update(frame):
    # iterate through and update circles
    for i, circle in enumerate(_circles):
        # reposition _circles
        circle.center = _positions[frame][i]

# create interactive widget
Slider(old_update, "frame", max=len(_positions)-1, interval=200).player

-------------------------