### Setup

In [1]:
import sys,importlib
import torch
import os
import time
from tqdm import tqdm
import random
import numpy as np
import pandas as pd
from IPython.display import clear_output

import matplotlib.pyplot as plt
import plotly.express as px

#import pykeops
#pykeops.clean_pykeops()

%matplotlib inline 

In [142]:
try:
    importlib.reload(sys.modules['dmasif_surface'])
    importlib.reload(sys.modules['data'])

except KeyError:
    pass
from dmasif_surface import atoms_to_points_normals
from data import AtomSurfaceDataset
from losses import chamfer_distance
from IPython.display import clear_output

### Оптимизируем параметры dmasif для тяжелых атомов

In [None]:
dataset=AtomSurfaceDataset(storage='structure_data.pkl') 

In [None]:
xs=[]
for data in tqdm(dataset):
    xs.append((data['atom_rad']>1.2).sum().item()/data['atom_rad'].shape[0])

print(np.mean(xs))
print(20/np.mean(xs))

Установим sup_sampling=34 чтобы скомпенсировать кол-во точек, генерируемых на каждый атом водорода.

In [None]:
def comp_loss(dataset, args):
    ch_losses=[]
    norm_losses=[]
    sizes=[]
    for data in tqdm(dataset):
        xyz=data['atom_xyz'][data['atom_rad']>1.2]
        rad=data['atom_rad'][data['atom_rad']>1.2]*100
        #xyz=data['atom_xyz']
        #rad=data['atom_rad']*100
        try:
            dmasif_vert, dmasif_norm, _ = atoms_to_points_normals(xyz, 
                                                       torch.zeros(xyz.shape[0], dtype=int), 
                                                        atom_rad=rad,
                                                        **args)
            d_ch_loss, d_norm_loss=chamfer_distance( dmasif_vert, data['target_xyz'], 
                                             dmasif_norm, data['target_normals'])
        except RuntimeError:
            continue
        ch_losses.append(d_ch_loss)
        norm_losses.append(d_norm_loss)
        sizes.append(dmasif_vert.shape[0])
    return ch_losses, norm_losses, sizes


In [None]:
ch=[]
no=[]
ss=[]
par=[1,1.05,1.1,1.15,1.2,1.25,1.3,1.35]
for parameter in par:
    args={'smoothness':0.5,'distance':parameter, 'sup_sampling':34, 'variance': 0.1 }
    a, b, c=comp_loss(dataset[:103], args)
    ch.append(np.mean(a))
    no.append(np.mean(b))
    ss.append(np.mean(c))
    clear_output()
    plt.figure(figsize=(10,3))
    plt.subplot(131)
    plt.title('chamfer')
    plt.plot(par[:len(ch)],ch, label='chamfer')
    plt.subplot(132)
    plt.title('normal')
    plt.plot(par[:len(no)],no, label='normal')
    plt.subplot(133)
    plt.title('size')
    plt.plot(par[:len(ss)],ss, label='size')
    plt.show()

Установим параметр distance=1.25 для генерации поверхности только по тяжелым атомам

### Martinize

Рассчитаем идеальные положения псевдоатомов в системе С-Са-N

In [3]:
# from https://github.com/baker-laboratory/rf_diffusion_all_atom/blob/main/util.py

def rigid_from_3_points(N, Ca, C, non_ideal=False, eps=1e-8):
    #N, Ca, C - [B,L, 3]
    #R - [B,L, 3, 3], det(R)=1, inv(R) = R.T, R is a rotation matrix
    B,L = N.shape[:2]
    
    v1 = C-Ca
    v2 = N-Ca
    e1 = v1/(torch.norm(v1, dim=-1, keepdim=True)+eps)
    u2 = v2-(torch.einsum('bli, bli -> bl', e1, v2)[...,None]*e1)
    e2 = u2/(torch.norm(u2, dim=-1, keepdim=True)+eps)
    e3 = torch.cross(e1, e2, dim=-1)
    R = torch.cat([e1[...,None], e2[...,None], e3[...,None]], axis=-1) #[B,L,3,3] - rotation matrix
    
    if non_ideal:
        v2 = v2/(torch.norm(v2, dim=-1, keepdim=True)+eps)
        cosref = torch.sum(e1*v2, dim=-1) # cosine of current N-CA-C bond angle
        costgt = cos_ideal_NCAC.item()
        cos2del = torch.clamp( cosref*costgt + torch.sqrt((1-cosref*cosref)*(1-costgt*costgt)+eps), min=-1.0, max=1.0 )
        cosdel = torch.sqrt(0.5*(1+cos2del)+eps)
        sindel = torch.sign(costgt-cosref) * torch.sqrt(1-0.5*(1+cos2del)+eps)
        Rp = torch.eye(3, device=N.device).repeat(B,L,1,1)
        Rp[:,:,0,0] = cosdel
        Rp[:,:,0,1] = -sindel
        Rp[:,:,1,0] = sindel
        Rp[:,:,1,1] = cosdel
    
        R = torch.einsum('blij,bljk->blik', R,Rp)

    return R, Ca

In [27]:
# from https://github.com/baker-laboratory/rf_diffusion_all_atom/blob/main/chemical.py

num2aa=[
    'ALA','ARG','ASN','ASP','CYS',
    'GLN','GLU','GLY','HIS','ILE',
    'LEU','LYS','MET','PHE','PRO',
    'SER','THR','TRP','TYR','VAL',
    'UNK','MAS',
    ]
ideal_coords = [
    [ # 0 ala
        [' N  ', 0, (-0.5272, 1.3593, 0.000)],
        [' CA ', 0, (0.000, 0.000, 0.000)],
        [' C  ', 0, (1.5233, 0.000, 0.000)],
        [' O  ', 3, (0.6303, 1.0574, 0.000)],
        [' H  ', 2, (0.4920,-0.8821,  0.0000)],
        [' HA ', 0, (-0.3341, -0.4928,  0.9132)],
        [' CB ', 8, (-0.5289,-0.7734,-1.1991)],
        ['1HB ', 8, (-0.1265, -1.7863, -1.1851)],
        ['2HB ', 8, (-1.6173, -0.8147, -1.1541)],
        ['3HB ', 8, (-0.2229, -0.2744, -2.1172)],
    ],
    [ # 1 arg
        [' N  ', 0, (-0.5272, 1.3593, 0.000)],
        [' CA ', 0, (0.000, 0.000, 0.000)],
        [' C  ', 0, (1.5233, 0.000, 0.000)],
        [' O  ', 3, (0.6303, 1.0574, 0.000)],
        [' H  ', 2, (0.4920,-0.8821,  0.0000)],
        [' HA ', 0, (-0.3467, -0.5055,  0.9018)],
        [' CB ', 8, (-0.5042,-0.7698,-1.2118)],
        ['1HB ', 4, ( 0.3635, -0.5318,  0.8781)],
        ['2HB ', 4, ( 0.3639, -0.5323, -0.8789)],
        [' CG ', 4, (0.6396,1.3794, 0.000)],
        ['1HG ', 5, (0.3639, -0.5139,  0.8900)],
        ['2HG ', 5, (0.3641, -0.5140, -0.8903)],
        [' CD ', 5, (0.5492,1.3801, 0.000)],
        ['1HD ', 6, (0.3637, -0.5135,  0.8895)],
        ['2HD ', 6, (0.3636, -0.5134, -0.8893)],
        [' NE ', 6, (0.5423,1.3491, 0.000)],
        [' NH1', 7, (0.2012,2.2965, 0.000)],
        [' NH2', 7, (2.0824,1.0030, 0.000)],
        [' CZ ', 7, (0.7650,1.1090, 0.000)],
        [' HE ', 7, (0.4701,-0.8955, 0.000)],
        ['1HH1', 7, (-0.8059,2.3776, 0.000)],
        ['1HH2', 7, (2.5160,0.0898, 0.000)],
        ['2HH1', 7, (0.7745,3.1277, 0.000)],
        ['2HH2', 7, (2.6554,1.8336, 0.000)],
    ],
    [ # 2 asn
        [' N  ', 0, (-0.5272, 1.3593, 0.000)],
        [' CA ', 0, (0.000, 0.000, 0.000)],
        [' C  ', 0, (1.5233, 0.000, 0.000)],
        [' O  ', 3, (0.6303, 1.0574, 0.000)],
        [' H  ', 2, (0.4920,-0.8821,  0.0000)],
        [' HA ', 0, (-0.3233, -0.4967,  0.9162)],
        [' CB ', 8, (-0.5341,-0.7799,-1.1874)],
        ['1HB ', 4, ( 0.3641, -0.5327,  0.8795)],
        ['2HB ', 4, ( 0.3639, -0.5323, -0.8789)],
        [' CG ', 4, (0.5778,1.3881, 0.000)],
        [' ND2', 5, (0.5839,-1.1711, 0.000)],
        [' OD1', 5, (0.6331,1.0620, 0.000)],
        ['1HD2', 5, (1.5825, -1.2322, 0.000)],
        ['2HD2', 5, (0.0323, -2.0046, 0.000)],
    ],
    [ # 3 asp
        [' N  ', 0, (-0.5272, 1.3593, 0.000)],
        [' CA ', 0, (0.000, 0.000, 0.000)],
        [' C  ', 0, (1.5233, 0.000, 0.000)],
        [' O  ', 3, (0.6303, 1.0574, 0.000)],
        [' H  ', 2, (0.4920,-0.8821,  0.0000)],
        [' HA ', 0, (-0.3233, -0.4967,  0.9162)],
        [' CB ', 8, (-0.5162,-0.7757,-1.2144)],
        ['1HB ', 4, ( 0.3639, -0.5324,  0.8791)],
        ['2HB ', 4, ( 0.3640, -0.5325, -0.8792)],
        [' CG ', 4, (0.5926,1.4028, 0.000)],
        [' OD1', 5, (0.5746,1.0629, 0.000)],
        [' OD2', 5, (0.5738,-1.0627, 0.000)],
    ],
    [ # 4 cys
        [' N  ', 0, (-0.5272, 1.3593, 0.000)],
        [' CA ', 0, (0.000, 0.000, 0.000)],
        [' C  ', 0, (1.5233, 0.000, 0.000)],
        [' O  ', 3, (0.6303, 1.0574, 0.000)],
        [' H  ', 2, (0.4920,-0.8821,  0.0000)],
        [' HA ', 0, (-0.3481, -0.5059,  0.9006)],
        [' CB ', 8, (-0.5046,-0.7727,-1.2189)],
        ['1HB ', 4, ( 0.3639, -0.5324,  0.8791)],
        ['2HB ', 4, ( 0.3638, -0.5322, -0.8787)],
        [' SG ', 4, (0.7386,1.6511, 0.000)],
        [' HG ', 5, (0.1387,1.3221, 0.000)],
    ],
    [ # 5 gln
        [' N  ', 0, (-0.5272, 1.3593, 0.000)],
        [' CA ', 0, (0.000, 0.000, 0.000)],
        [' C  ', 0, (1.5233, 0.000, 0.000)],
        [' O  ', 3, (0.6303, 1.0574, 0.000)],
        [' H  ', 2, (0.4920,-0.8821,  0.0000)],
        [' HA ', 0, (-0.3363, -0.5013,  0.9074)],
        [' CB ', 8, (-0.5226,-0.7776,-1.2109)],
        ['1HB ', 4, ( 0.3638, -0.5323,  0.8789)],
        ['2HB ', 4, ( 0.3638, -0.5322, -0.8788)],
        [' CG ', 4, (0.6225,1.3857, 0.000)],
        ['1HG ', 5, ( 0.3531, -0.5156,  0.8931)],
        ['2HG ', 5, ( 0.3531, -0.5156, -0.8931)],
        [' CD ', 5, (0.5788,1.4021, 0.000)],
        [' NE2', 6, (0.5908,-1.1895, 0.000)],
        [' OE1', 6, (0.6347,1.0584, 0.000)],
        ['1HE2', 6, (1.5825, -1.2525, 0.000)],
        ['2HE2', 6, (0.0380, -2.0229, 0.000)],
    ],
    [ # 6 glu
        [' N  ', 0, (-0.5272, 1.3593, 0.000)],
        [' CA ', 0, (0.000, 0.000, 0.000)],
        [' C  ', 0, (1.5233, 0.000, 0.000)],
        [' O  ', 3, (0.6303, 1.0574, 0.000)],
        [' H  ', 2, (0.4920,-0.8821,  0.0000)],
        [' HA ', 0, (-0.3363, -0.5013,  0.9074)],
        [' CB ', 8, (-0.5197,-0.7737,-1.2137)],
        ['1HB ', 4, ( 0.3638, -0.5323,  0.8789)],
        ['2HB ', 4, ( 0.3638, -0.5322, -0.8788)],
        [' CG ', 4, (0.6287,1.3862, 0.000)],
        ['1HG ', 5, ( 0.3531, -0.5156,  0.8931)],
        ['2HG ', 5, ( 0.3531, -0.5156, -0.8931)],
        [' CD ', 5, (0.5850,1.3849, 0.000)],
        [' OE1', 6, (0.5752,1.0618, 0.000)],
        [' OE2', 6, (0.5741,-1.0635, 0.000)],
    ],
    [ # 7 gly
        [' N  ', 0, (-0.5272, 1.3593, 0.000)],
        [' CA ', 0, (0.000, 0.000, 0.000)],
        [' C  ', 0, (1.5233, 0.000, 0.000)],
        [' O  ', 3, (0.6303, 1.0574, 0.000)],
        [' H  ', 2, (0.4920,-0.8821,  0.0000)],
        ['1HA ', 0, ( -0.3676, -0.5329,  0.8771)],
        ['2HA ', 0, ( -0.3674, -0.5325, -0.8765)],
    ],
    [ # 8 his
        [' N  ', 0, (-0.5272, 1.3593, 0.000)],
        [' CA ', 0, (0.000, 0.000, 0.000)],
        [' C  ', 0, (1.5233, 0.000, 0.000)],
        [' O  ', 3, (0.6303, 1.0574, 0.000)],
        [' H  ', 2, (0.4920,-0.8821,  0.0000)],
        [' HA ', 0, (-0.3299, -0.5180,  0.9001)],
        [' CB ', 8, (-0.5163,-0.7809,-1.2129)],
        ['1HB ', 4, ( 0.3640, -0.5325,  0.8793)],
        ['2HB ', 4, ( 0.3637, -0.5321, -0.8786)],
        [' CG ', 4, (0.6016,1.3710, 0.000)],
        [' CD2', 5, (0.8918,-1.0184, 0.000)],
        [' CE1', 5, (2.0299,0.8564, 0.000)],
        [' HE1', 5, (2.8542, 1.5693,  0.000)],
        [' HD2', 5, ( 0.6584, -2.0835, 0.000) ],
        [' ND1', 6, (-1.8631, -1.0722,  0.000)],
        [' NE2', 6, (-1.8625,  1.0707, 0.000)],
        [' HE2', 6, (-1.5439,  2.0292, 0.000)],
    ],
    [ # 9 ile
        [' N  ', 0, (-0.5272, 1.3593, 0.000)],
        [' CA ', 0, (0.000, 0.000, 0.000)],
        [' C  ', 0, (1.5233, 0.000, 0.000)],
        [' O  ', 3, (0.6303, 1.0574, 0.000)],
        [' H  ', 2, (0.4920,-0.8821,  0.0000)],
        [' HA ', 0, (-0.3405, -0.5028,  0.9044)],
        [' CB ', 8, (-0.5140,-0.7885,-1.2184)],
        [' HB ', 4, (0.3637, -0.4714,  0.9125)],
        [' CG1', 4, (0.5339,1.4348,0.000)],
        [' CG2', 4, (0.5319,-0.7693,-1.1994)],
        ['1HG2', 4, (1.6215, -0.7588, -1.1842)],
        ['2HG2', 4, (0.1785, -1.7986, -1.1569)],
        ['3HG2', 4, (0.1773, -0.3016, -2.1180)],
        [' CD1', 5, (0.6106,1.3829, 0.000)],
        ['1HG1', 5, (0.3637, -0.5338,  0.8774)],
        ['2HG1', 5, (0.3640, -0.5322, -0.8793)],
        ['1HD1', 5, (1.6978,  1.3006, 0.000)],
        ['2HD1', 5, (0.2873,  1.9236, -0.8902)],
        ['3HD1', 5, (0.2888, 1.9224, 0.8896)],
    ],
    [ # 10 leu
        [' N  ', 0, (-0.5272, 1.3593, 0.000)],
        [' CA ', 0, (0.000, 0.000, 0.000)],
        [' C  ', 0, (1.525, -0.000, -0.000)],
        [' O  ', 3, (0.6303, 1.0574, 0.000)],
        [' H  ', 2, (0.4920,-0.8821,  0.0000)],
        [' HA ', 0, (-0.3435, -0.5040,  0.9027)],
        [' CB ', 8, (-0.5175,-0.7692,-1.2220)],
        ['1HB ', 4, ( 0.3473, -0.5346,  0.8827)],
        ['2HB ', 4, ( 0.3476, -0.5351, -0.8836)],
        [' CG ', 4, (0.6652,1.3823, 0.000)],
        [' CD1', 5, (0.5083,1.4353, 0.000)],
        [' CD2', 5, (0.5079,-0.7600,1.2163)],
        [' HG ', 5, (0.3640, -0.4825, -0.9075)],
        ['1HD1', 5, (1.5984,  1.4353, 0.000)],
        ['2HD1', 5, (0.1462,  1.9496, -0.8903)],
        ['3HD1', 5, (0.1459, 1.9494, 0.8895)],
        ['1HD2', 5, (1.5983, -0.7606,  1.2158)],
        ['2HD2', 5, (0.1456, -0.2774,  2.1243)],
        ['3HD2', 5, (0.1444, -1.7871,  1.1815)],
    ],
    [ # 11 lys
        [' N  ', 0, (-0.5272, 1.3593, 0.000)],
        [' CA ', 0, (0.000, 0.000, 0.000)],
        [' C  ', 0, (1.5233, 0.000, 0.000)],
        [' O  ', 3, (0.6303, 1.0574, 0.000)],
        [' H  ', 2, (0.4920,-0.8821,  0.0000)],
        [' HA ', 0, (-0.3335, -0.5005,  0.9097)],
        ['1HB ', 4, ( 0.3640, -0.5324,  0.8791)],
        ['2HB ', 4, ( 0.3639, -0.5324, -0.8790)],
        [' CB ', 8, (-0.5259,-0.7785,-1.2069)],
        ['1HG ', 5, (0.3641, -0.5229,  0.8852)],
        ['2HG ', 5, (0.3637, -0.5227, -0.8841)],
        [' CG ', 4, (0.6291,1.3869, 0.000)],
        [' CD ', 5, (0.5526,1.4174, 0.000)],
        ['1HD ', 6, (0.3641, -0.5239,  0.8848)],
        ['2HD ', 6, (0.3638, -0.5219, -0.8850)],
        [' CE ', 6, (0.5544,1.4170, 0.000)],
        [' NZ ', 7, (0.5566,1.3801, 0.000)],
        ['1HE ', 7, (0.4199, -0.4638,  0.9482)],
        ['2HE ', 7, (0.4202, -0.4631, -0.8172)],
        ['1HZ ', 7, (1.6223, 1.3980, 0.0658)],
        ['2HZ ', 7, (0.2970,  1.9326, -0.7584)],
        ['3HZ ', 7, (0.2981, 1.9319, 0.8909)],
    ],
    [ # 12 met
        [' N  ', 0, (-0.5272, 1.3593, 0.000)],
        [' CA ', 0, (0.000, 0.000, 0.000)],
        [' C  ', 0, (1.5233, 0.000, 0.000)],
        [' O  ', 3, (0.6303, 1.0574, 0.000)],
        [' H  ', 2, (0.4920,-0.8821,  0.0000)],
        [' HA ', 0, (-0.3303, -0.4990,  0.9108)],
        ['1HB ', 4, ( 0.3635, -0.5318,  0.8781)],
        ['2HB ', 4, ( 0.3641, -0.5326, -0.8795)],
        [' CB ', 8, (-0.5331,-0.7727,-1.2048)],
        ['1HG ', 5, (0.3637, -0.5256,  0.8823)],
        ['2HG ', 5, (0.3638, -0.5249, -0.8831)],
        [' CG ', 4, (0.6298,1.3858,0.000)],
        [' SD ', 5, (0.6953,1.6645,0.000)],
        [' CE ', 6, (0.3383,1.7581,0.000)],
        ['1HE ', 6, (1.7054,  2.0532, -0.0063)],
        ['2HE ', 6, (0.1906,  2.3099, -0.9072)],
        ['3HE ', 6, (0.1917, 2.3792, 0.8720)],
    ],
    [ # 13 phe
        [' N  ', 0, (-0.5272, 1.3593, 0.000)],
        [' CA ', 0, (0.000, 0.000, 0.000)],
        [' C  ', 0, (1.5233, 0.000, 0.000)],
        [' O  ', 3, (0.6303, 1.0574, 0.000)],
        [' H  ', 2, (0.4920,-0.8821,  0.0000)],
        [' HA ', 0, (-0.3303, -0.4990,  0.9108)],
        ['1HB ', 4, ( 0.3635, -0.5318,  0.8781)],
        ['2HB ', 4, ( 0.3641, -0.5326, -0.8795)],
        [' CB ', 8, (-0.5150,-0.7729,-1.2156)],
        [' CG ', 4, (0.6060,1.3746, 0.000)],
        [' CD1', 5, (0.7078,1.1928, 0.000)],
        [' CD2', 5, (0.7084,-1.1920, 0.000)],
        [' CE1', 5, (2.0900,1.1940, 0.000)],
        [' CE2', 5, (2.0897,-1.1939, 0.000)],
        [' CZ ', 5, (2.7809, 0.000, 0.000)],
        [' HD1', 5, (0.1613, 2.1362, 0.000)],
        [' HD2', 5, (0.1621, -2.1360, 0.000)],
        [' HE1', 5, (2.6335,  2.1384, 0.000)],
        [' HE2', 5, (2.6344, -2.1378, 0.000)],
        [' HZ ', 5, (3.8700, 0.000, 0.000)],
    ],
    [ # 14 pro
        [' N  ', 0, (-0.5272, 1.3593, 0.000)],
        [' CA ', 0, (0.000, 0.000, 0.000)],
        [' C  ', 0, (1.5233, 0.000, 0.000)],
        [' O  ', 3, (0.6303, 1.0574, 0.000)],
        [' HA ', 0, (-0.3868, -0.5380,  0.8781)],
        ['1HB ', 4, ( 0.3762, -0.5355,  0.8842)],
        ['2HB ', 4, ( 0.3762, -0.5355, -0.8842)],
        [' CB ', 8, (-0.5649,-0.5888,-1.2966)],
        [' CG ', 4, (0.3657,1.4451,0.0000)],
        [' CD ', 5, (0.3744,1.4582, 0.0)],
        ['1HG ', 5, (0.3798, -0.5348,  0.8830)],
        ['2HG ', 5, (0.3798, -0.5348, -0.8830)],
        ['1HD ', 6, (0.3798, -0.5348,  0.8830)],
        ['2HD ', 6, (0.3798, -0.5348, -0.8830)],
    ],
    [ # 15 ser
        [' N  ', 0, (-0.5272, 1.3593, 0.000)],
        [' CA ', 0, (0.000, 0.000, 0.000)],
        [' C  ', 0, (1.5233, 0.000, 0.000)],
        [' O  ', 3, (0.6303, 1.0574, 0.000)],
        [' H  ', 2, (0.4920,-0.8821,  0.0000)],
        [' HA ', 0, (-0.3425, -0.5041,  0.9048)],
        ['1HB ', 4, ( 0.3637, -0.5321,  0.8786)],
        ['2HB ', 4, ( 0.3636, -0.5319, -0.8782)],
        [' CB ', 8, (-0.5146,-0.7595,-1.2073)],
        [' OG ', 4, (0.5021,1.3081, 0.000)],
        [' HG ', 5, (0.2647, 0.9230, 0.000)],
    ],
    [ # 16 thr
        [' N  ', 0, (-0.5272, 1.3593, 0.000)],
        [' CA ', 0, (0.000, 0.000, 0.000)],
        [' C  ', 0, (1.5233, 0.000, 0.000)],
        [' O  ', 3, (0.6303, 1.0574, 0.000)],
        [' H  ', 2, (0.4920,-0.8821,  0.0000)],
        [' HA ', 0, (-0.3364, -0.5015,  0.9078)],
        [' HB ', 4, ( 0.3638, -0.5006,  0.8971)],
        ['1HG2', 4, ( 1.6231, -0.7142, -1.2097)],
        ['2HG2', 4, ( 0.1792, -1.7546, -1.2237)],
        ['3HG2', 4, ( 0.1808, -0.2222, -2.1269)],
        [' CB ', 8, (-0.5172,-0.7952,-1.2130)],
        [' CG2', 4, (0.5334,-0.7239,-1.2267)],
        [' OG1', 4, (0.4804,1.3506,0.000)],
        [' HG1', 5, (0.3194,  0.9056, 0.000)],
    ],
    [ # 17 trp
        [' N  ', 0, (-0.5272, 1.3593, 0.000)],
        [' CA ', 0, (0.000, 0.000, 0.000)],
        [' C  ', 0, (1.5233, 0.000, 0.000)],
        [' O  ', 3, (0.6303, 1.0574, 0.000)],
        [' H  ', 2, (0.4920,-0.8821,  0.0000)],
        [' HA ', 0, (-0.3436, -0.5042,  0.9031)],
        ['1HB ', 4, ( 0.3639, -0.5323,  0.8790)],
        ['2HB ', 4, ( 0.3638, -0.5322, -0.8787)],
        [' CB ', 8, (-0.5136,-0.7712,-1.2173)],
        [' CG ', 4, (0.5984,1.3741, 0.000)],
        [' CD1', 5, (0.8151,1.0921, 0.000)],
        [' CD2', 5, (0.8753,-1.1538, 0.000)],
        [' CE2', 5, (2.1865,-0.6707, 0.000)],
        [' CE3', 5, (0.6541,-2.5366, 0.000)],
        [' NE1', 5, (2.1309,0.7003, 0.000)],
        [' CH2', 5, (3.0315,-2.8930, 0.000)],
        [' CZ2', 5, (3.2813,-1.5205, 0.000)],
        [' CZ3', 5, (1.7521,-3.3888, 0.000)],
        [' HD1', 5, (0.4722, 2.1252,  0.000)],
        [' HE1', 5, ( 2.9291,  1.3191,  0.000)],
        [' HE3', 5, (-0.3597, -2.9356,  0.000)],
        [' HZ2', 5, (4.3053, -1.1462,  0.000)],
        [' HZ3', 5, ( 1.5712, -4.4640,  0.000)],
        [' HH2', 5, ( 3.8700, -3.5898,  0.000)],
    ],
    [ # 18 tyr
        [' N  ', 0, (-0.5272, 1.3593, 0.000)],
        [' CA ', 0, (0.000, 0.000, 0.000)],
        [' C  ', 0, (1.5233, 0.000, 0.000)],
        [' O  ', 3, (0.6303, 1.0574, 0.000)],
        [' H  ', 2, (0.4920,-0.8821,  0.0000)],
        [' HA ', 0, (-0.3305, -0.4992,  0.9112)],
        ['1HB ', 4, ( 0.3642, -0.5327,  0.8797)],
        ['2HB ', 4, ( 0.3637, -0.5321, -0.8785)],
        [' CB ', 8, (-0.5305,-0.7799,-1.2051)],
        [' CG ', 4, (0.6104,1.3840, 0.000)],
        [' CD1', 5, (0.6936,1.2013, 0.000)],
        [' CD2', 5, (0.6934,-1.2011, 0.000)],
        [' CE1', 5, (2.0751,1.2013, 0.000)],
        [' CE2', 5, (2.0748,-1.2011, 0.000)],
        [' OH ', 5, (4.1408, 0.000, 0.000)],
        [' CZ ', 5, (2.7648, 0.000, 0.000)],
        [' HD1', 5, (0.1485, 2.1455,  0.000)],
        [' HD2', 5, (0.1484, -2.1451,  0.000)],
        [' HE1', 5, (2.6200, 2.1450,  0.000)],
        [' HE2', 5, (2.6199, -2.1453,  0.000)],
        [' HH ', 6, (0.3190, 0.9057,  0.000)],
    ],
    [ # 19 val
        [' N  ', 0, (-0.5272, 1.3593, 0.000)],
        [' CA ', 0, (0.000, 0.000, 0.000)],
        [' C  ', 0, (1.5233, 0.000, 0.000)],
        [' O  ', 3, (0.6303, 1.0574, 0.000)],
        [' H  ', 2, (0.4920,-0.8821,  0.0000)],
        [' HA ', 0, (-0.3497, -0.5068,  0.9002)],
        [' CB ', 8, (-0.5105,-0.7712,-1.2317)],
        [' CG1', 4, (0.5326,1.4252, 0.000)],
        [' CG2', 4, (0.5177,-0.7693,1.2057)],
        [' HB ', 4, (0.3541, -0.4754, -0.9148)],
        ['1HG1', 4, (1.6228,  1.4063,  0.000)],
        ['2HG1', 4, (0.1790,  1.9457, -0.8898)],
        ['3HG1', 4, (0.1798, 1.9453, 0.8903)],
        ['1HG2', 4, (1.6073, -0.7659,  1.1989)],
        ['2HG2', 4, (0.1586, -0.2971,  2.1203)],
        ['3HG2', 4, (0.1582, -1.7976,  1.1631)],
    ],
    [ # 20 unk
        [' N  ', 0, (-0.5272, 1.3593, 0.000)],
        [' CA ', 0, (0.000, 0.000, 0.000)],
        [' C  ', 0, (1.5233, 0.000, 0.000)],
        [' O  ', 3, (0.6303, 1.0574, 0.000)],
        [' H  ', 2, (0.4920,-0.8821,  0.0000)],
        [' HA ', 0, (-0.3341, -0.4928,  0.9132)],
        [' CB ', 8, (-0.5289,-0.7734,-1.1991)],
        ['1HB ', 8, (-0.1265, -1.7863, -1.1851)],
        ['2HB ', 8, (-1.6173, -0.8147, -1.1541)],
        ['3HB ', 8, (-0.2229, -0.2744, -2.1172)],
    ],
    [ # 21 mask
        [' N  ', 0, (-0.5272, 1.3593, 0.000)],
        [' CA ', 0, (0.000, 0.000, 0.000)],
        [' C  ', 0, (1.5233, 0.000, 0.000)],
        [' O  ', 3, (0.6303, 1.0574, 0.000)],
        [' H  ', 2, (0.4920,-0.8821,  0.0000)],
        [' HA ', 0, (-0.3341, -0.4928,  0.9132)],
        [' CB ', 8, (-0.5289,-0.7734,-1.1991)],
        ['1HB ', 8, (-0.1265, -1.7863, -1.1851)],
        ['2HB ', 8, (-1.6173, -0.8147, -1.1541)],
        ['3HB ', 8, (-0.2229, -0.2744, -2.1172)],
    ],
]

In [82]:
[len([a for a in x if a[0][1]!='H']) for x in ideal_coords]

[5, 11, 8, 8, 6, 9, 9, 4, 10, 8, 8, 9, 8, 11, 7, 6, 7, 14, 12, 7, 5, 5]

In [75]:
# from https://github.com/cgmartini/martinize.py/

def nsplit(*x):
    return [i.split() for i in x]

mass = {'H': 1, 'C': 12, 'N': 14, 'O': 16, 'S': 32, 'P': 31, 'M': 0}

bb = "N CA C O H H1 H2 H3 O1 O2"
mapping = {
        "ALA":  nsplit(bb + " CB"),
    
        "UNK":  nsplit(bb + " CB"), # added unknown atoms as alanines
        "MAS":  nsplit(bb + " CB"),
    
        "CYS":  nsplit(bb, "CB SG"),
        "ASP":  nsplit(bb, "CB CG OD1 OD2"),
        "GLU":  nsplit(bb, "CB CG CD OE1 OE2"),
        "PHE":  nsplit(bb, "CB CG CD1 HD1", "CD2 HD2 CE2 HE2", "CE1 HE1 CZ HZ"),
        "GLY":  nsplit(bb),
        "HIS":  nsplit(bb, "CB CG", "CD2 HD2 NE2 HE2", "ND1 HD1 CE1 HE1"),
        "HIH":  nsplit(bb, "CB CG", "CD2 HD2 NE2 HE2", "ND1 HD1 CE1 HE1"),     # Charged Histidine.
        "ILE":  nsplit(bb, "CB CG1 CG2 CD CD1"),
        "LYS":  nsplit(bb, "CB CG CD", "CE NZ HZ1 HZ2 HZ3"),
        "LEU":  nsplit(bb, "CB CG CD1 CD2"),
        "MET":  nsplit(bb, "CB CG SD CE"),
        "ASN":  nsplit(bb, "CB CG ND1 ND2 OD1 OD2 HD11 HD12 HD21 HD22"),
        "PRO":  nsplit(bb, "CB CG CD"),
        "HYP":  nsplit(bb, "CB CG CD OD"),
        "GLN":  nsplit(bb, "CB CG CD OE1 OE2 NE1 NE2 HE11 HE12 HE21 HE22"),
        "ARG":  nsplit(bb, "CB CG CD", "NE HE CZ NH1 NH2 HH11 HH12 HH21 HH22"),
        "SER":  nsplit(bb, "CB OG HG"),
        "THR":  nsplit(bb, "CB OG1 HG1 CG2"),
        "VAL":  nsplit(bb, "CB CG1 CG2"),
        "TRP":  nsplit(bb, "CB CG CD2", "CD1 HD1 NE1 HE1 CE2", "CE3 HE3 CZ3 HZ3", "CZ2 HZ2 CH2 HH2"),
        "TYR":  nsplit(bb, "CB CG CD1 HD1", "CD2 HD2 CE2 HE2", "CE1 HE1 CZ OH HH"),
        "DA": nsplit("P OP1 OP2 O5' O3' O1P O2P", "C5' O4' C4'", "C3' C2' C1'", "N9 C4", "C8 N7 C5", "C6 N6 N1", "C2 N3"),
        "DG": nsplit("P OP1 OP2 O5' O3' O1P O2P", "C5' O4' C4'", "C3' C2' C1'", "N9 C4", "C8 N7 C5", "C6 O6 N1", "C2 N2 N3"),
        "DC": nsplit("P OP1 OP2 O5' O3' O1P O2P", "C5' O4' C4'", "C3' C2' C1'", "N1 C6", "C5 C4 N4", "N3 C2 O2"),
        "DT": nsplit("P OP1 OP2 O5' O3' O1P O2P", "C5' O4' C4'", "C3' C2' C1'", "N1 C6", "C5 C4 O4 C7 C5M", "N3 C2 O2"),
}
pseudoatom_types = {
        "ALA":  ['P4'],
    
        "UNK":  ['P4'],
        "MAS":  ['P4'],
    
        "CYS":  ['P5','C5'],
        "ASP":  ['P5','Qa'],
        "GLU":  ['P5','Qa'],
        "PHE":  ['P5','SC4','SC4','SC4'],
        "GLY":  ['P5'],
        "HIS":  ['P5','SC4','SP1','SP1'],
        "ILE":  ['P5','AC1'],
        "LYS":  ['P5','C3','P1'],
        "LEU":  ['P5','AC1'],
        "MET":  ['P5','C5'],
        "ASN":  ['P5','P5'],
        "PRO":  ['P5','AC2'],
        "GLN":  ['P5','P4'],
        "ARG":  ['P5','N0','Qd'],
        "SER":  ['P5','P1'],
        "THR":  ['P5','P1'],
        "VAL":  ['P5','AC2'],
        "TRP":  ['P5','SC4','SP1','SC4','SC4'],
        "TYR":  ['P5','SC4','SC4','SP1'],
        "DA": ["Q0","SN0","SC2", "TN0", "TA2", "TA3", "TNa"],
        "DG": ["Q0","SN0","SC2", "TN0", "TG2", "TG3", "TNa"],
        "DC": ["Q0","SN0","SC2", "TN0", "TY2", "TY3"],
        "DT": ["Q0","SN0","SC2", "TN0", "TT2", "TT3"],
        "A": ["Q0","SN0","SNda", "TN0", "TA2", "TA3", "TNa"],
        "G": ["Q0","SN0","SNda", "TN0", "TG2", "TG3", "TNa"],
        "C": ["Q0","SN0","SNda", "TN0", "TY2", "TY3"],
        "U": ["Q0","SN0","SNda", "TN0", "TT2", "TT3"],
}


In [90]:
set([y for x in pseudoatom_types for y in pseudoatom_types[x] if x in num2aa])

{'AC1', 'AC2', 'C3', 'C5', 'N0', 'P1', 'P4', 'P5', 'Qa', 'Qd', 'SC4', 'SP1'}

In [77]:
def martinize(seq, atoms, coords):
    ps_coords=[]
    ps_types=[]
    for i, (aa, ass, xyzs) in enumerate(zip(seq, atoms, coords)):
        yobadict[aa]={'xyz': [],'types': []}
        av=[[[0,0,0],0]]*len(mapping[aa])
        for a, xyz in zip(ass, xyzs):
            for j, m in enumerate(mapping[aa]):
                if a.strip() in m:
                    av[j][0][0]+=xyz[0]*mass[a.strip()[0]]
                    av[j][0][1]+=xyz[1]*mass[a.strip()[0]]
                    av[j][0][2]+=xyz[2]*mass[a.strip()[0]]
                    av[j][1]+=mass[a.strip()[0]]
        av=[[ps[0][0]/ps[1],ps[0][1]/ps[1],ps[0][2]/ps[1]] for ps in av]
        ps_coords.append(av)
        ps_types.append(pseudoatom_types[aa])
    return ps_coords, ps_types

In [None]:
ps, ts=martinize(num2aa, 
                 [[y[0] for y in x] for x in ideal_coords],
                 [[y[2] for y in x] for x in ideal_coords] )

print(list(zip(num2aa, ps, ts)))

In [97]:
(0.24145E-00/4/5.6*2)**(1/6)

0.5275555519127099

In [98]:
2**(1/6)*0.47

0.5275571627054053

### Write functions to get bb atoms from protein structures

In [307]:
importlib.reload(sys.modules['data'])
from data import AtomSurfaceDataset


In [308]:
class same_dict:
    def __getitem__(self, idx):
        return idx
    def get(self, idx, d=None):
        return idx
    
class ReshapeBB:
    def __init__(self,encoder={'N': 0, 'CA': 1, 'C': 2}):
        self.encoder=encoder
        
    def __call__(self,data):
        
        seq, idx=torch.sort(data['atom_resid'])
        seq=seq.unique()
        for key in data:
            data[key]=data[key][idx]
    
        for aa in seq:
            if len(data['atom_resid']==aa)<3:
                for key in data:
                    data[key]=data[key][data['atom_resid']!=aa]      
        assert len(seq)*3==len(data['atom_resid'])
    
        bb_xyz=torch.stack((data['atom_xyz'][data['atom_type']==self.encoder['N']],
                            data['atom_xyz'][data['atom_type']==self.encoder['CA']],
                            data['atom_xyz'][data['atom_type']==self.encoder['C']]), 
                           dim=1)
        for key in list(data.keys()):
            if 'atom' in key:
                data.pop(key)
        data['seq']=seq
        data['bb_xyz']=bb_xyz
        return data    

In [309]:
with open('prot_nrd.txt','r') as f:
    a=f.readlines()
a=[x.strip()+'.pdb' for x in a]

In [None]:
dataset=AtomSurfaceDataset(storage='protein_data_bb.pkl', encoders={
    'atom_names':[{'name': 'atom_type',
                      'encoder': {'N': 0, 'CA': 1, 'C': 2, '-': 0}
                     },
                  {'name': 'mask',
                      'encoder': {'N': 1, 'CA': 1, 'C': 1, '-': 0}
                     }],
    'atom_resnames':[{'name': 'atom_resname',
                      'encoder': {a: i for i, a in enumerate(num2aa+['-'])}
                     }
                    ],
    'atom_resids':[{'name': 'atom_resid',
                      'encoder': same_dict()
                     }
                    ]},
                    list=a,
                    pre_transform=ReshapeBB())

 39%|███████████████▍                        | 900/2337 [12:29<23:15,  1.03it/s]

Failed to load 3OQ5_D.pdb


 62%|████████████████████████▎              | 1455/2337 [22:26<11:22,  1.29it/s]