In [39]:
rejects = []
def flatten(seq):
    for s in seq:
        if type(s) in (tuple, list, set):
            flatten(s)
        else:
            rejects.append(s)

def scramble(array, sequence):
    return np.array([array[s] for s in sequence])


def prune_conformers(structures, atomnos, k = 1, max_rmsd = 1):

    if k != 1:

        r = np.arange(structures.shape[0])
        sequence = np.random.permutation(r)
        inv_sequence = np.array([np.where(sequence == i)[0][0] for i in r], dtype=int)

        structures = scramble(structures, sequence)
        # energies = scramble_mask(energies, sequence)
        # scrambling array before splitting, so to improve efficiency when doing
        # multiple runs of group pruning

    mask_out = []
    d = len(structures) // k

    for step in range(k):
        if step == k-1:
            structures_subset = structures[d*step:]
            # energies_subset = energies[d*step:]
        else:
            structures_subset = structures[d*step:d*(step+1)]
            # energies_subset = energies[d*step:d*(step+1)]

        rmsd_mat = np.zeros((len(structures_subset), len(structures_subset)))
        rmsd_mat[:] = np.nan
        for i, tgt in enumerate(structures_subset):
            for j, ref in enumerate(structures_subset[i+1:]):
                val = rmsd(tgt, ref, atomnos, atomnos, center=True, minimize=True)
                rmsd_mat[i, i+j+1] = val
                if val < max_rmsd:
                    break


        where = np.where(rmsd_mat < max_rmsd)
        matches = [(i,j) for i,j in zip(where[0], where[1])]

        g = nx.Graph(matches)

        subgraphs = [g.subgraph(c) for c in nx.connected_components(g)]
        groups = [tuple(graph.nodes) for graph in subgraphs]

        best_of_cluster = [group[0] for group in groups]
        # re-do with energies

        rejects_sets = [set(a) - {b} for a, b in zip(groups, best_of_cluster)]

        flatten(rejects_sets)

        mask = np.array([True for _ in range(len(structures_subset))], dtype=bool)
        for i in rejects:
            mask[i] = False

        mask_out.append(mask)
    
    mask = np.concatenate(mask_out)

    if k != 1:
        mask = scramble(mask, inv_sequence)
        structures = scramble(structures, inv_sequence)
        # undoing the previous shuffling, therefore preserving the input order

    return structures[mask], mask

In [2]:
from cclib.io import ccread
import os
import numpy as np
from spyrmsd.rmsd import rmsd
import networkx as nx

# from prune import prune_conformers as p
# from prune import prune_conformers_v2 as p2
%load_ext cython
os.chdir('Resources/SN2')
mol = ccread('TS_out_test.xyz')
stack = np.vstack([mol.atomcoords for _ in range(50)])
stack.shape

(200, 19, 3)

In [2]:
from ase import Atoms
from ase.visualize import view
view(Atoms(mol.atomnos, positions=mol.atomcoords[0]))

NameError: name 'mol' is not defined

In [46]:
%%cython --compile=-fopenmp --link-args=-fopenmp
import numpy as np
cimport numpy as np
import networkx as nx
cimport cython

from libc.stdio cimport stdout, fprintf
import sys
from time import time
# sys.stdout.flush()


DTYPE = np.float
ctypedef np.float_t DTYPE_t


# cdef int en(tup, energies):
#     ens = [energies[t] for t in tup]
#     return tup[ens.index(min(ens))]
        

@cython.boundscheck(False) # turn off bounds-checking for entire function
@cython.wraparound(False)  # turn off negative index wrapping for entire function
cdef inline np.ndarray[np.int_t, ndim=3, mode='c'] scramble_mask(np.ndarray array,
                                                np.ndarray[np.int_t, ndim=1, mode='c'] sequence):
    return np.array([array[s] for s in sequence])

@cython.boundscheck(False) # turn off bounds-checking for entire function
@cython.wraparound(False)  # turn off negative index wrapping for entire function
cdef inline np.ndarray[np.int_t, ndim=3, mode='c'] scramble(np.ndarray[DTYPE_t, ndim=3, mode='c'] array,
                                                            np.ndarray[np.int_t, ndim=1, mode='c'] sequence):
    return np.array([array[s] for s in sequence])

@cython.boundscheck(False) # turn off bounds-checking for entire function
@cython.wraparound(False)  # turn off negative index wrapping for entire function
cpdef np.ndarray[np.int_t, ndim=1, mode='c'] prune_conformers_v2(np.ndarray[DTYPE_t, ndim=3, mode='c'] structures,
                                                                 np.ndarray[np.int_t, ndim=1, mode='c'] atomnos,
                                                                 int k = 1,
                                                                 double max_rmsd = 1.):

    cdef np.ndarray[np.int_t, ndim=1, mode='c'] r, sequence, inv_sequence, mask

    if k != 1:

        r = np.arange(structures.shape[0])
        sequence = np.random.permutation(r)
        inv_sequence = np.array([np.where(sequence == i)[0][0] for i in r], dtype=int)

        structures = scramble(structures, sequence)
        # energies = scramble_mask(energies, sequence)
        # scrambling array before splitting, so to improve efficiency when doing
        # multiple runs of group pruning

    cdef list mask_out = []
    cdef unsigned int step, d = len(structures) // k
    cdef unsigned int l
    cdef unsigned int i, j
    cdef np.ndarray[DTYPE_t, ndim=2, mode='c'] rmsd_mat
    cdef double[:,:] rmsd_mat_view
    cdef int[:] energies_subset
    cdef np.ndarray[DTYPE_t, ndim=3, mode='c'] structures_subset
    cdef tuple where
    cdef list matches, subgraphs, groups, best_of_cluster, rejects_sets, rejects
    cdef object g
    cdef double val

    for step in range(k):
        if step == k-1:
            structures_subset = structures[d*step:]
            # energies_subset = energies[d*step:]
        else:
            structures_subset = structures[d*step:d*(step+1)]
            # energies_subset = energies[d*step:d*(step+1)]

        l = structures_subset.shape[0]
        rmsd_mat = np.zeros((l, l))
        rmsd_mat[:] = max_rmsd
        rmsd_mat_view = rmsd_mat

        # t0 = time()

        for i in range(l):
            for j in range(i+1,l):
                val = rmsd_c(structures_subset[i], structures_subset[j])
                rmsd_mat_view[i, j] = val
                if val < max_rmsd:
                    break


        # t1 = time()

        where = np.where(rmsd_mat < max_rmsd)
        matches = [(i,j) for i,j in zip(where[0], where[1])]

        g = nx.Graph(matches)

        subgraphs = [g.subgraph(c) for c in nx.connected_components(g)]
        groups = [tuple(graph.nodes) for graph in subgraphs]

        best_of_cluster = [group[0] for group in groups]
        # re-do with energies?

        rejects_sets = [set(a) - {b} for a, b in zip(groups, best_of_cluster)]
        rejects = []
        for s in rejects_sets:
            for i in s:
                rejects.append(i)

        mask = np.array([1 for _ in range(len(structures_subset))], dtype=int)
        for i in rejects:
            mask[i] = 0

        mask_out.append(mask)
    
    mask = np.concatenate(mask_out)

    if k != 1:
        mask = scramble_mask(mask, inv_sequence)
        structures = scramble(structures, inv_sequence)
        # undoing the previous shuffling, therefore preserving the input order

    # t2 = time()
    # print(f'First step: {round(t1-t0, 2)} s\nSecond step: {round(t2-t1, 2)} s')
    # sys.stdout.flush()

    return mask


@cython.boundscheck(False) # turn off bounds-checking for entire function
@cython.wraparound(False)  # turn off negative index wrapping for entire function
cdef inline double rmsd_c(np.ndarray[DTYPE_t, ndim=2, mode='c'] coords1, np.ndarray[DTYPE_t, ndim=2, mode='c'] coords2):

    cdef float atol = 1e-9
    cdef double c0, c1, c2
    cdef np.ndarray[DTYPE_t, ndim=2, mode='c'] A = coords1 - np.mean(coords1, axis=0)
    cdef np.ndarray[DTYPE_t, ndim=2, mode='c'] B = coords2 - np.mean(coords2, axis=0)


    cdef int N = A.shape[0]

    cdef double Ga = np.trace(A.T @ A)
    cdef double Gb = np.trace(B.T @ B)

    cdef np.ndarray[DTYPE_t, ndim=2, mode='c'] M = M_mtx(A, B)
    cdef np.ndarray[DTYPE_t, ndim=2, mode='c'] K = K_mtx(M)

    c2, c1, c0 = coefficients(M, K)

    cdef double l_max = _lambda_max_eig(K)

    cdef double s = Ga + Gb - 2 * l_max
    cdef double rmsd

    if abs(s) < atol:  # Avoid numerical errors when Ga + Gb = 2 * l_max
        rmsd = 0.0
    else:
        rmsd = np.sqrt(s / N)

    return rmsd

@cython.boundscheck(False) # turn off bounds-checking for entire function
@cython.wraparound(False)  # turn off negative index wrapping for entire function
cdef inline tuple coefficients(np.ndarray[DTYPE_t, ndim=2, mode='c'] M, np.ndarray[DTYPE_t, ndim=2, mode='c'] K):

    cdef double c2 = -2 * np.trace(M.T @ M)
    cdef double c1 = -8 * np.linalg.det(M)
    cdef double c0 = np.linalg.det(K)

    return c2, c1, c0

@cython.boundscheck(False) # turn off bounds-checking for entire function
@cython.wraparound(False)  # turn off negative index wrapping for entire function
cdef inline np.ndarray M_mtx(np.ndarray[DTYPE_t, ndim=2, mode='c'] A, np.ndarray[DTYPE_t, ndim=2, mode='c'] B):
    return B.T @ A

@cython.boundscheck(False) # turn off bounds-checking for entire function
@cython.wraparound(False)  # turn off negative index wrapping for entire function
cdef inline np.ndarray K_mtx(np.ndarray[DTYPE_t, ndim=2, mode='c'] M):

    S_xx = M[0, 0]
    S_xy = M[0, 1]
    S_xz = M[0, 2]
    S_yx = M[1, 0]
    S_yy = M[1, 1]
    S_yz = M[1, 2]
    S_zx = M[2, 0]
    S_zy = M[2, 1]
    S_zz = M[2, 2]

    # p = plus, m = minus
    S_xx_yy_zz_ppp = S_xx + S_yy + S_zz
    S_yz_zy_pm = S_yz - S_zy
    S_zx_xz_pm = S_zx - S_xz
    S_xy_yx_pm = S_xy - S_yx
    S_xx_yy_zz_pmm = S_xx - S_yy - S_zz
    S_xy_yx_pp = S_xy + S_yx
    S_zx_xz_pp = S_zx + S_xz
    S_xx_yy_zz_mpm = -S_xx + S_yy - S_zz
    S_yz_zy_pp = S_yz + S_zy
    S_xx_yy_zz_mmp = -S_xx - S_yy + S_zz

    return np.array(
        [
            [S_xx_yy_zz_ppp, S_yz_zy_pm, S_zx_xz_pm, S_xy_yx_pm],
            [S_yz_zy_pm, S_xx_yy_zz_pmm, S_xy_yx_pp, S_zx_xz_pp],
            [S_zx_xz_pm, S_xy_yx_pp, S_xx_yy_zz_mpm, S_yz_zy_pp],
            [S_xy_yx_pm, S_zx_xz_pp, S_yz_zy_pp, S_xx_yy_zz_mmp],
        ]
    )

@cython.boundscheck(False) # turn off bounds-checking for entire function
@cython.wraparound(False)  # turn off negative index wrapping for entire function
cdef inline double _lambda_max_eig(np.ndarray[DTYPE_t, ndim=2, mode='c'] K):
    return np.max(np.linalg.eig(K)[0])




In [50]:
mask = prune_conformers_v2(stack, mol.atomnos)
print(f'{len([i for i in mask if i==1])} structures kept')
mask

4 structures kept


array([1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0])

In [49]:
_, mask = prune_conformers(stack, mol.atomnos)
print(f'{len([i for i in mask if i==1])} structures kept')
mask

4 structures kept


array([ True,  True,  True,  True, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False,

In [47]:
%timeit prune_conformers_v2(stack, mol.atomnos, k=1)
print('CYTHON\n')
%timeit prune_conformers(stack, mol.atomnos, k=1)
print('PYTHON')


138 ms ± 5.41 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
CYTHON

338 ms ± 10.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
PYTHON


In [48]:
(138/338)**-1

2.449275362318841

In [25]:
%%cython 
import numpy as np
cimport numpy as np
@cython.boundscheck(False) # turn off bounds-checking for entire function
@cython.wraparound(False)  # turn off negative index wrapping for entire function
cpdef test(np.ndarray[double, ndim=3, mode='c'] structures_subset):
    cdef unsigned int l = structures_subset.shape[0]
    cdef double[:,:,:] str_ptr = structures_subset
    cdef np.ndarray rmsd_mat = np.zeros((l,l), dtype=float)
    rmsd_mat[:] = 5
    cdef double[:,:] rmsd_mat_view = rmsd_mat
    cdef int i, j
    for i in range(l):
        for j in range(len(structures_subset[i+1:])):
            rmsd_mat_view[i, i+j+1] = np.sum(np.stack((structures_subset[i], structures_subset[j])))
    return rmsd_mat

In [17]:
# %timeit -r 10 prune_conformers_v2(stack, mol.atomnos, k=1)
# %timeit -r 1 get_rmsd_mat(stack)
%timeit test(stack)
# test(stack)[0,2]

1.19 s ± 22 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [6]:
import numpy as np
def ptest(structures_subset):
    rmsd_mat = np.zeros((len(structures_subset), len(structures_subset)))
    rmsd_mat[:] = 5
    for i in range(len(structures_subset)):
        for j in range(len(structures_subset[i+1:])):
            rmsd_mat[i, i+j+1] = np.sum(np.stack((structures_subset[i], structures_subset[j])))
    return rmsd_mat

In [7]:
%timeit ptest(stack)

1.31 s ± 47.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


# COMPENETRATION

In [None]:
%%cython
import numpy as np
cimport numpy as np
from math import sqrt
cimport cython

DTYPE = np.float64
ctypedef np.float64_t DTYPE_t

cdef float s
cdef float norm
@cython.boundscheck(False) # turn off bounds-checking for entire function
@cython.wraparound(False)  # turn off negative index wrapping for entire function
cdef double norm_of(double[:] v):
    s = v[0]*v[0] + v[1]*v[1] + v[2]*v[2]
    cdef double norm = sqrt(s)
    return norm

@cython.boundscheck(False) # turn off bounds-checking for entire function
@cython.wraparound(False)  # turn off negative index wrapping for entire function
cpdef int compenetration_check(np.ndarray[DTYPE_t, ndim=2] coords, list ids):

    cdef double thresh = 1.2
    cdef int clashes = 0
    cdef int max_clashes = 2
    # max_clashes clashes is good, max_clashes + 1 is not
    cdef np.ndarray[DTYPE_t, ndim=2] m1, m2, m3
    cdef np.ndarray[DTYPE_t, ndim=1] v1, v2, v3
    cdef float dist

    if len(ids) == 2:
        m1 = coords[0:ids[0]]
        m2 = coords[ids[0]:]
        for v1 in m1:
            for v2 in m2:
                dist = norm_of(v1-v2)
                if dist < thresh:
                    clashes += 1
                if clashes > max_clashes:
                    return 0
        return 1

    else:
        m1 = coords[0:ids[0]]
        m2 = coords[ids[0]:ids[0]+ids[1]]
        m3 = coords[ids[0]+ids[1]:]

        for v1 in m1:
            for v2 in m2:
                dist = norm_of(v1-v2)
                if dist < thresh:
                    clashes += 1
                if clashes > max_clashes:
                    return 0

        for v2 in m2:
            for v3 in m3:
                dist = norm_of(v2-v3)
                if dist < thresh:
                    clashes += 1
                if clashes > max_clashes:
                    return 0

        for v3 in m3:
            for v1 in m1:
                dist = norm_of(v3-v1)
                if dist < thresh:
                    clashes += 1
                if clashes > max_clashes:
                    return 0

        return 1

In [None]:
compenetration_check(stack[0], )