In [1]:
import numpy as np
import torch

#check GPU
print('CUDA available: {}'.format(torch.cuda.is_available()))
print('Current GPU: {}'.format(torch.cuda.get_device_name(torch.cuda.current_device())))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

CUDA available: True
Current GPU: GeForce RTX 2080 SUPER


In [2]:
import deepchem as dc

loaded, datasets, transformers = dc.utils.load_dataset_from_disk('data/combined/amg')
train_dataset, valid_dataset, test_dataset = datasets
print('train/val/test split: {}/{}/{}'.format(
    len(train_dataset), len(valid_dataset), len(test_dataset)))

train/val/test split: 483/60/61


In [3]:
num_node_features = train_dataset.X[0].num_node_features
num_edge_features = train_dataset.X[0].num_edge_features
num_classes = train_dataset.y[0].shape[-1]
print(num_node_features, num_classes)

79 1


In [4]:
from torch_geometric.data import Data, DataLoader

def get_data_loader(dc_dataset, batch_size=64, shuffle=True):
    ds = [x.to_pyg_graph() for x in dc_dataset.X]
    for i in range(len(ds)):
        ds[i].y = torch.from_numpy(dc_dataset.y[i].reshape(1, -1))
        ds[i].w = torch.from_numpy(dc_dataset.w[i].reshape(1, -1))
    return DataLoader(ds, batch_size=batch_size, shuffle=shuffle)

train_loader = get_data_loader(train_dataset, batch_size=64, shuffle=True)
val_loader = get_data_loader(valid_dataset, batch_size=64, shuffle=True)
test_loader = get_data_loader(test_dataset, batch_size=64, shuffle=True)

for step, data in enumerate(train_loader):
    print('Step {}:'.format(step + 1))
    print('=======')
    print('Number of graphs in the current batch: {}'.format(data.num_graphs))
    print(data, '\n')



Step 1:
Number of graphs in the current batch: 64
Batch(batch=[1484], edge_attr=[3200, 12], edge_index=[2, 3200], ptr=[65], w=[64, 1], x=[1484, 79], y=[64, 1]) 

Step 2:
Number of graphs in the current batch: 64
Batch(batch=[1330], edge_attr=[2860, 12], edge_index=[2, 2860], ptr=[65], w=[64, 1], x=[1330, 79], y=[64, 1]) 

Step 3:
Number of graphs in the current batch: 64
Batch(batch=[1414], edge_attr=[3040, 12], edge_index=[2, 3040], ptr=[65], w=[64, 1], x=[1414, 79], y=[64, 1]) 

Step 4:
Number of graphs in the current batch: 64
Batch(batch=[1255], edge_attr=[2698, 12], edge_index=[2, 2698], ptr=[65], w=[64, 1], x=[1255, 79], y=[64, 1]) 

Step 5:
Number of graphs in the current batch: 64
Batch(batch=[1599], edge_attr=[3432, 12], edge_index=[2, 3432], ptr=[65], w=[64, 1], x=[1599, 79], y=[64, 1]) 

Step 6:
Number of graphs in the current batch: 64
Batch(batch=[1515], edge_attr=[3246, 12], edge_index=[2, 3246], ptr=[65], w=[64, 1], x=[1515, 79], y=[64, 1]) 

Step 7:
Number of graphs in 

In [5]:
from sklearn.metrics import roc_auc_score, average_precision_score
from sklearn.metrics import precision_score, recall_score

def train(model, optimizer, criterion):
    model.train()

    for data in train_loader:  # Iterate in batches over the training dataset.
        data = data.to(device)
        
        out = model(data.x, data.edge_index, data.edge_attr, data.batch)
        
        #loss = criterion(out, data.y)
        loss = criterion(out, data.y, weights=data.w)
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

def test(model, loader):
    model.eval()

    outs = []
    ys = []
    for data in loader:  # Iterate in batches over the training/test dataset.
        data = data.to(device)
        
        #out = model(data.x, data.edge_index, data.batch)  # Perform a single forward pass.
        out = model(data.x, data.edge_index, data.edge_attr, data.batch)
        
        outs.append(out.detach().cpu().numpy())
        ys.append(data.y.detach().cpu().numpy())
    
    pred = np.concatenate(outs, axis=0)
    y = np.concatenate(ys, axis=0)
    
    roc_auc = roc_auc_score(y, pred)
    ave_prec = average_precision_score(y, pred)
    return roc_auc, ave_prec

def test_precision_recall(model, loader, pred_th=0.5):
    model.eval()

    outs = []
    ys = []
    for data in loader:  # Iterate in batches over the training/test dataset.
        data = data.to(device)
        
        out = model(data.x, data.edge_index, 
                     data.edge_attr, data.batch)
        
        outs.append(out.detach().cpu().numpy())
        ys.append(data.y.detach().cpu().numpy())
    
    pred = np.concatenate(outs, axis=0)
    y = np.concatenate(ys, axis=0).astype(int)
    binary_pred = (pred > pred_th).astype(int)
    
    precision = precision_score(y, binary_pred, zero_division=0)
    recall = recall_score(y, binary_pred, zero_division=0)
    return precision, recall

def get_loss(model, loader):
    model.eval()

    losses = []
    for data in loader:
        data = data.to(device)
        out = model(data.x, data.edge_index, data.edge_attr, data.batch)
        loss = criterion(out, data.y, data.w)
        losses.append(loss.detach().cpu().numpy())
    return np.mean(losses)

In [6]:
from models import AttentiveFP

model = AttentiveFP(in_channels=num_node_features, hidden_channels=512, 
                    out_channels=num_classes, edge_dim=num_edge_features, 
                    num_layers=4, num_timesteps=3,
                    dropout=0.4).to(device)
print(model)

AttentiveFP(
  (lin1): Linear(in_features=79, out_features=512, bias=True)
  (atom_convs): ModuleList(
    (0): GATEConv(
      (lin1): Linear(in_features=524, out_features=512, bias=False)
      (lin2): Linear(in_features=512, out_features=512, bias=False)
    )
    (1): GATConv(512, 512, heads=1)
    (2): GATConv(512, 512, heads=1)
    (3): GATConv(512, 512, heads=1)
  )
  (atom_grus): ModuleList(
    (0): GRUCell(512, 512)
    (1): GRUCell(512, 512)
    (2): GRUCell(512, 512)
    (3): GRUCell(512, 512)
  )
  (atom_norms): ModuleList(
    (0): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (3): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (mol_conv): GATConv(512, 512, heads=1)
  (mol_gru): GRUCell(512, 512)
  (lin2): Linear(in_features=512, out_features=1, bias=True)
)


In [7]:
#transfer learning step

model_state_dict = model.state_dict()

pretrained_model = torch.load('pretrained_models/AFP_0.72710')
#pretrained_model = torch.load('pretrained_models/AFP_0.71991')
pretrained_state_dict = pretrained_model.state_dict()
sel_pretrained_state_dict = {k: v for k, v in pretrained_state_dict.items() 
                             if model_state_dict[k].size() == v.size()}
model_state_dict.update(sel_pretrained_state_dict) 
model.load_state_dict(model_state_dict)

<All keys matched successfully>

In [8]:
for p in model.parameters():
    p.requires_grad = False
for p in model.lin2.parameters():
    p.requires_grad = True

In [9]:
from loss import PrecisionAtRecallLoss, PrecisionRecallAUCLoss, PrecisionRecallAUCLoss2

#optimizer = torch.optim.Adam(model.parameters(), lr=3e-5)
optimizer = torch.optim.RMSprop(model.parameters(), lr=1e-4)
#criterion = torch.nn.BCEWithLogitsLoss(reduction='none')
#criterion = RecallAtPrecisionLoss(num_classes=num_classes, target_precision=0.9).to(device)
#criterion = PrecisionAtRecallLoss(num_classes=num_classes, target_recall=0.5).to(device)
#criterion = PrecisionAtRecallLoss(
#    num_classes=num_classes, target_recall=0.1, surrogate_type='hinge').to(device)
#criterion = PrecisionRecallAUCLoss(
#    num_classes=num_classes, precision_range=(0.0, 0.5)).to(device)
criterion = PrecisionRecallAUCLoss2(
    num_classes=num_classes, recall_range=(0.0, 1.0)).to(device)


for epoch in range(1, 31):
    train(model, optimizer, criterion)
    tr_loss = get_loss(model, train_loader)
    val_loss = get_loss(model, val_loader)
    print('Epoch: {:03d}, Train loss: {:.5f}, Val loss: {:.5f}'.format(
        epoch, tr_loss, val_loss))
    
    tr_roc_auc, tr_ave_prec = test(model, train_loader)
    val_roc_auc, val_ave_prec  = test(model, val_loader)
    print('Train - roc_auc:{:.4f}, ave_prec:{:.4f}, Val - roc_auc:{:.4f}, ave_prec:{:.4f}'.format(
        tr_roc_auc, tr_ave_prec, val_roc_auc, val_ave_prec))
    
    tr_precision, tr_recall = test_precision_recall(model, train_loader)
    val_precision, val_recall  = test_precision_recall(model, val_loader)
    print('Train - precision:{:.4f}, recall:{:.4f}, Val - precision:{:.4f}, recall:{:.4f}'.format(
        tr_precision, tr_recall, val_precision, val_recall))
    
te_loss  = get_loss(model, test_loader)
print('Test loss: {:.5f}'.format(te_loss))
te_roc_auc, te_ave_prec = test(model, test_loader)
print('Test - roc_auc:{:.4f}, ave_prec:{:.4f}'.format(te_roc_auc, te_ave_prec))
te_precision, te_recall = test_precision_recall(model, test_loader)
print('Test - precision:{:.4f}, recall:{:.4f}'.format(te_precision, te_recall))

Epoch: 001, Train loss: 1.21446, Val loss: 1.09994
Train - roc_auc:0.7037, ave_prec:0.2062, Val - roc_auc:0.8895, ave_prec:0.5831
Train - precision:0.1000, recall:0.0185, Val - precision:0.0000, recall:0.0000
Epoch: 002, Train loss: 1.11450, Val loss: 1.06081
Train - roc_auc:0.7460, ave_prec:0.2268, Val - roc_auc:0.9353, ave_prec:0.6560
Train - precision:0.2083, recall:0.0926, Val - precision:0.5000, recall:0.1429
Epoch: 003, Train loss: 1.10612, Val loss: 1.02447
Train - roc_auc:0.7695, ave_prec:0.2406, Val - roc_auc:0.9488, ave_prec:0.6669
Train - precision:0.2292, recall:0.2037, Val - precision:0.7500, recall:0.4286
Epoch: 004, Train loss: 1.10180, Val loss: 1.00316
Train - roc_auc:0.7789, ave_prec:0.2481, Val - roc_auc:0.9515, ave_prec:0.6561
Train - precision:0.3158, recall:0.3333, Val - precision:0.7500, recall:0.4286
Epoch: 005, Train loss: 1.06779, Val loss: 0.98100
Train - roc_auc:0.7851, ave_prec:0.2523, Val - roc_auc:0.9488, ave_prec:0.6235
Train - precision:0.2933, recall:0

In [10]:
tr_precision, tr_recall = test_precision_recall(model, train_loader, pred_th=0.8)
val_precision, val_recall  = test_precision_recall(model, val_loader, pred_th=0.8)
print('Train - precision:{:.4f}, recall:{:.4f}, Val - precision:{:.4f}, recall:{:.4f}'.format(
    tr_precision, tr_recall, val_precision, val_recall))
te_precision, te_recall = test_precision_recall(model, test_loader, pred_th=0.8)
print('Test - precision:{:.4f}, recall:{:.4f}'.format(te_precision, te_recall))

Train - precision:0.3387, recall:0.3889, Val - precision:0.5000, recall:0.2857
Test - precision:0.4000, recall:0.2857


In [11]:
for p in model.parameters():
    p.requires_grad = True

In [12]:
optimizer = torch.optim.RMSprop(model.parameters(), lr=3e-6)
#criterion = torch.nn.BCEWithLogitsLoss(reduction='none')
#criterion = RecallAtPrecisionLoss(num_classes=num_classes, target_precision=0.99).to(device)
#criterion = PrecisionAtRecallLoss(num_classes=num_classes, target_recall=0.05).to(device)
#criterion = PrecisionAtRecallLoss(
#    num_classes=num_classes, target_recall=0.1, surrogate_type='hinge').to(device)
#criterion = PrecisionRecallAUCLoss(
#    num_classes=num_classes, precision_range=(0.5, 1.)).to(device)
criterion = PrecisionRecallAUCLoss2(
    num_classes=num_classes, recall_range=(0.0, 1.0)).to(device)


for epoch in range(1, 101):
    train(model, optimizer, criterion)
    tr_loss = get_loss(model, train_loader)
    val_loss = get_loss(model, val_loader)
    print('Epoch: {:03d}, Train loss: {:.5f}, Val loss: {:.5f}'.format(
        epoch, tr_loss, val_loss))
    
    tr_roc_auc, tr_ave_prec = test(model, train_loader)
    val_roc_auc, val_ave_prec  = test(model, val_loader)
    print('Train - roc_auc:{:.4f}, ave_prec:{:.4f}, Val - roc_auc:{:.4f}, ave_prec:{:.4f}'.format(
        tr_roc_auc, tr_ave_prec, val_roc_auc, val_ave_prec))
    
    tr_precision, tr_recall = test_precision_recall(model, train_loader)
    val_precision, val_recall  = test_precision_recall(model, val_loader)
    print('Train - precision:{:.4f}, recall:{:.4f}, Val - precision:{:.4f}, recall:{:.4f}'.format(
        tr_precision, tr_recall, val_precision, val_recall))
    
te_loss  = get_loss(model, test_loader)
print('Test loss: {:.5f}'.format(te_loss))
te_roc_auc, te_ave_prec = test(model, test_loader)
print('Test - roc_auc:{:.4f}, ave_prec:{:.4f}'.format(te_roc_auc, te_ave_prec))
te_precision, te_recall = test_precision_recall(model, test_loader)
print('Test - precision:{:.4f}, recall:{:.4f}'.format(te_precision, te_recall))

Epoch: 001, Train loss: 0.89846, Val loss: 0.74855
Train - roc_auc:0.8248, ave_prec:0.3139, Val - roc_auc:0.9650, ave_prec:0.7512
Train - precision:0.3053, recall:0.5370, Val - precision:0.5833, recall:1.0000
Epoch: 002, Train loss: 0.89706, Val loss: 0.72863
Train - roc_auc:0.8369, ave_prec:0.3379, Val - roc_auc:0.9677, ave_prec:0.7639
Train - precision:0.3200, recall:0.5926, Val - precision:0.5833, recall:1.0000
Epoch: 003, Train loss: 0.87955, Val loss: 0.71016
Train - roc_auc:0.8424, ave_prec:0.3564, Val - roc_auc:0.9704, ave_prec:0.7854
Train - precision:0.3232, recall:0.5926, Val - precision:0.5833, recall:1.0000
Epoch: 004, Train loss: 0.87211, Val loss: 0.70264
Train - roc_auc:0.8468, ave_prec:0.3647, Val - roc_auc:0.9704, ave_prec:0.7854
Train - precision:0.3173, recall:0.6111, Val - precision:0.5833, recall:1.0000
Epoch: 005, Train loss: 0.88644, Val loss: 0.68930
Train - roc_auc:0.8485, ave_prec:0.3730, Val - roc_auc:0.9704, ave_prec:0.7854
Train - precision:0.3300, recall:0

Epoch: 041, Train loss: 0.66125, Val loss: 0.51131
Train - roc_auc:0.9075, ave_prec:0.5172, Val - roc_auc:0.9811, ave_prec:0.8806
Train - precision:0.4149, recall:0.7222, Val - precision:0.6667, recall:0.8571
Epoch: 042, Train loss: 0.67496, Val loss: 0.50425
Train - roc_auc:0.9090, ave_prec:0.5129, Val - roc_auc:0.9811, ave_prec:0.8806
Train - precision:0.4211, recall:0.7407, Val - precision:0.6667, recall:0.8571
Epoch: 043, Train loss: 0.65283, Val loss: 0.50957
Train - roc_auc:0.9104, ave_prec:0.5224, Val - roc_auc:0.9811, ave_prec:0.8806
Train - precision:0.4272, recall:0.8148, Val - precision:0.6667, recall:0.8571
Epoch: 044, Train loss: 0.64932, Val loss: 0.50153
Train - roc_auc:0.9101, ave_prec:0.5162, Val - roc_auc:0.9811, ave_prec:0.8806
Train - precision:0.4343, recall:0.7963, Val - precision:0.6667, recall:0.8571
Epoch: 045, Train loss: 0.65056, Val loss: 0.50115
Train - roc_auc:0.9112, ave_prec:0.5167, Val - roc_auc:0.9811, ave_prec:0.8806
Train - precision:0.4433, recall:0

Epoch: 081, Train loss: 0.50808, Val loss: 0.41801
Train - roc_auc:0.9373, ave_prec:0.5878, Val - roc_auc:0.9784, ave_prec:0.8766
Train - precision:0.5106, recall:0.8889, Val - precision:0.6250, recall:0.7143
Epoch: 082, Train loss: 0.51624, Val loss: 0.42161
Train - roc_auc:0.9377, ave_prec:0.5870, Val - roc_auc:0.9784, ave_prec:0.8766
Train - precision:0.5106, recall:0.8889, Val - precision:0.8333, recall:0.7143
Epoch: 083, Train loss: 0.49490, Val loss: 0.41323
Train - roc_auc:0.9383, ave_prec:0.5903, Val - roc_auc:0.9784, ave_prec:0.8766
Train - precision:0.5106, recall:0.8889, Val - precision:0.8333, recall:0.7143
Epoch: 084, Train loss: 0.50014, Val loss: 0.41185
Train - roc_auc:0.9401, ave_prec:0.6013, Val - roc_auc:0.9784, ave_prec:0.8766
Train - precision:0.4948, recall:0.8889, Val - precision:0.7143, recall:0.7143
Epoch: 085, Train loss: 0.49395, Val loss: 0.41499
Train - roc_auc:0.9403, ave_prec:0.6044, Val - roc_auc:0.9784, ave_prec:0.8766
Train - precision:0.4898, recall:0

In [13]:
tr_precision, tr_recall = test_precision_recall(model, train_loader, pred_th=0.8)
val_precision, val_recall  = test_precision_recall(model, val_loader, pred_th=0.8)
print('Train - precision:{:.4f}, recall:{:.4f}, Val - precision:{:.4f}, recall:{:.4f}'.format(
    tr_precision, tr_recall, val_precision, val_recall))
te_precision, te_recall = test_precision_recall(model, test_loader, pred_th=0.8)
print('Test - precision:{:.4f}, recall:{:.4f}'.format(te_precision, te_recall))

Train - precision:0.5375, recall:0.7963, Val - precision:0.8333, recall:0.7143
Test - precision:0.5556, recall:0.7143


In [14]:
#torch.save(model, 'trained_models/AFP_tl_amg_0.85093')
torch.save(model, 'trained_models/AFP_tl_amg_test')

In [None]:
model = torch.load('trained_models/AFP_tl_amg_0.85093')
te_score = test(model, test_loader)
print('Test: {:.5f}'.format(te_score))