In [1]:
from lib.datasets import ATPBind3D
dataset = ATPBind3D(limit=10).initialize_mask_and_weights()

dataset.split(valid_fold_num=0)

17:58:50   ATPBind3D: loaded 20 proteins and 20 targets
17:58:53   ATPBind3D: loaded 20 gvp graphs. length of data: 20
Initialize Undersampling: all ones
Initialize Weighting: all ones


[<torch.utils.data.dataset.Subset at 0x7f0061e47910>,
 <torch.utils.data.dataset.Subset at 0x7f0061e47790>,
 <torch.utils.data.dataset.Subset at 0x7f0061e47700>]

In [2]:
dataset.get_item(0)

{'graph': Protein(num_atom=2264, num_bond=4524, num_residue=566),
 'gvp_data': Data(x=[566, 3], edge_index=[2, 16980], seq=[566], name='', node_s=[566, 6], node_v=[566, 3, 3], edge_s=[16980, 32], edge_v=[16980, 1, 3], mask=[566])}

In [3]:
from torchdrug import data, utils
from lib.pipeline import graph_collate_with_gvp
import torch

# Create a DataLoader for the dataset
batch_size = 1  # You can adjust this as needed
dataloader = data.DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=graph_collate_with_gvp)

# Retrieve a single batch
batch = next(iter(dataloader))
batch = utils.cuda(batch)
for key, value in batch.items():
    if isinstance(value, torch.Tensor):
        print(f"{key}: {value.shape}")
    else:
        print(f"{key}: {type(value)}")

# Access specific elements of the batch
print("\nAccessing specific elements:")
print("Graph:", batch['graph'])
print("GVP data:", batch['gvp_data'])


graph: <class 'torchdrug.data.protein.PackedProtein'>
gvp_data: <class 'torch_geometric.data.batch.DataBatch'>

Accessing specific elements:
Graph: PackedProtein(batch_size=1, num_atoms=[2264], num_bonds=[4524], num_residues=[566], device='cuda:0')
GVP data: DataBatch(x=[566, 3], edge_index=[2, 16980], seq=[566], name=[1], node_s=[566, 6], node_v=[566, 3, 3], edge_s=[16980, 32], edge_v=[16980, 1, 3], mask=[566], batch=[566], ptr=[2])


In [4]:
batch['graph'].atom_name[:8]

tensor([17,  1,  0, 26, 17,  1,  0, 26], device='cuda:0')

In [5]:
data.Protein.atom_name2id["CA"]

1

In [6]:
batch['graph'].atom2residue[0:12]

tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2], device='cuda:0')

In [7]:
[i for i in dir(batch['graph']) if i.startswith('atom')]

['atom',
 'atom2graph',
 'atom2residue',
 'atom2valence',
 'atom_map',
 'atom_name',
 'atom_name2id',
 'atom_reference',
 'atom_type']

In [8]:
batch['graph'].node_position.shape

torch.Size([2264, 3])

In [9]:
batch['graph'].residue_feature.shape

torch.Size([566, 21])

In [10]:
# Get the first protein from batch['graph']
import numpy as np
from Bio.PDB.Polypeptide import protein_letters_3to1
first_protein = batch['graph'][0]

print(f"Node positions shape: {first_protein.node_position.shape}")
protein = first_protein

protein.to_sequence()

Node positions shape: torch.Size([2264, 3])


'PRGLELLIAQTILQGFDAQYGRFLEVTSGAQQRFEQADWHAVQQAMKNRIHLYDHHVGLVVEQLRCITNGQSTDAEFLLRVKEHYTRLLPDYPRFEIAESFFNSVYCRLFDHRSLTPERLFIFSSQPERRFRTIPRPLAKDFHPDHGWESLLMRVISDLPLRLHWQNKSRDIHYIIRHLTETLGPENLSKSHLQVANELFYRNKAAWLVGKLITPSGTLPFLLPIHQTDDGELFIDTCLTTTAEASIVFGFARSYFMVYAPLPAALVEWLREILPGKTTAELYMAIGCQKHAKTESYREYLVYLQGCNEQFIEAPGIRGMVMLVFTLPGFDRVFKVIKDKFAPQKEMSAAHVRACYQLVKEHDRVGRMADTQEFENFVLEKRHISPALMELLLQEAAEKITDLGEQIVIRHLYIERRMVPLNIWLEQVEGQQLRDAIEEYGNAIRQLAAANIFPGDMLFKNFGVTRHGRVVFYDYDEICYMTEVNFRDIPPPRYP.PWYSVSPGDVFPEEFRHWLCADPRIGPLFEEMHADLFRADYWRALQNRIREGHVEDVYAYRRRQRFSVRYG'

In [11]:
def get_residue_atom_data(protein, residue_index):
    """Extract atom data for a single residue."""
    atom_indices = protein.residue2atom(residue_index)
    atom_positions = protein.node_position[atom_indices]
    atom_name_ids = protein.atom_name[atom_indices]
    return atom_positions, atom_name_ids

def get_residue_coords(protein, residue_index, target_atoms):
    """Extract coordinates for target atoms of a single residue."""
    atom_positions, atom_name_ids = get_residue_atom_data(protein, residue_index)
    
    residue_coords = []
    existing_coords = []
    for target_atom in target_atoms:
        target_atom_id = data.Protein.atom_name2id[target_atom]
        matching_positions = atom_positions[(atom_name_ids == target_atom_id).nonzero()[0][0]]
        
        if matching_positions.numel() > 0:
            coord = matching_positions.cpu().numpy()
            residue_coords.append(coord)
            existing_coords.append(coord)
        else:
            residue_coords.append(None)
    
    # Fallback to mean coordinate if some atoms are missing
    if len(existing_coords) > 0:
        mean_coord = np.mean(existing_coords, axis=0)
        residue_coords = [coord if coord is not None else mean_coord for coord in residue_coords]
    else:
        raise ValueError(f"No coordinates found for residue {residue_index+1}")
    
    return residue_coords

def parse_protein_to_json_record(protein, name):
    """Convert a torchprotein Protein structure to coordinates of target atoms from all AAs

    Args:
        protein: a torchprotein.Protein object representing the protein structure
        name: String. Name of the protein

    Return:
        Dictionary with the protein sequence, atom 3D coordinates and name.
    """
    output = {}
    
    # Get AA sequence
    output["seq"] = protein.to_sequence()
    
    # Get atom coordinates
    coords = []
    target_atoms = ["N", "CA", "C", "O"]
    for residue_index in range(protein.num_residue):
        residue_coords = get_residue_coords(protein, residue_index, target_atoms)
        coords.append(residue_coords)
    
    coords = np.asarray(coords)
    
    output["coords"] = coords
    output["name"] = name
    
    return output

record = parse_protein_to_json_record(protein, name='')

In [12]:
import time

start_time = time.time()
result = protein.atom_name[protein.atom2residue == 1]
end_time = time.time()

execution_time = end_time - start_time
print(f"Execution time: {execution_time:.6f} seconds")
print("Result:", result)

Execution time: 0.001078 seconds
Result: tensor([17,  1,  0, 26], device='cuda:0')


In [13]:
protein.node_position[protein.atom2residue == 1]

tensor([[ 33.5700,  19.2910, -10.2110],
        [ 34.5020,  18.2620, -10.6680],
        [ 35.8620,  18.3450,  -9.9930],
        [ 36.3470,  17.3590,  -9.4360]], device='cuda:0')

In [14]:
from lib.custom_models import GVPWrapModel

model = GVPWrapModel(node_in_dim=(6, 3), node_h_dim=(128, 128), edge_in_dim=(
    32, 1), edge_h_dim=(128, 128), num_layers=3, drop_rate=0.1, gpu=0)
model.to('cuda')

GVPWrapModel(
  (W_v): Sequential(
    (0): GVP(
      (wh): Linear(in_features=3, out_features=128, bias=False)
      (ws): Linear(in_features=134, out_features=128, bias=True)
      (wv): Linear(in_features=128, out_features=128, bias=False)
    )
    (1): LayerNorm(
      (scalar_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
  )
  (W_e): Sequential(
    (0): GVP(
      (wh): Linear(in_features=1, out_features=128, bias=False)
      (ws): Linear(in_features=160, out_features=128, bias=True)
      (wv): Linear(in_features=128, out_features=128, bias=False)
    )
    (1): LayerNorm(
      (scalar_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
  )
  (encoder_layers): ModuleList(
    (0): GVPConvLayer(
      (conv): GVPConv()
      (norm): ModuleList(
        (0): LayerNorm(
          (scalar_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        )
        (1): LayerNorm(
          (scalar_norm): LayerNorm((128,), eps=1e-05, elementw

In [15]:
from lib.pipeline import Pipeline

pipeline_kwargs = {
    'model': model,
    'dataset': 'atpbind3d-minimal',
    'gpus': [0],
    'model_kwargs': {},
    'optimizer_kwargs': {'lr': 1e-4},
    'task_kwargs': {
        'normalization': False,
        'num_mlp_layer': 2,
        'metric': ['mcc'],
        'node_feature_type': 'gvp_data',
    },
    'batch_size': 1,
    'gradient_interval': 1,
    'verbose': False,
    'valid_fold_num': 0,
    'dataset_kwargs': {
        'to_slice': True,
        'max_slice_length': 550,
        'padding': 100
    },
    'num_mlp_layer': 2,
}

pipeline = Pipeline(**pipeline_kwargs)
task = pipeline.task

# Test the task
task.predict(batch)

init pipeline, model: GVPWrapModel(
  (W_v): Sequential(
    (0): GVP(
      (wh): Linear(in_features=3, out_features=128, bias=False)
      (ws): Linear(in_features=134, out_features=128, bias=True)
      (wv): Linear(in_features=128, out_features=128, bias=False)
    )
    (1): LayerNorm(
      (scalar_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
  )
  (W_e): Sequential(
    (0): GVP(
      (wh): Linear(in_features=1, out_features=128, bias=False)
      (ws): Linear(in_features=160, out_features=128, bias=True)
      (wv): Linear(in_features=128, out_features=128, bias=False)
    )
    (1): LayerNorm(
      (scalar_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
  )
  (encoder_layers): ModuleList(
    (0): GVPConvLayer(
      (conv): GVPConv()
      (norm): ModuleList(
        (0): LayerNorm(
          (scalar_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        )
        (1): LayerNorm(
          (scalar_norm): LayerNorm((128,

tensor([[0.8345],
        [0.8345],
        [0.8345],
        [0.8345],
        [0.8345],
        [0.8345],
        [0.8345],
        [0.8345],
        [0.8254],
        [0.8345],
        [0.8345],
        [0.8345],
        [0.8345],
        [0.8345],
        [0.8345],
        [0.8339],
        [0.8341],
        [0.8330],
        [0.8345],
        [0.8339],
        [0.8345],
        [0.8345],
        [0.8345],
        [0.8345],
        [0.8345],
        [0.8345],
        [0.8345],
        [0.8345],
        [0.8345],
        [0.8345],
        [0.8345],
        [0.8345],
        [0.8345],
        [0.8345],
        [0.8345],
        [0.8345],
        [0.8345],
        [0.8345],
        [0.8345],
        [0.8345],
        [0.8345],
        [0.8345],
        [0.8345],
        [0.8345],
        [0.8345],
        [0.8345],
        [0.8345],
        [0.8345],
        [0.8345],
        [0.8345],
        [0.8345],
        [0.8345],
        [0.8345],
        [0.8345],
        [0.8345],
        [0

In [16]:
pipeline_kwargs = {
    'model': 'gearnet',
    'dataset': 'atpbind3d-minimal',
    'gpus': [0],
    'model_kwargs': {
        'input_dim': 21,
        'hidden_dims': [512] * 4,
        'gpu': 0,
    },
    'optimizer_kwargs': {'lr': 2e-3},
    'task_kwargs': {
        'normalization': False,
        'num_mlp_layer': 2,
        'metric': ['mcc'],
    },
    'batch_size': 1,
    'gradient_interval': 1,
    'verbose': False,
    'valid_fold_num': 0,
    'dataset_kwargs': {
        'to_slice': True,
        'max_slice_length': 550,
        'padding': 100
    },
    'num_mlp_layer': 2,
}
pipeline = Pipeline(**pipeline_kwargs)

data.dataloader.graph_collate = graph_collate_with_gvp
train_record = pipeline.train_until_fit(max_epoch=10, patience=10)

init pipeline, model: gearnet, dataset: atpbind3d-minimal, gpus: [0]
load model gearnet, kwargs: {'input_dim': 21, 'hidden_dims': [512, 512, 512, 512], 'gpu': 0}
get dataset with kwargs: {'to_slice': True, 'max_slice_length': 550, 'padding': 100}
Initialize Undersampling: all ones
Initialize Weighting: all ones
train samples: 4, valid samples: 2, test samples: 5
no scheduler
pipeline batch_size: 1, gradient_interval: 1
0m2s {'mcc': -0.0068, 'valid_mcc': 0.0, 'train_bce': 0.9336, 'valid_bce': 38.4902, 'best_threshold': -3.0}
0m1s {'mcc': -0.0136, 'valid_mcc': 0.0, 'train_bce': 0.4173, 'valid_bce': 4.5143, 'best_threshold': -1.8}
0m1s {'mcc': 0.0, 'valid_mcc': 0.0, 'train_bce': 0.1301, 'valid_bce': 0.5955, 'best_threshold': 0.3}
0m1s {'mcc': 0.1262, 'valid_mcc': 0.1005, 'train_bce': 0.0937, 'valid_bce': 0.3288, 'best_threshold': -1.5}
0m1s {'mcc': 0.1081, 'valid_mcc': 0.0767, 'train_bce': 0.022, 'valid_bce': 0.3913, 'best_threshold': -0.8}
0m1s {'mcc': 0.1094, 'valid_mcc': 0.0641, 'train

In [32]:
import logging
pipeline_kwargs = {
    'model': 'gvp',
    'dataset': 'atpbind3d-minimal',
    'gpus': [0],
    'model_kwargs': {
        'node_in_dim': (6, 3),
        'node_h_dim': (100, 16),
        'edge_in_dim': (32, 1),
        'edge_h_dim': (32, 1),
        'num_layers': 1,
        'drop_rate': 0.1,
        'output_dim': 20,
        'gpu': 0,
    },
    'task_kwargs': {
        'normalization': False,
        'num_mlp_layer': 1,
        'metric': ['mcc'],
        'node_feature_type': 'gvp_data',
    },
    'batch_size': 1,
    'gradient_interval': 1,
    'verbose': False,
    'valid_fold_num': 0,
    'dataset_kwargs': {
        'to_slice': True,
        'max_slice_length': 550,
        'padding': 100
    },
    'num_mlp_layer': 2,
}
logging.basicConfig(level=logging.DEBUG)
pipeline = Pipeline(**pipeline_kwargs)

data.dataloader.graph_collate = graph_collate_with_gvp
train_record = pipeline.train_until_fit(max_epoch=10, patience=10)

init pipeline, model: gvp, dataset: atpbind3d-minimal, gpus: [0]
load model gvp, kwargs: {'node_in_dim': (6, 3), 'node_h_dim': (100, 16), 'edge_in_dim': (32, 1), 'edge_h_dim': (32, 1), 'num_layers': 1, 'drop_rate': 0.1, 'output_dim': 20, 'gpu': 0}
get dataset with kwargs: {'to_slice': True, 'max_slice_length': 550, 'padding': 100}
Initialize Undersampling: all ones
Initialize Weighting: all ones
train samples: 4, valid samples: 2, test samples: 5
no scheduler
pipeline batch_size: 1, gradient_interval: 1
0m1s {'mcc': 0.1044, 'valid_mcc': 0.1392, 'train_bce': 0.4086, 'valid_bce': 0.1612, 'best_threshold': -2.7}
0m1s {'mcc': 0.1007, 'valid_mcc': 0.144, 'train_bce': 0.2615, 'valid_bce': 0.1574, 'best_threshold': -3.0}
0m1s {'mcc': 0.0542, 'valid_mcc': 0.1363, 'train_bce': 0.2624, 'valid_bce': 0.1593, 'best_threshold': -3.0}
0m1s {'mcc': 0.0871, 'valid_mcc': 0.1718, 'train_bce': 0.2591, 'valid_bce': 0.1663, 'best_threshold': -2.7}
0m1s {'mcc': 0.0932, 'valid_mcc': 0.1439, 'train_bce': 0.254

In [48]:
from lib.utils import protein_to_sequence


def protein_to_sequence(protein):
    seq = protein.to_sequence()
    if isinstance(seq, list):
        seq = seq[0]
    return ''.join(i for i in seq if i != '.')

# Extract the task object from the pipeline
task = pipeline.task

# Get a batch from the train set
train_loader = data.DataLoader(
    pipeline.train_set,
    batch_size=pipeline.batch_size,
    shuffle=False,
    collate_fn=graph_collate_with_gvp
)

# Get the first batch
batch = next(iter(train_loader))

# Move the batch to the correct device
batch = utils.cuda(batch, device=torch.device(f'cuda:{pipeline.gpus[0]}'))
print(f'sequence: {protein_to_sequence(batch["graph"])[29:39]}')
print(f'prediction: {task.predict(batch)[29:39]}')
print(f'label: {batch["graph"].target[29:39]}')

sequence: IGKGFEDLMT
prediction: tensor([[-3.4665],
        [-1.8042],
        [-1.8622],
        [-2.1003],
        [-2.7256],
        [-2.9655],
        [-3.0620],
        [-3.5202],
        [-2.1327],
        [-2.1367]], device='cuda:0', grad_fn=<SliceBackward0>)
label: tensor([[1],
        [1],
        [1],
        [1],
        [1],
        [0],
        [0],
        [0],
        [1],
        [0]], device='cuda:0')


In [142]:
import torch
for i in range(100):
    task.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
    optimizer.zero_grad()
    loss, metric = task(batch)
    loss.backward()
    optimizer.step()

In [161]:
task.eval()
print(f'sequence: {protein_to_sequence(batch["graph"])[29:39]}')
print(f'prediction: {task.predict(batch)[29:39]}')
print(f'label: {batch["graph"].target[29:39]}')

sequence: IGKGFEDLMT
prediction: tensor([[-3.4665],
        [-1.8042],
        [-1.8622],
        [-2.1003],
        [-2.7256],
        [-2.9655],
        [-3.0620],
        [-3.5202],
        [-2.1327],
        [-2.1367]], device='cuda:0', grad_fn=<SliceBackward0>)
label: tensor([[1],
        [1],
        [1],
        [1],
        [1],
        [0],
        [0],
        [0],
        [1],
        [0]], device='cuda:0')
