In [1]:
import sys
pl_path = "/mnt/c/Users/mm851/PycharmProjects/ProteinLearning"
if pl_path not in sys.path:
    sys.path.append(pl_path)

In [2]:
from protein_learning.common.data.datasets.protein_dataset import ProteinDataset
from protein_learning.common.protein_constants import ALL_ATOMS
from protein_learning.models.sc_packing.sc_feature_generator import SCFeatureGenerator
from protein_learning.common.data.protein import Protein
from protein_learning.protein_utils.sidechain_utils import (
    align_symmetric_sidechains,
    per_residue_chi_indices_n_mask,

)
from protein_learning.common.data.model_data import ExtraInput
from protein_learning.common.helpers import batched_index_select, exists
from protein_learning.protein_utils.dihedral.angle_utils import signed_dihedral_4
from einops import repeat, rearrange # noqa
from typing import Any, NamedTuple, Optional
from torch import Tensor
import torch



In [3]:
class ExtraSCInput(ExtraInput):
    def __init__(
            self,
            alt_truth_atom_coords: Tensor,
            chi_indices: Optional[Tensor],
            chi_mask: Optional[Tensor],
    ):
        self.alt_truth_atom_coords = alt_truth_atom_coords
        self.chi_mask = chi_mask
        self.chi_indices = chi_indices

    def get_sc_dihedral(self, coords: Tensor) -> Tensor:
        m = self.chi_mask.shape[0]
        assert m <= 4, f"{self.chi_mask.shape}"
        rep_coords = repeat(coords, "b n a c -> (m b) n a c", m=m)
        assert rep_coords.shape[:2] == self.chi_indices.shape[:2]
        assert self.chi_indices.shape[2] == 4
        select_coords = batched_index_select(rep_coords, self.chi_indices, dim=2)
        assert select_coords.shape[:3] == self.chi_indices.shape


    def crop(self, start: int, end: int):
        self.alt_truth_atom_coords = self.alt_truth_atom_coords

    def to(self, device: Any):
        self.alt_truth_atom_coords = self.alt_truth_atom_coords.to(device)


In [4]:
def get_sc_dihedral(coords: Tensor, chi_mask: Tensor, chi_indices: Tensor) -> Tensor:
    m = chi_mask.shape[0]
    assert m <= 4, f"{chi_mask.shape}"
    rep_coords = repeat(coords, "b n a c -> (m b) n a c", m=m)
    print(rep_coords.shape)
    assert rep_coords.shape[:2] == chi_indices.shape[:2]
    assert chi_indices.shape[2] == 4
    select_coords = batched_index_select(rep_coords, chi_indices, dim=2)
    assert select_coords.shape[:3] == chi_indices.shape
    print(select_coords.shape, chi_indices.shape)
    select_coords = rearrange(select_coords,"b n a c ->  a b n c")
    return signed_dihedral_4(ps=select_coords)

In [15]:
seq = "ARNDCQEGHILKMFPSTWYV"
n,a = len(seq),36
atom_masks = torch.randn(1,n,a)>0
coords = torch.randn(1,n,a,3)
chi_indices, chi_mask = per_residue_chi_indices_n_mask(
    coord_mask=atom_masks, seq=seq
)
get_sc_dihedral(coords, chi_mask, chi_indices).shape

torch.Size([4, 20, 36, 3])
torch.Size([4, 20, 4, 3]) torch.Size([4, 20, 4])


torch.Size([4, 20])

In [6]:
s="""
    [0.0, 0.0, 0.0, 0.0],  # ALA
    [0.0, 0.0, 0.0, 0.0],  # ARG
    [0.0, 0.0, 0.0, 0.0],  # ASN
    [0.0, 1.0, 0.0, 0.0],  # ASP
    [0.0, 0.0, 0.0, 0.0],  # CYS
    [0.0, 0.0, 0.0, 0.0],  # GLN
    [0.0, 0.0, 1.0, 0.0],  # GLU
    [0.0, 0.0, 0.0, 0.0],  # GLY
    [0.0, 0.0, 0.0, 0.0],  # HIS
    [0.0, 0.0, 0.0, 0.0],  # ILE
    [0.0, 0.0, 0.0, 0.0],  # LEU
    [0.0, 0.0, 0.0, 0.0],  # LYS
    [0.0, 0.0, 0.0, 0.0],  # MET
    [0.0, 1.0, 0.0, 0.0],  # PHE
    [0.0, 0.0, 0.0, 0.0],  # PRO
    [0.0, 0.0, 0.0, 0.0],  # SER
    [0.0, 0.0, 0.0, 0.0],  # THR
    [0.0, 0.0, 0.0, 0.0],  # TRP
    [0.0, 1.0, 0.0, 0.0],  # TYR
    [0.0, 0.0, 0.0, 0.0],  # VAL
    [0.0, 0.0, 0.0, 0.0],  # UNK
"""
s = s.split("\n")
for line in s:
    line = line.strip()
    if len(line)>0:
        arr,name = line.split("#")
        print(f"'{name.strip()}':{arr[:-1]}")

'ALA':[0.0, 0.0, 0.0, 0.0], 
'ARG':[0.0, 0.0, 0.0, 0.0], 
'ASN':[0.0, 0.0, 0.0, 0.0], 
'ASP':[0.0, 1.0, 0.0, 0.0], 
'CYS':[0.0, 0.0, 0.0, 0.0], 
'GLN':[0.0, 0.0, 0.0, 0.0], 
'GLU':[0.0, 0.0, 1.0, 0.0], 
'GLY':[0.0, 0.0, 0.0, 0.0], 
'HIS':[0.0, 0.0, 0.0, 0.0], 
'ILE':[0.0, 0.0, 0.0, 0.0], 
'LEU':[0.0, 0.0, 0.0, 0.0], 
'LYS':[0.0, 0.0, 0.0, 0.0], 
'MET':[0.0, 0.0, 0.0, 0.0], 
'PHE':[0.0, 1.0, 0.0, 0.0], 
'PRO':[0.0, 0.0, 0.0, 0.0], 
'SER':[0.0, 0.0, 0.0, 0.0], 
'THR':[0.0, 0.0, 0.0, 0.0], 
'TRP':[0.0, 0.0, 0.0, 0.0], 
'TYR':[0.0, 1.0, 0.0, 0.0], 
'VAL':[0.0, 0.0, 0.0, 0.0], 
'UNK':[0.0, 0.0, 0.0, 0.0], 


In [7]:
AA_ALPHABET = "ARNDCQEGHILKMFPSTWYV-"
N_AMINO_ACID_KEYS = 21
BB_ATOMS = ['N', 'CA', 'C', 'O']
BB_ATOM_POSNS = {a: i for i, a in enumerate(BB_ATOMS)}
SC_ATOMS = ['CE3', 'CZ', 'SD', 'CB', 'CD1', 'NH1', 'OG1', 'CE1', 'OE1', 'CZ2', 'OH', 'CG',
            'CZ3', 'NE', 'CH2', 'OD1', 'NH2', 'ND2', 'OG', 'CG2', 'OE2', 'CD2', 'ND1', 'NE2',
            'NZ', 'CD', 'CE2', 'CE', 'OD2', 'SG', 'NE1', 'CG1']
SC_ATOM_POSNS = {a: i for i, a in enumerate(SC_ATOMS)}

AA3LetterCode = ['ALA', 'ARG', 'ASN', 'ASP', 'ASX', 'CYS', 'GLU', 'GLN', 'GLX', 'GLY', 'HIS', 'ILE', 'LEU', 'LYS',
                 'MET', 'PHE', 'PRO', 'SER', 'THR', 'TRP', 'TYR', 'VAL', "UNK"]
AA1LetterCode = ['A', 'R', 'N', 'D', 'B', 'C', 'E', 'Q', 'Z', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W',
                 'Y', 'V', "-"]
VALID_AA_3_LETTER = set(AA3LetterCode)
VALID_AA_1_LETTER = set(AA1LetterCode)

ALL_ATOMS = BB_ATOMS + SC_ATOMS
ALL_ATOM_POSNS = {a: i for i, a in enumerate(ALL_ATOMS)}

THREE_TO_ONE = {three: one for three, one in zip(AA3LetterCode, AA1LetterCode)}
ONE_TO_THREE = {one: three for three, one in THREE_TO_ONE.items()}
chi_pi_periodic = {
    'ALA': [0.0, 0.0, 0.0, 0.0],
    'ARG': [0.0, 0.0, 0.0, 0.0],
    'ASN': [0.0, 0.0, 0.0, 0.0],
    'ASP': [0.0, 1.0, 0.0, 0.0],
    'CYS': [0.0, 0.0, 0.0, 0.0],
    'GLN': [0.0, 0.0, 0.0, 0.0],
    'GLU': [0.0, 0.0, 1.0, 0.0],
    'GLY': [0.0, 0.0, 0.0, 0.0],
    'HIS': [0.0, 0.0, 0.0, 0.0],
    'ILE': [0.0, 0.0, 0.0, 0.0],
    'LEU': [0.0, 0.0, 0.0, 0.0],
    'LYS': [0.0, 0.0, 0.0, 0.0],
    'MET': [0.0, 0.0, 0.0, 0.0],
    'PHE': [0.0, 1.0, 0.0, 0.0],
    'PRO': [0.0, 0.0, 0.0, 0.0],
    'SER': [0.0, 0.0, 0.0, 0.0],
    'THR': [0.0, 0.0, 0.0, 0.0],
    'TRP': [0.0, 0.0, 0.0, 0.0],
    'TYR': [0.0, 1.0, 0.0, 0.0],
    'VAL': [0.0, 0.0, 0.0, 0.0],
    'UNK': [0.0, 0.0, 0.0, 0.0],
}
chi_pi_periodic.update({THREE_TO_ONE[r]: arr for r, arr in chi_pi_periodic.items()})

In [8]:
def get_chi_pi_periodic_mask(seq : str):
    masks = [chi_pi_periodic[r] for r in seq]
    return torch.tensor(masks).bool()

get_chi_pi_periodic_mask(seq)

tensor([[False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False,  True, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False,  True, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False,  True, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False,  True, False, False],
        [False, False, False, False]])

In [9]:
from protein_learning.protein_utils.sidechain_utils import (
    swap_symmetric_atoms,
    per_residue_chi_indices_n_mask,
    get_chi_pi_periodic_mask,
    get_symmetric_residue_keys_n_indices,
)

In [14]:
symm_keys_n_indices = get_symmetric_residue_keys_n_indices(seq)
alt_truth_coords = swap_symmetric_atoms(coords.squeeze(), symm_keys_n_indices)
torch.norm(coords-swap_symmetric_atoms(alt_truth_coords, symm_keys_n_indices))
equal_mask = torch.norm(coords-alt_truth_coords,dim=-1)<=1e-6
print(equal_mask.numel(), equal_mask[equal_mask].numel())

720 694
