In [None]:
import nglview
import torchdrug
from torchdrug import data, utils

# pdb_file = utils.download("https://files.rcsb.org/download/2LWZ.pdb", "./")
pdb_file = "enzymes/A0B8R0_relaxed.pdb"
protein = data.Protein.from_pdb(
    pdb_file,
    atom_feature="default",
    bond_feature="default",
    residue_feature="symbol",
    mol_feature=None,
)
print(protein)
print(protein.residue_feature.shape)
print(protein.atom_feature.shape)
print(protein.bond_feature.shape)

In [None]:
lys = data.feature.onehot("LYS", data.feature.residue_vocab, allow_unknown=True)

In [None]:
import nglview

view = nglview.show_file(pdb_file)
view

In [None]:
first_x = 15
print("node: 3d coords")

for atom, position in zip(protein.atom_name.tolist()[:first_x], protein.node_position.tolist()[:first_x]):
    print("%s: %s" % (data.Protein.id2atom_name[atom], position))

In [None]:
from torchdrug import datasets

dataset = datasets.BetaLactamase("~/protein-datasets/", atom_feature=None, bond_feature=None, residue_feature="default", transform=None)
train_set, valid_set, test_set = dataset.split()
print("The label of first sample: ", dataset[0][dataset.target_fields[0]])
print("train samples: %d, valid samples: %d, test samples: %d" % (len(train_set), len(valid_set), len(test_set)))

In [None]:
from torchdrug import tasks
from torchdrug import models

model = models.ProteinCNN(
    input_dim=21,
    hidden_dims=[1024, 1024],
    kernel_size=5, padding=2, readout="max"
)

task = tasks.PropertyPrediction(
    model, task=dataset.tasks,
    criterion="mse", metric=("mae", "rmse", "spearmanr", "r2"),
    normalization=False, num_mlp_layer=2
)

In [None]:
import torch
from torchdrug import core

optimizer = torch.optim.Adam(task.parameters(), lr=1e-4)
solver = core.Engine(
    task,
    train_set,
    valid_set,
    test_set,
    optimizer,
    gpus=[0],
    batch_size=64
)
solver.train(num_epoch=10)
solver.evaluate("valid")

In [None]:
dataset = datasets.BindingDB("~/protein-datasets/", atom_feature=None, bond_feature=None, residue_feature="default", transform=None)
train_set, valid_set, test_set = dataset.split(["train", "valid", "holdout_test"])
print("The label of first sample: ", dataset[0][dataset.target_fields[0]])
print("train samples: %d, valid samples: %d, test samples: %d" % (len(train_set), len(valid_set), len(test_set)))