In [1]:
[a for a in [1,2,3] if a not in {'OX'}]

[1, 2, 3]

In [2]:
import torch as th
import itertools
import atomium
import pandas as pd
import params

In [5]:
aa_atoms = ['N','C','O','CA','CB','CG','CD','CE','CZ','OD','NH','NE','OG','ND','SG','OE','CH','NZ','OH','SD','OX']

#atom_id = {ch : i for i, ch in enumerate(charge.keys())}
AT_INT = {k : i for i, k in enumerate(aa_atoms)}
INT_AT = {v:k for k, v in AT_INT.items()}

at_at_list = [[f'{at1}-{at2}' for at1 in aa_atoms] for at2 in aa_atoms]
AT_AT_name = list(itertools.chain(*at_at_list))
ATAT_INT = {k : i for i,k in enumerate(AT_AT_name)}
num_atat_feats = len(ATAT_INT)


In [6]:
HYDROPHOBIC = {'ALA', 'VAL', 'LEU', 'ILE', 'MET', 'PHE', 'TRP', 'RPO', 'TYR'}
AROMATIC = {'TRP', 'TYR', 'PHE'}
CATION_PI = {'CG', 'CD', 'CE', 'CZ'}
SALT_BRIDGE_C1 = {'NH', 'NE'}
SALT_BRIDGE_C2 = {'OE', 'OE'}
HYDROGEN_ACCEPTOR = {'NH', 'OH'}
HYDROGEN_DONOR = {'F', 'N', 'O'}

In [20]:
'''script for calculating residue - residue interactions'''
import atomium
import torch as th
from params import (HYDROPHOBIC,
                     AROMATIC,
                     CATION_PI,
                     SALT_BRIDGE_C1,
                     SALT_BRIDGE_C2,
                     CHARGE,
                    VDW_RADIUS,
                    HYDROGEN_ACCEPTOR,
                    HYDROGEN_DONOR)


nan_type = float('nan')
atom_id = {ch : i for i, ch in enumerate(CHARGE.keys())}
EPSILON = th.Tensor([78*10e-2]) # unit Farad / angsterm

t = 8
data = atomium.fetch('2lyz')
data = data.model

atoms = []
name = []
ca_xyz, cb_xyz = [], []
residues = []
residues_name = []
is_side_chain = []
res_at_num = []
for chain in data.chains():
    for i, res in enumerate(chain.residues()):
        r_at_name = [r.name for r in res.atoms()]
        res_at_num.append(len(r_at_name))
        for atom in res.atoms():
            n = atom.name
            if n == 'CA':
                ca_xyz.append(atom.location)
            elif n == 'CB':
                cb_xyz.append(atom.location)
            elif len(n) == 3:
                n = n[:2]
            name.append(n)
            is_side_chain.append(atom.is_side_chain)
            atoms.append(atom.location)
            residues.append(i)
            residues_name.append(res.name)
        if 'CB' not in r_at_name:
            cb_xyz.append((nan_type, nan_type, nan_type))
        if 'CA' not in r_at_name:
            raise KeyError('missing CA atom')

name_base = [n[0] for n in name]
at_charge = [CHARGE[n] for n in name_base]
at_vdw = [VDW_RADIUS[n] for n in name_base]
atom_arr = [atom_id[n] for n in name_base]

res_id = th.LongTensor(residues)
res_xyz = th.FloatTensor(ca_xyz)
res_dist = th.cdist(res_xyz, res_xyz)
res_cb = th.FloatTensor(cb_xyz)

is_res_hf = [True if r in HYDROPHOBIC else False for r in residues_name]
#is_at_rg = [True if at is False else False for at in is_side_chain ]
is_at_hb_a = [True if r in HYDROGEN_ACCEPTOR else False for r in name]
is_at_hb_d = [True if r in HYDROGEN_DONOR else False for r in name]
is_res_ar = [True if r in AROMATIC else False for r in residues_name]

is_res_cpi = [True if at in CATION_PI else False for at in name]
is_res_arg = [True if r in {'ARG'} else False for r in residues_name]

is_at_sb_c1 = [True if at in SALT_BRIDGE_C1 else False for at in name]
is_res_sb_c1 = [True if at in {'ARG', 'LYS'} else False for at in residues_name]
is_at_sb_c2 = [True if at in {'ARG', 'GLU'} else False for at in name]
is_res_sb_c2 = [True if at in SALT_BRIDGE_C2 else False for at in residues_name]

at_xyz = th.FloatTensor(atoms)
at_dist = th.cdist(at_xyz, at_xyz)
at_id = th.LongTensor(atom_arr)
sigma = th.FloatTensor(at_vdw)
at_is_side = th.BoolTensor(is_side_chain)
at_is_hba = th.BoolTensor(is_at_hb_a)
at_is_hbd = th.BoolTensor(is_at_hb_d)

at_dist_inv = 1/(at_dist + 1e-6)
# set inverse of the atom self distance to zero to avoid nan/inf when summing
at_dist_inv.fill_diagonal_(0) 
atat_charge = th.FloatTensor(at_charge).view(-1, 1)
atat_charge = atat_charge * atat_charge.view(1, -1)
sigma = sigma.view(-1, 1) * sigma.view(1, -1)

lj_r = sigma*at_dist_inv * (at_dist < 10)
lj6 = th.pow(lj_r, 6) 
lj12 = th.pow(lj_r, 12)
#print('lj', lj_r)
#print('lj12', lj12)
#print('lj12', lj6)
disulfde = (at_id == 4) & (at_dist < 2.2)
hydrophobic = (at_dist < 5.0) & (at_is_side == False) & th.BoolTensor(is_res_hf)
cation_pi = (at_dist < 6) & th.BoolTensor(is_res_cpi)
arg_arg = (at_dist < 5.0) & th.BoolTensor(is_res_arg)
hbond = at_is_hba.view(-1, 1) & at_is_hba.view(1, -1)

sb_tmp1 = th.BoolTensor(is_at_sb_c1).view(-1, 1) & th.BoolTensor(is_at_sb_c2).view(1, -1)
sb_tmp2 = th.BoolTensor(is_res_sb_c1).view(-1, 1) & th.BoolTensor(is_res_sb_c2).view(1, -1)

salt_bridge = sb_tmp1 & (at_dist < 5.0) & sb_tmp2

feats = th.cat((disulfde.unsqueeze(2),
               hydrophobic.unsqueeze(2),
               cation_pi.unsqueeze(2),
               arg_arg.unsqueeze(2),
               salt_bridge.unsqueeze(2),
               hbond.unsqueeze(2)), dim=2)
feats = feats.float()
coulomb_energy =  (1/3.14*EPSILON) * atat_charge * at_dist_inv
lenard_jones_energy = 1e-2 * (lj12 - lj6 )
energy_sum = coulomb_energy + lenard_jones_energy
feats = th.cat((feats, 
                coulomb_energy.unsqueeze(2),
               lenard_jones_energy.unsqueeze(2),
               energy_sum.unsqueeze(2)),
               dim=2)
# change feature resolution
# from atomic level to residue level
efeat_list = list()
first_dim_split = feats.split(res_at_num, 0)
for i in range(len(res_at_num)):
    efeat_list.extend(list(first_dim_split[i].split(res_at_num, 1)))

u, v = th.where(res_dist < t)
uv = th.where(res_dist.ravel() < t)[0]
feats_at = th.cat([efeat_list[e].sum((0,1), keepdim=True) for e in uv], dim=0)
efeats = th.zeros_like(res_dist)
# gather residue level feature, such as edge criteria
cb1 = th.linalg.norm(res_cb - res_xyz, dim=1, keepdim=True)
cb2 = cb1.clone().swapdims(0, 1)
tn_cb12 = cb1 / (cb2 + 1e-2)
tn_cb12[th.isnan(tn_cb12)] = -1
inv_ca12 = 1/(res_dist + 1e-5)
inv_ca12.fill_diagonal_(0)
res_id_short = th.arange(0, res_id.max()+1, 1)
is_seq = th.abs(res_id_short.unsqueeze(0) - res_id_short.unsqueeze(1))
is_self = is_seq == 0
is_seq_0 = is_seq == 1
is_seq_1 = is_seq == 2
is_struct_0 = ~is_seq_0
feats_res = th.cat((tn_cb12.unsqueeze(2),
                   inv_ca12.unsqueeze(2),
                    is_self.unsqueeze(2),
                   is_seq_0.unsqueeze(2),
                   is_seq_1.unsqueeze(2),
                   is_struct_0.unsqueeze(2)), dim=2)
feats_res = feats_res[u,v]
feats_all = th.cat((feats_at.squeeze(), feats_res), dim=-1)

In [24]:
uv

tensor([    0,     1,     2,  ..., 16638, 16639, 16640])

In [21]:
feats_all.shape

torch.Size([1387, 15])

In [None]:
is_res_hf = [True if r in HYDROPHOBIC else False for r in residues_name]
#is_at_rg = [True if at is False else False for at in is_side_chain ]
is_at_hb_a = [True if r in HYDROGEN_ACCEPTOR else False for r in name]
is_at_hb_d = [True if r in HYDROGEN_DONOR else False for r in name]

is_res_ar = [True if r in AROMATIC else False for r in residues_name]

is_res_cpi = [True if at in CATION_PI else False for at in name]

is_res_arg = [True if r in {'ARG'} else False for r in residues_name]

is_at_sb_c1 = [True if at in SALT_BRIDGE_C1 else False for at in name]
is_res_sb_c1 = [True if at in {'ARG', 'LYS'} else False for at in residues_name]
is_at_sb_c2 = [True if at in {'ARG', 'GLU'} else False for at in name]
is_res_sb_c2 = [True if at in SALT_BRIDGE_C2 else False for at in residues_name]

In [11]:
res_id = th.LongTensor(residues)
res_xyz = th.FloatTensor(ca_xyz)
res_dist = th.cdist(res_xyz, res_xyz)
res_cb = th.FloatTensor(cb_xyz)

In [16]:
ca_xyz

[(2.386, 10.407, 9.247),
 (2.387, 13.773, 7.48),
 (-1.159, 15.139, 7.344),
 (-2.756, 17.229, 4.601),
 (-4.131, 20.653, 5.5),
 (-7.708, 19.507, 4.913),
 (-7.368, 16.019, 6.393),
 (-6.046, 17.23, 9.745),
 (-8.771, 19.878, 9.908),
 (-11.387, 17.163, 9.423),
 (-9.749, 14.968, 12.051),
 (-9.266, 17.634, 14.717),
 (-12.939, 18.423, 14.147),
 (-13.997, 14.793, 14.524),
 (-12.024, 14.969, 17.772),
 (-13.836, 17.826, 19.506),
 (-11.464, 20.717, 18.811),
 (-13.423, 23.137, 16.628),
 (-14.105, 25.897, 19.156),
 (-12.733, 24.019, 22.165),
 (-12.235, 26.657, 24.859),
 (-13.192, 29.142, 22.143),
 (-10.459, 28.582, 19.56),
 (-11.578, 28.139, 15.955),
 (-10.558, 24.97, 14.12),
 (-8.206, 27.147, 12.075),
 (-6.336, 28.276, 15.184),
 (-5.416, 24.609, 15.552),
 (-4.46, 23.679, 11.995),
 (-2.381, 26.86, 11.989),
 (-0.472, 26.09, 15.183),
 (0.226, 22.552, 13.989),
 (1.503, 23.646, 10.582),
 (4.058, 26.051, 12.04),
 (4.851, 23.738, 14.949),
 (5.359, 20.464, 13.089),
 (4.475, 20.822, 9.415),
 (1.605, 18.366, 

In [10]:
name_base = [n[0] for n in name]
at_charge = [charge[n] for n in name_base]
at_vdw = []
atom_arr = [atom_id[n] for n in name_base]

at_xyz = th.FloatTensor(atoms)
at_dist = th.cdist(at_xyz, at_xyz)
at_id = th.LongTensor(atom_arr)
at_is_side = th.BoolTensor(is_side_chain)
at_is_hba = th.BoolTensor(is_at_hb_a)
at_is_hbd = th.BoolTensor(is_at_hb_d)

res_id = th.LongTensor(residues)
res_xyz = th.FloatTensor(ca_xyz)
res_dist = th.cdist(res_xyz, res_xyz)
res_cb = th.FloatTensor(cb_xyz)

NameError: name 'charge' is not defined

In [9]:
res_xyz

NameError: name 'res_xyz' is not defined

In [None]:
atat_charge = th.FloatTensor(at_charge).view(-1, 1)
atat_charge = atat_charge * th.FloatTensor(at_charge).view(1, -1)

at_dist_inv = 1e-6/th.pow(at_dist + 1e-2, 2)

In [None]:
disulfde = (at_id == 4) & (at_dist < 2.2)
hydrophobic = (at_dist < 5.0) & (at_is_side == False) & th.BoolTensor(is_res_hf)
cation_pi = (at_dist < 6) & th.BoolTensor(is_res_cpi)
arg_arg = (at_dist < 5.0) & th.BoolTensor(is_res_arg)
hbond = at_is_hba,view(-1, 1) & at_is_hba.view(1, -1)

sb_tmp1 = th.BoolTensor(is_at_sb_c1).view(-1, 1) & th.BoolTensor(is_at_sb_c2).view(1, -1)
sb_tmp2 = th.BoolTensor(is_res_sb_c1).view(-1, 1) & th.BoolTensor(is_res_sb_c2).view(1, -1)

salt_bridge = sb_tmp1 & (at_dist < 5.0) & sb_tmp2

feats = th.cat((disulfde.unsqueeze(2),
               hydrophobic.unsqueeze(2),
               cation_pi.unsqueeze(2),
               arg_arg.unsqueeze(2),
               salt_bridge.unsqueeze(2),
               hbond.unsqueeze(2)), dim=2)
feats = feats.float()
coulomb_force = at_dist_inv * atat_charge
feats = th.cat((feats, 
                coulomb_force.unsqueeze(2)), dim=2)

In [None]:
efeat_list = list()
first_dim_split = feats.split(res_at_num, 0)
for i in range(len(res_at_num)):
    efeat_list.extend(list(first_dim_split[i].split(res_at_num, 1)))

In [None]:
u, v = th.where(res_dist < 8)
uv = th.where(res_dist.ravel() < 8)[0]

In [None]:
feats_at = th.cat([efeat_list[e].sum((0,1), keepdim=True) for e in uv], dim=0)

In [None]:
feats_at.shape

In [None]:
efeats = th.zeros_like(res_dist)


In [None]:
cb1 = th.linalg.norm(res_cb - res_xyz, dim=1, keepdim=True)
cb2 = cb1.clone().swapdims(0, 1)
tn_cb12 = cb1 / (cb2 + 1e-2)
tn_cb12[th.isnan(tn_cb12)] = -1

inv_ca12 = 1/(res_dist - 1e-3)

res_id_short = th.arange(0, res_id.max()+1, 1)
is_seq = th.abs(res_id_short.unsqueeze(0) - res_id_short.unsqueeze(1))
is_self = is_seq == 0
is_seq_0 = is_seq == 1
is_seq_1 = is_seq == 2
is_struct_0 = ~is_seq_0


feats_res = th.cat((tn_cb12.unsqueeze(2),
                   inv_ca12.unsqueeze(2),
                    is_self.unsqueeze(2),
                   is_seq_0.unsqueeze(2),
                   is_seq_1.unsqueeze(2),
                   is_struct_0.unsqueeze(2)), dim=2)
feats_res = feats_res[u,v]

In [None]:
feats_all = th.cat((feats_at.squeeze(), feats_res), dim=-1)

In [None]:
feats_all.mean(0)

In [None]:
dist_where = (dist < 2)
res1, res2 = th.nonzero(dist_where, as_tuple=True)

In [None]:
ca_adj = res_dist < 8

In [None]:
th.index_copy_()

In [None]:
aa1, aa2 = res_num[res1], res_num[res2]
at1, at2 = atom_id[res1], atom_id[res2]
inter = aa1 != aa2
inter1, inter2 = aa1[inter], aa2[inter]

In [None]:
print('num:', inter.shape)
bond_list = list()
bond_dict = {i : [] for i in range(num_resid)}
for i in range(inter1.shape[0]-1):
    rid1, rid2 = inter1[i].item(), inter2[i].item()
    atid1, atid2 = at1[i].item(), at2[i].item()
    #print(rid1, ' - ', rid2, ': ', INT_AT[atid1], '-',  INT_AT[atid2])
    atat_name = INT_AT[atid1]+'-'+INT_AT[atid2]
    bond_list.append(atat_name)
    bond_dict[rid1].append(atat_name)

In [None]:


def bond_to_vector(bond_dict):
    stack = list()
    for res, bond_list in bond_dict.items():
        
        bonds_unique = set(bond_list)
        bonds = th.zeros(params.num_atat_feats)
        
        if bonds_unique:
            
            bonds_id = [params.ATAT_INT[b] for b in bonds_unique if b in params.ATAT_INT]
            bonds[bonds_id] = 1
        stack.append(bonds.unsqueeze(0))
    feats = th.cat(stack, 0)
    return feats