In [1]:
%load_ext autoreload
%autoreload 2

# Scanning the Chemistry of Proteins

In [2]:
from moleculib.protein.dataset import MonomerDataset
from moleculib.protein.transform import (
    ProteinCrop,
    DescribeChemistry,
)
from tqdm import tqdm
from collections import defaultdict
from moleculib.protein.alphabet import all_residues
import numpy as np
import matplotlib.pyplot as plt
from einops import rearrange


def norm(vector: np.ndarray) -> np.ndarray:
    norms_sqr = np.sum(vector**2, axis=-1)
    norms = norms_sqr ** 0.5
    return norms

def normalize(vector: np.ndarray) -> np.ndarray:
    return vector / norm(vector)[..., None]


data_path = '/mas/projects/molecularmachines/db/PDB'
min_seq_len = 16
max_seq_len = sequence_length = 512
dataset = MonomerDataset(
    base_path=data_path,
    attrs="all",
    max_resolution=1.7,
    min_sequence_length=min_seq_len,
    max_sequence_length=max_seq_len,
    frac=1.0,
    transform=[
        ProteinCrop(crop_size=sequence_length),
        DescribeChemistry(),
    ],
)


KeyboardInterrupt



In [None]:
datum.residue_token

### Scan Chemistry

In [None]:
from einops import rearrange
import numpy as np

from moleculib.protein.datum import ProteinDatum
import jax.numpy as jnp
from typing import List


class ProteinMetric:
    def __call__(self, datum: ProteinDatum):
        raise NotImplementedError("ProteinMetric is abstract")

def norm(vector: np.ndarray) -> np.ndarray:
    norms_sqr = np.sum(vector**2, axis=-1)
    norms = norms_sqr ** 0.5
    return norms

def normalize(vector: np.ndarray) -> np.ndarray:
    return vector / norm(vector)[..., None]

def measure_bonds(coord, idx):
    v, u = idx.T
    bonds_len = np.sqrt(np.square(coord[v] - coord[u]).sum(-1))
    return bonds_len * (coord[v].sum(-1) != 0.0) * (coord[u].sum(-1) > 0.0)

def measure_angles(coords, idx):
    i, j, k = rearrange(idx, "... a -> a ...")
    mask = (coords[i].sum(-1) != 0.0) & (coords[j].sum(-1) != 0.0) & (coords[k].sum(-1) != 0.0)
    v1, v2 = coords[i] - coords[j], coords[k] - coords[j]
    v1, v2 = normalize(v1), normalize(v2)
    x, y = norm(v1 + v2), norm(v1 - v2)
    return  2 * np.arctan2(y, x) * mask


def measure_dihedrals(coords, indices):    
    p, q, v, u = rearrange(indices, "... b -> b ...")
    u1, u2, u3, u4 = coords[p], coords[q], coords[v], coords[u]
    
    a1 = u2 - u1
    a2 = u3 - u2
    a3 = u4 - u3

    v1 = jnp.cross(a1, a2)
    v1 = normalize(v1)
    v2 = jnp.cross(a2, a3)
    v2 = normalize(v2)
    
    porm = jnp.sign((v1*a3).sum(-1))
    rad = jnp.arccos((v1*v2).sum(-1) / ((v1**2).sum(-1) * (v2**2).sum(-1))**0.5)
    rad = jnp.where(porm == 0, rad * porm, rad)

    mask = (u1.sum(-1) != 0.0) & (u2.sum(-1) != 0.0) & (u3.sum(-1) != 0.0) & (u4.sum(-1) != 0.0)
    
    return rad * mask

measure_functions = dict(
    bonds=measure_bonds,
    angles=measure_angles,
    dihedrals=measure_dihedrals,
)


counter = 0
num_data = 200

measures_dict = defaultdict(lambda: defaultdict(lambda: np.array([])))

with tqdm(total=num_data) as pbar:
    for datum in dataset:
        for prop in ('bonds', 'angles', 'dihedrals'):
            chem_props = datum._apply_chemistry(
                prop, measure_functions[prop])
            for token, res_prop in zip(datum.residue_token, chem_props):
                code = all_residues[token]
                if code not in measures_dict[prop]:
                    measures_dict[prop][code] = res_prop[None]
                else:
                    measures_dict[prop][code] = np.concatenate(
                        (measures_dict[prop][code], res_prop[None]), axis=0)
        counter += 1
        pbar.update(1)
        if counter == num_data:
            break


In [None]:
import jax
    
def stats(array):
    array[np.isnan(array)] = 0.0
    mask = (array != 0.0).astype(np.float32)     
    mean = (array * mask).sum(0) 
    mean = mean / (mask.sum(0) + 1e-6)
    var = ((array - mean) ** 2 * mask).sum(0) 
    var = var / (mask.sum(0) + 1e-6)
    return mean.astype(np.float32), var.astype(np.float32)

stats_dict = jax.tree_util.tree_map(stats, measures_dict)

aeho = dict()
for prop, residues in stats_dict.items():
    stats_ = [residues[code] for code in all_residues][1:]
    mean, var = list(zip(*stats_))
    mean = np.pad(np.stack(mean), [(1,0), (0,0)])
    var = np.pad(np.stack(var), [(1,0), (0,0)])
    aeho[prop] = (mean, var)

np.set_printoptions(precision=3)
aeho