In [1]:
import numpy as np
import torch

#check GPU
print('CUDA available: {}'.format(torch.cuda.is_available()))
print('Current GPU: {}'.format(torch.cuda.get_device_name(torch.cuda.current_device())))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

CUDA available: True
Current GPU: GeForce RTX 2080 SUPER


In [2]:
import deepchem as dc

loaded, datasets, transformers = dc.utils.load_dataset_from_disk('data/combined/cisplatin')
train_dataset, valid_dataset, test_dataset = datasets
print('train/val/test split: {}/{}/{}'.format(
    len(train_dataset), len(valid_dataset), len(test_dataset)))

train/val/test split: 490/61/61


In [3]:
num_node_features = train_dataset.X[0].num_node_features
num_edge_features = train_dataset.X[0].num_edge_features
num_classes = train_dataset.y[0].shape[-1]
print(num_node_features, num_classes)

79 1


In [4]:
from torch_geometric.data import Data, DataLoader

def get_data_loader(dc_dataset, batch_size=64, shuffle=True):
    ds = [x.to_pyg_graph() for x in dc_dataset.X]
    for i in range(len(ds)):
        ds[i].y = torch.from_numpy(dc_dataset.y[i].reshape(1, -1))
        ds[i].w = torch.from_numpy(dc_dataset.w[i].reshape(1, -1))
    return DataLoader(ds, batch_size=batch_size, shuffle=shuffle)

train_loader = get_data_loader(train_dataset, batch_size=64, shuffle=True)
val_loader = get_data_loader(valid_dataset, batch_size=64, shuffle=True)
test_loader = get_data_loader(test_dataset, batch_size=64, shuffle=True)

for step, data in enumerate(train_loader):
    print('Step {}:'.format(step + 1))
    print('=======')
    print('Number of graphs in the current batch: {}'.format(data.num_graphs))
    print(data, '\n')



Step 1:
Number of graphs in the current batch: 64
Batch(batch=[1395], edge_attr=[3012, 12], edge_index=[2, 3012], ptr=[65], w=[64, 1], x=[1395, 79], y=[64, 1]) 

Step 2:
Number of graphs in the current batch: 64
Batch(batch=[1444], edge_attr=[3084, 12], edge_index=[2, 3084], ptr=[65], w=[64, 1], x=[1444, 79], y=[64, 1]) 

Step 3:
Number of graphs in the current batch: 64
Batch(batch=[1390], edge_attr=[2930, 12], edge_index=[2, 2930], ptr=[65], w=[64, 1], x=[1390, 79], y=[64, 1]) 

Step 4:
Number of graphs in the current batch: 64
Batch(batch=[1402], edge_attr=[3016, 12], edge_index=[2, 3016], ptr=[65], w=[64, 1], x=[1402, 79], y=[64, 1]) 

Step 5:
Number of graphs in the current batch: 64
Batch(batch=[1468], edge_attr=[3184, 12], edge_index=[2, 3184], ptr=[65], w=[64, 1], x=[1468, 79], y=[64, 1]) 

Step 6:
Number of graphs in the current batch: 64
Batch(batch=[1387], edge_attr=[2982, 12], edge_index=[2, 2982], ptr=[65], w=[64, 1], x=[1387, 79], y=[64, 1]) 

Step 7:
Number of graphs in 

In [5]:
from sklearn.metrics import roc_auc_score, average_precision_score
from sklearn.metrics import precision_score, recall_score

def train(model, optimizer, criterion):
    model.train()

    for data in train_loader:  # Iterate in batches over the training dataset.
        data = data.to(device)
        
        out = model(data.x, data.edge_index, data.edge_attr, data.batch)
        
        #loss = criterion(out, data.y)
        loss = criterion(out, data.y, weights=data.w)
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

def test(model, loader):
    model.eval()

    outs = []
    ys = []
    for data in loader:  # Iterate in batches over the training/test dataset.
        data = data.to(device)
        
        #out = model(data.x, data.edge_index, data.batch)  # Perform a single forward pass.
        out = model(data.x, data.edge_index, data.edge_attr, data.batch)
        
        outs.append(out.detach().cpu().numpy())
        ys.append(data.y.detach().cpu().numpy())
    
    pred = np.concatenate(outs, axis=0)
    y = np.concatenate(ys, axis=0)
    
    roc_auc = roc_auc_score(y, pred)
    ave_prec = average_precision_score(y, pred)
    return roc_auc, ave_prec

def test_precision_recall(model, loader, pred_th=0.5):
    model.eval()

    outs = []
    ys = []
    for data in loader:  # Iterate in batches over the training/test dataset.
        data = data.to(device)
        
        out = model(data.x, data.edge_index, 
                     data.edge_attr, data.batch)
        
        outs.append(out.detach().cpu().numpy())
        ys.append(data.y.detach().cpu().numpy())
    
    pred = np.concatenate(outs, axis=0)
    y = np.concatenate(ys, axis=0).astype(int)
    binary_pred = (pred > pred_th).astype(int)
    
    precision = precision_score(y, binary_pred, zero_division=0)
    recall = recall_score(y, binary_pred, zero_division=0)
    return precision, recall

def get_loss(model, loader):
    model.eval()

    losses = []
    for data in loader:
        data = data.to(device)
        out = model(data.x, data.edge_index, data.edge_attr, data.batch)
        loss = criterion(out, data.y, data.w)
        losses.append(loss.detach().cpu().numpy())
    return np.mean(losses)

In [6]:
from models import AttentiveFP

model = AttentiveFP(in_channels=num_node_features, hidden_channels=512, 
                    out_channels=num_classes, edge_dim=num_edge_features, 
                    num_layers=4, num_timesteps=3,
                    dropout=0.4).to(device)
print(model)

AttentiveFP(
  (lin1): Linear(in_features=79, out_features=512, bias=True)
  (atom_convs): ModuleList(
    (0): GATEConv(
      (lin1): Linear(in_features=524, out_features=512, bias=False)
      (lin2): Linear(in_features=512, out_features=512, bias=False)
    )
    (1): GATConv(512, 512, heads=1)
    (2): GATConv(512, 512, heads=1)
    (3): GATConv(512, 512, heads=1)
  )
  (atom_grus): ModuleList(
    (0): GRUCell(512, 512)
    (1): GRUCell(512, 512)
    (2): GRUCell(512, 512)
    (3): GRUCell(512, 512)
  )
  (atom_norms): ModuleList(
    (0): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (3): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (mol_conv): GATConv(512, 512, heads=1)
  (mol_gru): GRUCell(512, 512)
  (lin2): Linear(in_features=512, out_features=1, bias=True)
)


In [7]:
#transfer learning step

model_state_dict = model.state_dict()

pretrained_model = torch.load('pretrained_models/AFP_0.72710')
#pretrained_model = torch.load('pretrained_models/AFP_0.71991')
pretrained_state_dict = pretrained_model.state_dict()
sel_pretrained_state_dict = {k: v for k, v in pretrained_state_dict.items() 
                             if model_state_dict[k].size() == v.size()}
model_state_dict.update(sel_pretrained_state_dict) 
model.load_state_dict(model_state_dict)

<All keys matched successfully>

In [8]:
for p in model.parameters():
    p.requires_grad = False
for p in model.lin2.parameters():
    p.requires_grad = True

In [9]:
from loss import PrecisionAtRecallLoss, PrecisionRecallAUCLoss, PrecisionRecallAUCLoss2

#optimizer = torch.optim.Adam(model.parameters(), lr=3e-5)
optimizer = torch.optim.RMSprop(model.parameters(), lr=1e-4)
#criterion = torch.nn.BCEWithLogitsLoss(reduction='none')
#criterion = RecallAtPrecisionLoss(num_classes=num_classes, target_precision=0.9).to(device)
#criterion = PrecisionAtRecallLoss(num_classes=num_classes, target_recall=0.5).to(device)
#criterion = PrecisionAtRecallLoss(
#    num_classes=num_classes, target_recall=0.1, surrogate_type='hinge').to(device)
#criterion = PrecisionRecallAUCLoss(
#    num_classes=num_classes, precision_range=(0.0, 0.5)).to(device)
criterion = PrecisionRecallAUCLoss2(
    num_classes=num_classes, recall_range=(0.0, 1.0)).to(device)


for epoch in range(1, 31):
    train(model, optimizer, criterion)
    tr_loss = get_loss(model, train_loader)
    val_loss = get_loss(model, val_loader)
    print('Epoch: {:03d}, Train loss: {:.5f}, Val loss: {:.5f}'.format(
        epoch, tr_loss, val_loss))
    
    tr_roc_auc, tr_ave_prec = test(model, train_loader)
    val_roc_auc, val_ave_prec  = test(model, val_loader)
    print('Train - roc_auc:{:.4f}, ave_prec:{:.4f}, Val - roc_auc:{:.4f}, ave_prec:{:.4f}'.format(
        tr_roc_auc, tr_ave_prec, val_roc_auc, val_ave_prec))
    
    tr_precision, tr_recall = test_precision_recall(model, train_loader)
    val_precision, val_recall  = test_precision_recall(model, val_loader)
    print('Train - precision:{:.4f}, recall:{:.4f}, Val - precision:{:.4f}, recall:{:.4f}'.format(
        tr_precision, tr_recall, val_precision, val_recall))
    
te_loss  = get_loss(model, test_loader)
print('Test loss: {:.5f}'.format(te_loss))
te_roc_auc, te_ave_prec = test(model, test_loader)
print('Test - roc_auc:{:.4f}, ave_prec:{:.4f}'.format(te_roc_auc, te_ave_prec))
te_precision, te_recall = test_precision_recall(model, test_loader)
print('Test - precision:{:.4f}, recall:{:.4f}'.format(te_precision, te_recall))

Epoch: 001, Train loss: 1.12833, Val loss: 1.12112
Train - roc_auc:0.7254, ave_prec:0.2464, Val - roc_auc:0.7453, ave_prec:0.4107
Train - precision:0.2692, recall:0.2333, Val - precision:0.4000, recall:0.2500
Epoch: 002, Train loss: 1.08298, Val loss: 1.09386
Train - roc_auc:0.7567, ave_prec:0.2668, Val - roc_auc:0.7877, ave_prec:0.4834
Train - precision:0.3095, recall:0.4333, Val - precision:0.3333, recall:0.3750
Epoch: 003, Train loss: 1.08307, Val loss: 1.07962
Train - roc_auc:0.7642, ave_prec:0.2706, Val - roc_auc:0.7948, ave_prec:0.4887
Train - precision:0.3187, recall:0.4833, Val - precision:0.3333, recall:0.3750
Epoch: 004, Train loss: 1.05844, Val loss: 1.06894
Train - roc_auc:0.7703, ave_prec:0.2727, Val - roc_auc:0.7995, ave_prec:0.4905
Train - precision:0.3158, recall:0.5000, Val - precision:0.2500, recall:0.3750
Epoch: 005, Train loss: 1.04664, Val loss: 1.05952
Train - roc_auc:0.7735, ave_prec:0.2736, Val - roc_auc:0.8019, ave_prec:0.4924
Train - precision:0.3000, recall:0

In [10]:
tr_precision, tr_recall = test_precision_recall(model, train_loader, pred_th=0.8)
val_precision, val_recall  = test_precision_recall(model, val_loader, pred_th=0.8)
print('Train - precision:{:.4f}, recall:{:.4f}, Val - precision:{:.4f}, recall:{:.4f}'.format(
    tr_precision, tr_recall, val_precision, val_recall))
te_precision, te_recall = test_precision_recall(model, test_loader, pred_th=0.8)
print('Test - precision:{:.4f}, recall:{:.4f}'.format(te_precision, te_recall))

Train - precision:0.3243, recall:0.4000, Val - precision:0.4286, recall:0.3750
Test - precision:0.3750, recall:0.3750


In [11]:
for p in model.parameters():
    p.requires_grad = True

In [12]:
optimizer = torch.optim.RMSprop(model.parameters(), lr=3e-6)
#criterion = torch.nn.BCEWithLogitsLoss(reduction='none')
#criterion = RecallAtPrecisionLoss(num_classes=num_classes, target_precision=0.99).to(device)
#criterion = PrecisionAtRecallLoss(num_classes=num_classes, target_recall=0.05).to(device)
#criterion = PrecisionAtRecallLoss(
#    num_classes=num_classes, target_recall=0.1, surrogate_type='hinge').to(device)
#criterion = PrecisionRecallAUCLoss(
#    num_classes=num_classes, precision_range=(0.5, 1.)).to(device)
criterion = PrecisionRecallAUCLoss2(
    num_classes=num_classes, recall_range=(0.0, 1.0)).to(device)


for epoch in range(1, 126):
    train(model, optimizer, criterion)
    tr_loss = get_loss(model, train_loader)
    val_loss = get_loss(model, val_loader)
    print('Epoch: {:03d}, Train loss: {:.5f}, Val loss: {:.5f}'.format(
        epoch, tr_loss, val_loss))
    
    tr_roc_auc, tr_ave_prec = test(model, train_loader)
    val_roc_auc, val_ave_prec  = test(model, val_loader)
    print('Train - roc_auc:{:.4f}, ave_prec:{:.4f}, Val - roc_auc:{:.4f}, ave_prec:{:.4f}'.format(
        tr_roc_auc, tr_ave_prec, val_roc_auc, val_ave_prec))
    
    tr_precision, tr_recall = test_precision_recall(model, train_loader)
    val_precision, val_recall  = test_precision_recall(model, val_loader)
    print('Train - precision:{:.4f}, recall:{:.4f}, Val - precision:{:.4f}, recall:{:.4f}'.format(
        tr_precision, tr_recall, val_precision, val_recall))
    
te_loss  = get_loss(model, test_loader)
print('Test loss: {:.5f}'.format(te_loss))
te_roc_auc, te_ave_prec = test(model, test_loader)
print('Test - roc_auc:{:.4f}, ave_prec:{:.4f}'.format(te_roc_auc, te_ave_prec))
te_precision, te_recall = test_precision_recall(model, test_loader)
print('Test - precision:{:.4f}, recall:{:.4f}'.format(te_precision, te_recall))

Epoch: 001, Train loss: 0.94413, Val loss: 0.93856
Train - roc_auc:0.8111, ave_prec:0.3035, Val - roc_auc:0.8231, ave_prec:0.5313
Train - precision:0.3495, recall:0.6000, Val - precision:0.3077, recall:0.5000
Epoch: 002, Train loss: 0.91640, Val loss: 0.93348
Train - roc_auc:0.8181, ave_prec:0.3125, Val - roc_auc:0.8231, ave_prec:0.5314
Train - precision:0.3646, recall:0.5833, Val - precision:0.3333, recall:0.5000
Epoch: 003, Train loss: 0.90515, Val loss: 0.93619
Train - roc_auc:0.8217, ave_prec:0.3177, Val - roc_auc:0.8231, ave_prec:0.5314
Train - precision:0.3596, recall:0.5333, Val - precision:0.3333, recall:0.5000
Epoch: 004, Train loss: 0.89451, Val loss: 0.91279
Train - roc_auc:0.8286, ave_prec:0.3278, Val - roc_auc:0.8349, ave_prec:0.5432
Train - precision:0.3776, recall:0.6167, Val - precision:0.3077, recall:0.5000
Epoch: 005, Train loss: 0.88383, Val loss: 0.90003
Train - roc_auc:0.8343, ave_prec:0.3323, Val - roc_auc:0.8373, ave_prec:0.5500
Train - precision:0.3981, recall:0

Epoch: 041, Train loss: 0.72195, Val loss: 0.79699
Train - roc_auc:0.8947, ave_prec:0.4490, Val - roc_auc:0.8703, ave_prec:0.5793
Train - precision:0.4519, recall:0.7833, Val - precision:0.3636, recall:0.5000
Epoch: 042, Train loss: 0.70971, Val loss: 0.79294
Train - roc_auc:0.8964, ave_prec:0.4506, Val - roc_auc:0.8726, ave_prec:0.5821
Train - precision:0.4519, recall:0.7833, Val - precision:0.3636, recall:0.5000
Epoch: 043, Train loss: 0.70554, Val loss: 0.79495
Train - roc_auc:0.8974, ave_prec:0.4527, Val - roc_auc:0.8726, ave_prec:0.5821
Train - precision:0.4615, recall:0.8000, Val - precision:0.3636, recall:0.5000
Epoch: 044, Train loss: 0.70753, Val loss: 0.79805
Train - roc_auc:0.8984, ave_prec:0.4535, Val - roc_auc:0.8703, ave_prec:0.5800
Train - precision:0.4706, recall:0.8000, Val - precision:0.3636, recall:0.5000
Epoch: 045, Train loss: 0.70434, Val loss: 0.79237
Train - roc_auc:0.8995, ave_prec:0.4535, Val - roc_auc:0.8703, ave_prec:0.5800
Train - precision:0.4712, recall:0

Epoch: 081, Train loss: 0.56346, Val loss: 0.70229
Train - roc_auc:0.9281, ave_prec:0.5316, Val - roc_auc:0.9009, ave_prec:0.6130
Train - precision:0.5294, recall:0.9000, Val - precision:0.5556, recall:0.6250
Epoch: 082, Train loss: 0.56232, Val loss: 0.69845
Train - roc_auc:0.9287, ave_prec:0.5307, Val - roc_auc:0.9009, ave_prec:0.6130
Train - precision:0.5294, recall:0.9000, Val - precision:0.5000, recall:0.6250
Epoch: 083, Train loss: 0.56631, Val loss: 0.69177
Train - roc_auc:0.9298, ave_prec:0.5354, Val - roc_auc:0.9057, ave_prec:0.6234
Train - precision:0.5347, recall:0.9000, Val - precision:0.5556, recall:0.6250
Epoch: 084, Train loss: 0.57403, Val loss: 0.69902
Train - roc_auc:0.9306, ave_prec:0.5387, Val - roc_auc:0.9080, ave_prec:0.6252
Train - precision:0.5400, recall:0.9000, Val - precision:0.5556, recall:0.6250
Epoch: 085, Train loss: 0.55493, Val loss: 0.69018
Train - roc_auc:0.9310, ave_prec:0.5468, Val - roc_auc:0.9080, ave_prec:0.6227
Train - precision:0.5347, recall:0

Epoch: 121, Train loss: 0.42517, Val loss: 0.61103
Train - roc_auc:0.9538, ave_prec:0.6387, Val - roc_auc:0.9269, ave_prec:0.6407
Train - precision:0.5567, recall:0.9000, Val - precision:0.4444, recall:0.5000
Epoch: 122, Train loss: 0.41940, Val loss: 0.60249
Train - roc_auc:0.9542, ave_prec:0.6483, Val - roc_auc:0.9292, ave_prec:0.6444
Train - precision:0.5670, recall:0.9167, Val - precision:0.4444, recall:0.5000
Epoch: 123, Train loss: 0.41912, Val loss: 0.60184
Train - roc_auc:0.9551, ave_prec:0.6547, Val - roc_auc:0.9316, ave_prec:0.6569
Train - precision:0.5500, recall:0.9167, Val - precision:0.4444, recall:0.5000
Epoch: 124, Train loss: 0.41972, Val loss: 0.60555
Train - roc_auc:0.9552, ave_prec:0.6348, Val - roc_auc:0.9292, ave_prec:0.6444
Train - precision:0.5556, recall:0.9167, Val - precision:0.4444, recall:0.5000
Epoch: 125, Train loss: 0.40580, Val loss: 0.61111
Train - roc_auc:0.9558, ave_prec:0.6373, Val - roc_auc:0.9292, ave_prec:0.6444
Train - precision:0.5745, recall:0

In [13]:
tr_precision, tr_recall = test_precision_recall(model, train_loader, pred_th=0.8)
val_precision, val_recall  = test_precision_recall(model, val_loader, pred_th=0.8)
print('Train - precision:{:.4f}, recall:{:.4f}, Val - precision:{:.4f}, recall:{:.4f}'.format(
    tr_precision, tr_recall, val_precision, val_recall))
te_precision, te_recall = test_precision_recall(model, test_loader, pred_th=0.8)
print('Test - precision:{:.4f}, recall:{:.4f}'.format(te_precision, te_recall))

Train - precision:0.6000, recall:0.8500, Val - precision:0.5714, recall:0.5000
Test - precision:0.6250, recall:0.6250


In [14]:
#torch.save(model, 'trained_models/AFP_tl_cisplatin_0.85093')
torch.save(model, 'trained_models/AFP_tl_cisplatin_test')

In [None]:
model = torch.load('trained_models/AFP_tl_cisplatin_0.85093')
te_score = test(model, test_loader)
print('Test: {:.5f}'.format(te_score))