In [1]:
import sys
import os
import itertools

import pandas as pd
import numpy  as np
import multiprocessing as mp

from lib import nar
from lib import pdb
from scipy.spatial import KDTree


In [2]:
argv = [
    'r=data/4kqy_ori.pdb', 
    'q=data/2gis.pdb', 
    'rmsdmax=3', 
    'sizemin=20',
    'saveto=results/4kqy', 
    'threads=-1'
]

In [3]:
rres    = ''
qres    = ''
rformat = 'PDB'
qformat = 'PDB'

rmsdmin     = 0.
rmsdmax     = 1e10
sizemin     = 0.
sizemax     = 1e10
rmsdsizemin = 0.
rmsdsizemax = 1e10
matchrange  = 3.

saveto     = None
saveres    = None
saveformat = 'PDB'

threads = 1

seed_res_repr = (
    # For primary alignment
    nar.five_atom_repr,
    
    # To calculate centers of mass
    nar.five_atom_repr,
    
    # For secondary alignment
    nar.three_atom_repr,
    
    # To calculate the RMSD
    nar.three_atom_repr,
)

# Keep alternative atom locations = 'first', 'last', False
keep = 'last'

In [4]:
def get_transform(r:np.ndarray, q:np.ndarray):
    r_avg = r.mean(axis=0)
    q_avg = q.mean(axis=0)
    
    r = r - r_avg
    q = q - q_avg
    
    M = np.dot(np.transpose(q), r)
    u, s, vh = np.linalg.svd(M)
    
    rot = np.transpose(np.dot(np.transpose(vh), np.transpose(u)))
    if np.linalg.det(rot) < 0:
        vh[2] = -vh[2]
        rot = np.transpose(np.dot(np.transpose(vh), np.transpose(u)))
    tran = r_avg - np.dot(q_avg, rot)
    
    return rot, tran


def RMSD(r:np.ndarray, q:np.ndarray) -> float:
    diff = r - q
    return np.sqrt(np.sum(np.sum(np.multiply(diff, diff))) / len(r))

def RMSD_simple(r:np.ndarray, q:np.ndarray) -> float:
    diff = r - q
    return np.sum(np.sum(np.multiply(diff, diff)))


def apply_transform(coord:np.ndarray, rotran):
    rot, tran = rotran
    return np.dot(coord, rot) + tran


def mutual_nb(dist) -> list:
    U_1 = {}
    U_2 = {}
    
    for e in dist:
        v_1, v_2, d = e
        if v_1 not in U_1:
            U_1[v_1] = v_2, d
        else:
            if d < U_1[v_1][1]:
                U_1[v_1] = v_2, d
        
        if v_2 not in U_2:
            U_2[v_2] = v_1, d
        else:
            if d < U_2[v_2][1]:
                U_2[v_2] = v_1, d
    
    alt = []
    for v_2 in U_2:
        v_1, d = U_2[v_2]
        if U_1[v_1][0] == v_2:
            alt.append((v_1, v_2))
    
    return alt


def vstack(alt):
    ref_coord, alg_coord = zip(*alt)
    return np.vstack(ref_coord), np.vstack(alg_coord)

In [5]:
kwargs  = dict([arg.split('=') for arg in argv])
threads = int(kwargs.get('threads', threads))
if threads != 1:
    mp.set_start_method('fork')

r       = kwargs.get('r')
q       = kwargs.get('q')
rres    = kwargs.get('rres', rres)
qres    = kwargs.get('qres', qres)
rformat = kwargs.get('rformat', rformat)
qformat = kwargs.get('qformat', qformat)

rmsdmin     = float(kwargs.get('rmsdmin', rmsdmin))
rmsdmax     = float(kwargs.get('rmsdmax', rmsdmax))
sizemin     = float(kwargs.get('sizemin', sizemin))
sizemax     = float(kwargs.get('sizemax', sizemax))
rmsdsizemin = float(kwargs.get('rmsdsizemin', rmsdsizemin))
rmsdsizemax = float(kwargs.get('rmsdsizemax', rmsdsizemax))
matchrange  = float(kwargs.get('matchrange', matchrange))

saveto     = kwargs.get('saveto', saveto)
saveres    = kwargs.get('saveres', saveres)

threads = int(kwargs.get('threads', threads))

rname, rext = r.split(os.sep)[-1].split('.')
qname, qext = q.split(os.sep)[-1].split('.')

rext = rext.upper()
qext = qext.upper()

if rext in ['PDB', 'CIF']:
    rformat = rext

if qext in ['PDB', 'CIF']:
    qformat = qext

saveformat = kwargs.get('saveformat', qformat)

In [6]:
rstruct = pdb.parser(r, rformat, rname)
qstruct = pdb.parser(q, qformat, qname)

rsstruct = rstruct.get_sub_struct(rres)
qsstruct = qstruct.get_sub_struct(qres)

rsstruct.drop_duplicates_alt_id(keep=keep)
qsstruct.drop_duplicates_alt_id(keep=keep)

In [7]:
carrier = set.intersection(
    *map(
        lambda x: set(x.keys()),
        seed_res_repr
    )
)
res_repr = {}
for res in carrier:
    res_repr[res] = [rr[res] for rr in  seed_res_repr]

In [8]:
rrres, rures = rsstruct.artem_desc(res_repr)
qrres, qures = qsstruct.artem_desc(res_repr)

In [9]:
r_avg = np.vstack([res[2] for res in rrres])
q_avg = np.vstack([res[2] for res in qrres])

In [10]:
r_code, r_prim, r_avg, r_scnd, r_eval = zip(*rrres)
q_code, q_prim, q_avg, q_scnd, q_eval = zip(*qrres)

r_avg = np.vstack(r_avg)
q_avg = np.vstack(q_avg)

cpairs = list(itertools.product(r_code, q_code))
ipairs = itertools.product(range(len(r_code)), range(len(q_code)))

r_avg_tree = KDTree(r_avg)

In [11]:
def task(m, n):
    transform  = get_transform(r_prim[m], q_prim[n])
    
    q_avg_tree = KDTree(apply_transform(q_avg, transform))
    dist = r_avg_tree.sparse_distance_matrix(
        q_avg_tree,
        matchrange,
        p=2,
        output_type='ndarray'
    )
    
    nb   = mutual_nb(dist)
    size = len(nb)
    if not sizemin <= size <= sizemax:
        return None
    
    scnd = vstack([[r_scnd[i], q_scnd[j]] for i, j in nb])
    transform  = get_transform(*scnd)
    
    r_coord, q_coord = vstack([[r_eval[i], q_eval[j]] for i, j in nb])
    q_coord = apply_transform(q_coord, transform)
    
    rmsd = RMSD(r_coord, q_coord)
    if not rmsdmin <= rmsd <= rmsdmax:
        return None
    
    rmsdsize = rmsd / size
    if not rmsdsizemin <= rmsdsize <= rmsdsizemax:
        return None
    
    nb.sort()
    return [size, rmsd, rmsdsize, tuple(nb), transform]

In [12]:
if saveto:
    if saveres:
        sstruct = qstruct.get_sub_struct(saveres)
    else:
        sstruct = qsstruct

In [13]:
def saver(row):
    struct = sstruct.apply_transform(row['TRAN'])
    struct.rename('{}_{}'.format(struct, row.name))
    struct.saveto(saveto, saveformat)

In [14]:
if threads == 1:
    result = [task(m, n) for m, n in ipairs]
else:
    if threads == -1:
        pool = mp.Pool(mp.cpu_count())
    else:
        pool = mp.Pool(threads)
    result = pool.starmap(task, ipairs)

In [15]:
items = {}
for i, item in enumerate(result):
    if item:
        nb = item[-2]
        if nb not in items:
            items[nb] = item + [[i]]
        else:
            items[nb][-1].append(i)
            if item[1] < items[nb][1]:
                items[nb][1] = item[1]
                items[nb][2] = item[2]

In [16]:
columns = ['SIZE', 'RMSD', 'RMSDSIZE', 'PRIM', 'SCND', 'TRAN']

tab = []
for code in rures:
    tab.append([0, None, None, rsstruct.name, code, None])
for code in qures:
    tab.append([0, None, None, qsstruct.name, code, None])

zero_count = len(tab)
i = zero_count
transforms = {}
for item in items.values():
    seeds = sorted(item[-1])
    seeds = ','.join(['='.join(cpairs[i]) for i in seeds])
    pairs = item[-3]
    pairs = ','.join('='.join([r_code[m], q_code[n]]) for m, n in pairs)
    tab.append([*item[:3], seeds, pairs, item[4]])
    i += 1

tab = pd.DataFrame(tab, columns=columns)

In [17]:
tab   = tab.sort_values(['SIZE', 'RMSDSIZE'], ascending=[True, False])
# order = tab[tab['SIZE'] > 0].index.to_list()
tab.index = range(1, len(tab) + 1)
tab.index.name = 'ID'

tab.to_csv(
    sys.stdout, 
    columns=['SIZE', 'RMSD', 'RMSDSIZE', 'PRIM', 'SCND'],
    sep='\t',
    float_format='{:0.3f}'.format
)

ID	SIZE	RMSD	RMSDSIZE	PRIM	SCND
1	0			4kqy_ori	1.A.G.0.
2	0			4kqy_ori	1.A.SAM.201.
3	0			2gis	1.A.MG.205.
4	0			2gis	1.A.MG.206.
5	0			2gis	1.A.IRI.201.
6	0			2gis	1.A.IRI.202.
7	0			2gis	1.A.IRI.203.
8	0			2gis	1.A.IRI.204.
9	0			2gis	1.A.SAM.301.
10	0			2gis	1.A.HOH.401.
11	0			2gis	1.A.HOH.402.
12	0			2gis	1.A.HOH.403.
13	0			2gis	1.A.HOH.404.
14	0			2gis	1.A.HOH.405.
15	0			2gis	1.A.HOH.406.
16	0			2gis	1.A.HOH.407.
17	0			2gis	1.A.HOH.408.
18	0			2gis	1.A.HOH.409.
19	0			2gis	1.A.HOH.410.
20	0			2gis	1.A.HOH.411.
21	0			2gis	1.A.HOH.412.
22	0			2gis	1.A.HOH.413.
23	0			2gis	1.A.HOH.415.
24	0			2gis	1.A.HOH.416.
25	0			2gis	1.A.HOH.417.
26	0			2gis	1.A.HOH.418.
27	0			2gis	1.A.HOH.419.
28	0			2gis	1.A.HOH.420.
29	0			2gis	1.A.HOH.421.
30	0			2gis	1.A.HOH.422.
31	0			2gis	1.A.HOH.423.
32	0			2gis	1.A.HOH.424.
33	0			2gis	1.A.HOH.425.
34	0			2gis	1.A.HOH.426.
35	0			2gis	1.A.HOH.427.
36	0			2gis	1.A.HOH.428.
37	0			2gis	1.A.HOH.429.
38	0			2gis	1.A.HOH.430.
39	0			2gis	1.A.HOH.431.


In [20]:
# %%timeit
if 'sstruct' in locals():
    pool.map(saver, tab[tab['SIZE'] > 0].iloc)