In [None]:
%load_ext autoreload
%autoreload 2

# Scanning the Chemistry of Proteins

In [None]:
from moleculib.protein.dataset import MonomerDataset
from moleculib.protein.transform import (
    ProteinCrop,
    DescribeChemistry,
)

from collections import defaultdict
from moleculib.protein.alphabet import all_residues
import numpy as np
import matplotlib.pyplot as plt
from einops import rearrange


def norm(vector: np.ndarray) -> np.ndarray:
    norms_sqr = np.sum(vector**2, axis=-1)
    norms = norms_sqr ** 0.5
    return norms

def normalize(vector: np.ndarray) -> np.ndarray:
    return vector / norm(vector)[..., None]


data_path = '/mas/projects/molecularmachines/db/PDB'
min_seq_len = 16
max_seq_len = sequence_length = 512
dataset = MonomerDataset(
    base_path=data_path,
    attrs="all",
    max_resolution=1.7,
    min_sequence_length=min_seq_len,
    max_sequence_length=max_seq_len,
    frac=1.0,
    transform=[
        ProteinCrop(crop_size=sequence_length),
        DescribeChemistry(),
    ],
)

In [None]:
datum.residue_token

### Scan Chemistry

In [None]:
from einops import rearrange
import numpy as np

from moleculib.protein.datum import ProteinDatum
import jax.numpy as jnp
from typing import List


class ProteinMetric:
    def __call__(self, datum: ProteinDatum):
        raise NotImplementedError("ProteinMetric is abstract")

def norm(vector: np.ndarray) -> np.ndarray:
    norms_sqr = np.sum(vector**2, axis=-1)
    norms = norms_sqr ** 0.5
    return norms

def normalize(vector: np.ndarray) -> np.ndarray:
    return vector / norm(vector)[..., None]

def measure_bonds(coord, idx):
    v, u = idx.T
    bonds_len = np.sqrt(np.square(coord[v] - coord[u]).sum(-1))
    return bonds_len 

def measure_angles(coords, idx):
    i, j, k = rearrange(idx, "... a -> a ...")
    v1, v2 = coords[i] - coords[j], coords[k] - coords[j]
    v1, v2 = normalize(v1), normalize(v2)
    x, y = norm(v1 + v2), norm(v1 - v2)
    return  2 * np.arctan2(y, x)

def measure_dihedrals(coords, indices):
    p, q, v, u = rearrange(indices, "... b -> b ...")
    v1 = normalize(coords[q] - coords[p])
    v2 = normalize(coords[v] - coords[q])
    v3 = normalize(coords[u] - coords[v])

    n1 = np.cross(v1, v2)
    n2 = np.cross(v2, v3)

    x = (n1 * n2).sum(-1)
    y = (np.cross(n1, v2) * n2).sum(-1)

    x = np.where(x == 0.0, 1e-6, x)
    return x

measure_functions = dict(
    bonds=measure_bonds,
    angles=measure_angles,
    dihedrals=measure_dihedrals,
)


In [None]:
datum.apply_bonds(measure_bonds).shape, datum.apply_angles(measure_angles).shape, datum.apply_dihedrals(measure_dihedrals).shape

In [None]:
counter = 0
num_data = 5

measures_dict = defaultdict(lambda: defaultdict(list))

for datum in dataset:
    for prop in ('bonds', 'angles', 'dihedrals'):
        chem_props = datum._apply_chemistry(
            prop, measure_functions[prop])
        for res_token, res_prop in zip(datum.residue_token, chem_props):
            code = all_residues[res_token]
            if token not in measures_dict[prop]:
                measures_dict[prop][token] = res_prop[None]
            else:
                measures_dict[prop][token] = np.concatenate(
                    (measures_dict[prop][token], res_prop[None]), axis=0)
    counter += 1
    if counter == num_data:
        break


In [None]:
datum = ProteinDatum.fetch_pdb_id("1L2Y")
print(datum.to_pdb_str())