
# Convolutional neural network for pocket prediction

In this notebook, we will implement a convolutional neural network using Keras and Tensorflow. 



Install the dependencies by running the following cell:



In [None]:
#!pip install scikit-learn matplotlib tensorflow numpy gdown

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com


### Load the necessary libraries

The following code imports the packages needed for this CNN.

In [2]:
import tensorflow as tf
import torch
from torch.utils.data import Dataset
import torch
import os
import glob
from Bio.PDB import PDBParser
import numpy as np

print('Check GPU runtime type... ')
if len(tf.config.list_physical_devices('GPU')) == 0:
  print('Change Runtype Type in top menu for GPU acceleration')
  print(' "Runtime" -> "Change Runtime Type" -> "GPU"')
else:
  print('OK!')

2025-04-22 12:40:19.641394: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-04-22 12:40:19.670811: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-04-22 12:40:19.670833: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-04-22 12:40:19.671427: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-22 12:40:19.676830: I tensorflow/core/platform/cpu_feature_guar

Check GPU runtime type... 
Change Runtype Type in top menu for GPU acceleration
 "Runtime" -> "Change Runtime Type" -> "GPU"


2025-04-22 12:40:23.282670: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2025-04-22 12:40:23.283462: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2256] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


### Prepare the dataset

The dataset can be downloaded using the gdown package, this dataset will be used for training, validation and testing. Notice that since we are reusing a pretrained model we must keep the same input size and data pre-processing as in the original model.

In [3]:
#!gdown 1mtU0n_-ejTE9yA_G7iuOQJ4BCs9FVUhs --output dataset.zip
#!unzip dataset.zip

**collect_pdb_pairs** and **PocketGridDataset** functions are used to prepare the dataset:

In [4]:

def collect_pdb_pairs(dataset_root):
    protein_paths = []
    pocket_paths = []

    for pdb_dir in sorted(os.listdir(dataset_root)):
        full_path = os.path.join(dataset_root, pdb_dir)
        if not os.path.isdir(full_path):
            continue  

        protein_pdb = glob.glob(os.path.join(full_path, "*_protein.pdb"))
        pocket_pdb = glob.glob(os.path.join(full_path, "*_pocket.pdb"))

        if protein_pdb and pocket_pdb:
            protein_paths.append(protein_pdb[0])
            pocket_paths.append(pocket_pdb[0])

    return protein_paths, pocket_paths


**PocketGridDataset** uses the following functions (**voxelize_structure** and **generate_label_grid**)

In [5]:


def voxelize_structure(pdb_path, origin=None, grid_size=32, voxel_size=1.0, channels=['C', 'N', 'O', 'S'], return_origin=False):
    """
    Converts a protein PDB file into a 3D voxel grid.

    Parameters:
        pdb_path (str): path to the protein PDB file.
        origin (np.array or None): if None, automatically center the grid around the structure.
        grid_size (int): number of voxels along each axis.
        voxel_size (float): size of each voxel in Å.
        channels (list): Atom types to consider (default: C, N, O, S).
        return_origin (bool): if True, also return the grid origin used.

    Returns:
        grid (np.array): 4D array of shape (channels, D, H, W)
        origin (np.array): [only if return_origin=True] the origin used to align the grid.
    """
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure("structure", pdb_path)

    if origin is None:
        coords = np.array([atom.coord for atom in structure.get_atoms()])
        origin = coords.mean(axis=0) - (grid_size * voxel_size / 2)

    grid = np.zeros((len(channels), grid_size, grid_size, grid_size), dtype=np.float32)

    for atom in structure.get_atoms():
        atom_type = atom.element.strip()
        if atom_type not in channels:
            continue
        idx = channels.index(atom_type)
        coord = np.array(atom.coord)
        voxel = ((coord - origin) / voxel_size).astype(int)
        if all(0 <= v < grid_size for v in voxel):
            grid[idx, voxel[0], voxel[1], voxel[2]] += 1

    return (grid, origin) if return_origin else grid

In [6]:
def generate_label_grid(pocket_pdb_path, origin, grid_size=32, voxel_size=1.0):
    from Bio.PDB import PDBParser
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure("pocket", pocket_pdb_path)
    label_grid = np.zeros((grid_size, grid_size, grid_size), dtype=np.uint8)

    for atom in structure.get_atoms():
        coord = np.array(atom.coord)
        voxel = ((coord - origin) / voxel_size).astype(int)
        if all(0 <= v < grid_size for v in voxel):
            label_grid[voxel[0], voxel[1], voxel[2]] = 1

    return label_grid

In [7]:

class PocketGridDataset(Dataset):
    def __init__(self, protein_paths, pocket_paths, grid_size=32, voxel_size=1.0):
        self.protein_paths = protein_paths
        self.pocket_paths = pocket_paths
        self.grid_size = grid_size
        self.voxel_size = voxel_size

    def __len__(self):
        return len(self.protein_paths)

    def __getitem__(self, idx):
        protein_pdb = self.protein_paths[idx]
        pocket_pdb = self.pocket_paths[idx]

        protein_grid, origin = voxelize_structure(protein_pdb, return_origin=True,
                                                  grid_size=self.grid_size,
                                                  voxel_size=self.voxel_size)
        label_grid = generate_label_grid(pocket_pdb, origin,
                                         grid_size=self.grid_size,
                                         voxel_size=self.voxel_size)

        # Convert to tensors
        X = torch.tensor(protein_grid, dtype=torch.float32)
        Y = torch.tensor(label_grid, dtype=torch.float32).unsqueeze(0)

        return X, Y


In [8]:
dataset_root = "refined-set"  
protein_paths, pocket_paths = collect_pdb_pairs(dataset_root)

print(f"Found {len(protein_paths)} protein-pocket pairs.")

dataset = PocketGridDataset(protein_paths, pocket_paths, grid_size=32, voxel_size=1.0)

Found 5316 protein-pocket pairs.


In [9]:
# input size and preprocessing method
X, Y = dataset[0] # get the first sample

print("Input shape:", X.shape)  # should be (4, 32, 32, 32) - has 4 channels
print("Label shape:", Y.shape)  # should be (1, 32, 32, 32)
print("Pocket voxels in label:", Y.sum().item())


Input shape: torch.Size([4, 32, 32, 32])
Label shape: torch.Size([1, 32, 32, 32])
Pocket voxels in label: 367.0


In [10]:
'''os.makedirs("voxels", exist_ok=True)

for i, (X, Y) in enumerate(dataset):  # original dataset class
    np.save(f"voxels/X_{i}.npy", X.numpy())
    np.save(f"voxels/Y_{i}.npy", Y.numpy())'''

'os.makedirs("voxels", exist_ok=True)\n\nfor i, (X, Y) in enumerate(dataset):  # original dataset class\n    np.save(f"voxels/X_{i}.npy", X.numpy())\n    np.save(f"voxels/Y_{i}.npy", Y.numpy())'

In [11]:
class VoxelNPYDataset(torch.utils.data.Dataset):
    def __init__(self, voxel_dir, total_samples):
        self.voxel_dir = voxel_dir
        self.total_samples = total_samples

    def __len__(self):
        return self.total_samples

    def __getitem__(self, idx):
        X_path = os.path.join(self.voxel_dir, f"X_{idx}.npy")
        Y_path = os.path.join(self.voxel_dir, f"Y_{idx}.npy")
        
        X = torch.tensor(np.load(X_path)).float()
        Y = torch.tensor(np.load(Y_path)).float()
        
        return X, Y


In [12]:
dataset = VoxelNPYDataset("voxels", total_samples=5305)

In [13]:
print("dataset loaded:", type(dataset))


dataset loaded: <class '__main__.VoxelNPYDataset'>
