In [1]:
from  utils import process_pdb_to_arrays, input_to_jax_structure
from common.residue_library import ResidueLibrary
import numpy as np
import jax.numpy as jnp
from shrake_rupley_jax import calculate_sasa
residue_library = ResidueLibrary()
sphere_points = jnp.array(np.loadtxt("./common/thomson1000.xyz", skiprows=1))

In [None]:
class ShrakeRupleyCalculator:
    def __init__(self, probe_radius: float = 1.4, points_file: str = "./common/thomson1000.xyz"):
        self._sphere_points = jnp.array(np.loadtxt(points_file, skiprows=1))
        self.n_points = len(self._sphere_points)
        self.probe_radius = jnp.array(probe_radius)

    @partial(jit, static_argnums=(0,))
    def _compute_interaction_matrix(self, coords: jnp.ndarray, vdw_radii: jnp.ndarray) -> jnp.ndarray:
        radii = vdw_radii + self.probe_radius

        diff = coords[:, None, :] - coords[None, :, :]
        dist2 = jnp.sum(diff * diff, axis=-1)

        radsum = radii[:, None] + radii[None, :]
        radsum2 = radsum * radsum

        result = (dist2 <= radsum2) & ~jnp.eye(coords.shape[0], dtype=bool)
        return result

    @partial(jit, static_argnums=(0,))
    def _compute_all_atom_sasa(self, coords: jnp.ndarray, vdw_radii: jnp.ndarray,
                               interaction_matrix: jnp.ndarray) -> jnp.ndarray:
        radii = vdw_radii + self.probe_radius

        # Expand sphere points for all atoms
        scaled_points = self._sphere_points[None, :, :] * radii[:, None, None] + coords[:, None, :]

        # Compute distances from all scaled points to all atoms
        diff = scaled_points[:, :, None, :] - coords[None, None, :, :]
        dist2 = jnp.sum(diff * diff, axis=-1)

        # Compare against squared radii
        radii2 = jnp.square(vdw_radii + self.probe_radius)
        is_buried = (dist2 <= radii2[None, None, :]) & interaction_matrix[:, None, :]

        # Determine buried points for each atom
        buried_points = jnp.any(is_buried, axis=-1)
        n_accessible = self.n_points - jnp.sum(buried_points, axis=-1)

        # Calculate SASA
        areas = 4.0 * jnp.pi * jnp.square(radii)
        sasa = areas * (n_accessible / self.n_points)

        return sasa

    @partial(jit, static_argnums=(0,))
    def calculate_all(self, coords: jnp.ndarray, vdw_radii: jnp.ndarray,
                      mask: jnp.ndarray = None) -> jnp.ndarray:
        """
        mask: [N] mask (1 = valid atom, 0 = masked/padding in AF2)
        """
        if mask is None:
            mask = jnp.ones_like(vdw_radii)  # All atoms valid

        # Apply mask directly (1=valid, 0=masked)
        masked_coords = coords * mask[:, None]  # [N, 3]
        masked_radii = vdw_radii * mask        # [N]

        # Calculate with masked values
        interaction_matrix = self._compute_interaction_matrix(masked_coords, masked_radii)
        sasa = self._compute_all_atom_sasa(masked_coords, masked_radii, interaction_matrix)

        return sasa

In [2]:
test_pdb = "/home/alessio/dr_sasa_python/data/PRODIGYdataset_fixed/1A2K.pdb"
(positions, aatype_list, atom_mask, residue_index_list, 
 b_factors, atom_names,
 residue_names, chain_ids, 
 residue_numbers, elements, atom_radii) = process_pdb_to_arrays(test_pdb, residue_library)
jax_structure_data = input_to_jax_structure(
    atom_positions=positions,
    atom_radii=atom_radii,
    aatype=aatype_list,
    atom_mask=atom_mask,
    residue_index=residue_index_list,
    b_factors=b_factors,
    atom_names=atom_names,
    residue_names=residue_names,
    chain_ids=chain_ids,
    residue_numbers=residue_numbers,
    elements=elements,
    structure_id="protein"
)


In [3]:
sasa = calculate_sasa(jax_structure_data.atom_positions, jax_structure_data.atom_radii, 
                       jnp.array(jax_structure_data.atom_mask), sphere_points)



In [4]:
sasa

: 