In [137]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Scanning the Chemistry of Proteins

In [143]:
from moleculib.protein.dataset import MonomerDataset
from moleculib.protein.transform import (
    ProteinCrop,
    DescribeChemistry,
)

from collections import defaultdict
from moleculib.protein.alphabet import all_residues
import numpy as np
import matplotlib.pyplot as plt
from einops import rearrange


def norm(vector: np.ndarray) -> np.ndarray:
    norms_sqr = np.sum(vector**2, axis=-1)
    norms = norms_sqr ** 0.5
    return norms

def normalize(vector: np.ndarray) -> np.ndarray:
    return vector / norm(vector)[..., None]


data_path = '/mas/projects/molecularmachines/db/PDB'
min_seq_len = 16
max_seq_len = sequence_length = 512
dataset = MonomerDataset(
    base_path=data_path,
    attrs="all",
    max_resolution=1.7,
    min_sequence_length=min_seq_len,
    max_sequence_length=max_seq_len,
    frac=1.0,
    transform=[
        ProteinCrop(crop_size=sequence_length),
        DescribeChemistry(),
    ],
)

In [154]:
datum.residue_token

array([14, 13, 21,  2, 21,  2,  9,  5, 17,  2,  9,  8,  9, 12,  2, 13, 21,
       12,  2,  5, 10, 12, 13,  5,  3, 15,  8, 21, 17,  8, 11, 17,  4, 12,
       17,  5,  3, 21,  2, 17,  2, 21, 12,  5,  9, 18, 20,  5,  3,  2, 11,
       12, 21,  6,  9, 18,  9, 11,  9, 21,  6, 11,  2,  2,  4, 13, 21, 16,
        9, 11,  3,  2,  2, 12, 18, 10,  5, 18, 20, 17,  2,  8,  3,  2,  2,
       12, 17,  4,  4,  2,  7, 11, 11, 18, 14,  9,  2,  3, 21, 11,  9,  2,
        8, 21,  2, 13, 18, 11,  2,  5,  2, 15, 12,  2,  7, 18, 15])

### Scan Chemistry

In [158]:
from einops import rearrange
import numpy as np

from moleculib.protein.datum import ProteinDatum
import jax.numpy as jnp
from typing import List


class ProteinMetric:
    def __call__(self, datum: ProteinDatum):
        raise NotImplementedError("ProteinMetric is abstract")

def norm(vector: np.ndarray) -> np.ndarray:
    norms_sqr = np.sum(vector**2, axis=-1)
    norms = norms_sqr ** 0.5
    return norms

def normalize(vector: np.ndarray) -> np.ndarray:
    return vector / norm(vector)[..., None]

def measure_bonds(coord, idx):
    v, u = idx.T
    bonds_len = np.sqrt(np.square(coord[v] - coord[u]).sum(-1))
    return bonds_len 

def measure_angles(coords, idx):
    i, j, k = rearrange(idx, "... a -> a ...")
    v1, v2 = coords[i] - coords[j], coords[k] - coords[j]
    v1, v2 = normalize(v1), normalize(v2)
    x, y = norm(v1 + v2), norm(v1 - v2)
    return  2 * np.arctan2(y, x)

def measure_dihedrals(coords, indices):
    p, q, v, u = rearrange(indices, "... b -> b ...")
    v1 = normalize(coords[q] - coords[p])
    v2 = normalize(coords[v] - coords[q])
    v3 = normalize(coords[u] - coords[v])

    n1 = np.cross(v1, v2)
    n2 = np.cross(v2, v3)

    x = (n1 * n2).sum(-1)
    y = (np.cross(n1, v2) * n2).sum(-1)

    x = np.where(x == 0.0, 1e-6, x)
    return x

measure_functions = dict(
    bonds=measure_bonds,
    angles=measure_angles,
    dihedrals=measure_dihedrals,
)


In [159]:
datum.apply_bonds(measure_bonds).shape, datum.apply_angles(measure_angles).shape, datum.apply_dihedrals(measure_dihedrals).shape

  return vector / norm(vector)[..., None]


((20, 16), (20, 22), (20, 29))

In [199]:
counter = 0
num_data = 5

measures_dict = defaultdict(lambda: defaultdict(list))

for datum in dataset:
    for prop in ('bonds', 'angles', 'dihedrals'):
        chem_props = datum._apply_chemistry(
            prop, measure_functions[prop])
        for res_token, res_prop in zip(datum.residue_token, chem_props):
            code = all_residues[res_token]
            if token not in measures_dict[prop]:
                measures_dict[prop][token] = res_prop[None]
            else:
                measures_dict[prop][token] = np.concatenate(
                    (measures_dict[prop][token], res_prop[None]), axis=0)
    counter += 1
    if counter == num_data:
        break


TypeError: unhashable type: 'numpy.ndarray'

(708, 16)

In [200]:
datum = ProteinDatum.fetch_pdb_id("1L2Y")
print(datum.to_pdb_str())

ATOM  1     N    ASN  1         -8.901   4.127  -0.555                       N  
ATOM  2     CA   ASN  1         -8.608   3.135  -1.618                       C  
ATOM  3     C    ASN  1         -7.117   2.964  -1.897                       C  
ATOM  4     O    ASN  1         -6.634   1.849  -1.758                       O  
ATOM  5     CB   ASN  1         -9.437   3.396  -2.889                       C  
ATOM  6     CG   ASN  1        -10.915   3.130  -2.611                       C  
ATOM  7     OD1  ASN  1        -11.269   2.700  -1.524                       O  
ATOM  8     ND2  ASN  1        -11.806   3.406  -3.543                       N  
ATOM  9     N    LEU  2         -6.379   4.031  -2.228                       N  
ATOM  10    CA   LEU  2         -4.923   4.002  -2.452                       C  
ATOM  11    C    LEU  2         -4.136   3.187  -1.404                       C  
ATOM  12    O    LEU  2         -3.391   2.274  -1.760                       O  
ATOM  13    CB   LEU  2     