In [1]:
from lib.datasets import ATPBind3D

from torchdrug import transforms
from torchdrug import data, core, layers, tasks, metrics, utils, models
from torchdrug.layers import functional
from torchdrug.core import Registry as R

import torch
from torch.utils import data as torch_data
from torch.nn import functional as F
from lib.tasks import NodePropertyPrediction


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
truncuate_transform = transforms.TruncateProtein(max_length=350, random=False)
protein_view_transform = transforms.ProteinView(view='residue')
transform = transforms.Compose([truncuate_transform, protein_view_transform])

dataset = ATPBind3D(transform=transform)

train_set, valid_set, test_set = dataset.split()
print("train samples: %d, valid samples: %d, test samples: %d" %
      (len(train_set), len(valid_set), len(test_set)))


Split num:  [337, 41, 41]
train samples: 337, valid samples: 41, test samples: 41


In [3]:
from torchdrug import core, layers
from torchdrug.layers import geometry
import torch
from lib.tasks import NodePropertyPrediction

graph_construction_model = layers.GraphConstruction(node_layers=[geometry.AlphaCarbonNode()],
                                                    edge_layers=[geometry.SpatialEdge(radius=10.0, min_distance=5),
                                                                 geometry.KNNEdge(
                                                                     k=10, min_distance=5),
                                                                 geometry.SequentialEdge(max_distance=2)],
                                                    edge_feature="gearnet")
task = NodePropertyPrediction(
    models.GearNet(input_dim=21, hidden_dims=[512, 512, 512, 512, 512, 512], num_relation=7,
                   edge_input_dim=59, num_angle_bin=8,
                   batch_norm=True, concat_hidden=True, short_cut=True, readout="sum").cuda(),
    graph_construction_model=graph_construction_model,
    normalization=False,
    num_mlp_layer=2,
    metric=("micro_auroc", "mcc")
)

optimizer = torch.optim.Adam(task.parameters(), lr=1e-4)
solver = core.Engine(task, train_set, valid_set, test_set, optimizer,
                     gpus=[3], batch_size=1)


15:49:29   Preprocess training set
15:49:41   {'batch_size': 1,
 'class': 'core.Engine',
 'gpus': [3],
 'gradient_interval': 1,
 'log_interval': 100,
 'logger': 'logging',
 'num_worker': 0,
 'optimizer': {'amsgrad': False,
               'betas': (0.9, 0.999),
               'class': 'optim.Adam',
               'eps': 1e-08,
               'lr': 0.0001,
               'weight_decay': 0},
 'scheduler': None,
 'task': {'class': 'NodePropertyPrediction',
          'criterion': 'bce',
          'graph_construction_model': {'class': 'layers.GraphConstruction',
                                       'edge_feature': 'gearnet',
                                       'edge_layers': [SpatialEdge(),
                                                       KNNEdge(),
                                                       SequentialEdge()],
                                       'node_layers': [AlphaCarbonNode()]},
          'metric': ('micro_auroc', 'mcc'),
          'model': {'activation': 'relu',

In [4]:

metrics = []



In [6]:
for i in range(10):
    solver.train(num_epoch=1)
    metrics.append(solver.evaluate("test"))

15:49:58   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
15:49:58   Epoch 0 begin
15:49:59   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
15:49:59   binary cross entropy: 0.73706
15:50:24   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
15:50:24   binary cross entropy: 0.190648
15:50:50   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
15:50:50   binary cross entropy: 0.275289
15:51:26   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
15:51:26   binary cross entropy: 0.149468
15:51:39   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
15:51:39   Epoch 0 end
15:51:39   duration: 1.96 mins
15:51:39   speed: 2.86 batch / sec
15:51:39   ETA: 0.00 secs
15:51:39   max GPU memory: 1589.4 MiB
15:51:39   ------------------------------
15:51:39   average binary cross entropy: 0.185216
15:51:39   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
15:51:39   Evaluate on test
15:51:50   ------------------------------
15:51:50   mcc: 0.265743
15:51:50   micro_auroc: 0.775958
15:51:50   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
15:51:50   Epoch 1 begin
15:52:13   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
15:52:13   binary cross entropy:

In [7]:
solver.evaluate("train")

16:11:30   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
16:11:30   Evaluate on train
16:12:46   ------------------------------
16:12:46   average mcc: 0.80135
16:12:46   average micro_auroc: 0.987594


{'micro_auroc': tensor(0.9876, device='cuda:3'), 'mcc': 0.8013499550227059}

In [None]:
for i in range(10):
    solver.train(num_epoch=1)
    metrics.append(solver.evaluate("test"))

16:12:46   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
16:12:46   Epoch 10 begin
16:12:57   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
16:12:57   binary cross entropy: 0.0010001
16:13:30   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
16:13:30   binary cross entropy: 0.00329947
16:14:02   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
16:14:02   binary cross entropy: 0.0137138
16:14:34   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
16:14:34   binary cross entropy: 0.0397973
16:14:36   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
16:14:36   Epoch 10 end
16:14:36   duration: 3.23 mins
16:14:36   speed: 1.74 batch / sec
16:14:36   ETA: 0.00 secs
16:14:36   max GPU memory: 1589.4 MiB
16:14:36   ------------------------------
16:14:36   average binary cross entropy: 0.01586
16:14:36   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
16:14:36   Evaluate on test
16:14:46   ------------------------------
16:14:46   mcc: 0.39796
16:14:46   micro_auroc: 0.843575
16:14:46   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
16:14:46   Epoch 11 begin
16:15:17   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
16:15:17   binary cross e

In [None]:
metrics

In [None]:
for i in range(30):
    solver.train(num_epoch=1)
    metrics.append(solver.evaluate("test"))

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline


# metrics is your list of dictionaries
micro_auroc = [float(m['micro_auroc'].cpu().numpy())
               for m in metrics]  # convert tensors to floats
mcc = [m['mcc'] for m in metrics]

# Create x values, which are simply the index of each measurement
x_values = list(range(1, len(micro_auroc)+1))

# Create a new figure and add two subplots: one for each metric
fig, ax1 = plt.subplots()

# Plot the micro_auroc values in blue on the left y-axis
color = 'tab:blue'
ax1.set_xlabel('Training Epoch')
ax1.set_ylabel('micro_auroc', color=color)
ax1.plot(x_values, micro_auroc, color=color)
ax1.tick_params(axis='y', labelcolor=color)

# Create a second y-axis for the mcc values
ax2 = ax1.twinx()

# Plot the mcc values in red on the right y-axis
color = 'tab:red'
ax2.set_ylabel('mcc', color=color)
ax2.plot(x_values, mcc, color=color)
ax2.tick_params(axis='y', labelcolor=color)

# Show the plot
plt.title('micro_auroc and mcc over training epoch')
fig.tight_layout()
plt.show()
